diff --git a/NEWS.md b/NEWS.md index 5aafe8fd..adc417a9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,11 +3,13 @@ ## New Features * Added support for parametric treatment effect term in BCF [#309](https://github.com/StochasticTree/stochtree/pull/309/) +* Added support for observation-level weights passed as data arguments to BART and BCF [#333](https://github.com/StochasticTree/stochtree/pull/333) ## Bug Fixes * Fixed multi-chain BCF bugs with the parametric intercept term in R and Python [#326](https://github.com/StochasticTree/stochtree/pull/326) * Fixed indexing bugs for multivariate treatment BCF in Python [#326](https://github.com/StochasticTree/stochtree/pull/326) +* Convert binary factor-valued treatments to 0/1 binary numeric treatment in `bcf()` R function [#332](https://github.com/StochasticTree/stochtree/pull/332) ## Documentation and Other Maintenance diff --git a/R/bart.R b/R/bart.R index 4ceb6ad5..63aa06da 100644 --- a/R/bart.R +++ b/R/bart.R @@ -88,6 +88,11 @@ NULL #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. #' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model. +#' @param observation_weights (Optional) Numeric vector of observation weights of length `nrow(X_train)`. Weights are +#' applied as `y_i | - ~ N(mu(X_i), sigma^2 / w_i)`, so larger weights increase an observation's influence on the fit. +#' All weights must be non-negative. Default: `NULL` (all observations equally weighted). Compatible with Gaussian +#' (continuous/identity) and probit outcome models; not compatible with cloglog link functions. Note: these are +#' referred to internally in the C++ layer as "variance weights" (`var_weights`), since they scale the residual variance. #' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. #' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. #' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. @@ -189,6 +194,7 @@ bart <- function( leaf_basis_test = NULL, rfx_group_ids_test = NULL, rfx_basis_test = NULL, + observation_weights = NULL, num_gfr = 5, num_burnin = 0, num_mcmc = 100, @@ -507,6 +513,20 @@ bart <- function( include_mean_forest = FALSE } + # observation_weights compatibility checks + if (!is.null(observation_weights)) { + if (link_is_cloglog) { + stop( + "observation_weights are not compatible with cloglog link functions." + ) + } + if (include_variance_forest) { + warning( + "Results may be unreliable when observation_weights are deployed alongside a variance forest model." + ) + } + } + # Set the variance forest priors if not set if (include_variance_forest) { if (is.null(a_forest)) { @@ -531,6 +551,26 @@ bart <- function( stop("variable_weights cannot have any negative weights") } + # Observation weight validation + if (!is.null(observation_weights)) { + if (!is.numeric(observation_weights)) { + stop("observation_weights must be a numeric vector") + } + if (length(observation_weights) != nrow(X_train)) { + stop("length(observation_weights) must equal nrow(X_train)") + } + if (any(observation_weights < 0)) { + stop("observation_weights cannot have any negative values") + } + if (all(observation_weights == 0) && num_gfr > 0) { + stop( + "observation_weights are all zero (prior sampling mode) but num_gfr > 0. ", + "GFR warm-start is data-dependent and ill-defined with zero weights. ", + "Set num_gfr = 0 when using all-zero observation_weights." + ) + } + } + # Check covariates are matrix or dataframe if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { stop("X_train must be a matrix or dataframe") @@ -1217,13 +1257,20 @@ bart <- function( # Data if (leaf_regression) { - forest_dataset_train <- createForestDataset(X_train, leaf_basis_train) + forest_dataset_train <- createForestDataset( + X_train, + leaf_basis_train, + observation_weights + ) if (has_test) { forest_dataset_test <- createForestDataset(X_test, leaf_basis_test) } requires_basis <- TRUE } else { - forest_dataset_train <- createForestDataset(X_train) + forest_dataset_train <- createForestDataset( + X_train, + variance_weights = observation_weights + ) if (has_test) { forest_dataset_test <- createForestDataset(X_test) } @@ -4460,7 +4507,10 @@ createBARTModelFromCombinedJsonString <- function(json_string_list) { "outcome", "outcome_model" ) - outcome_model_link <- json_object_default$get_string("link", "outcome_model") + outcome_model_link <- json_object_default$get_string( + "link", + "outcome_model" + ) } else { outcome_model_outcome <- "continuous" outcome_model_link <- "identity" diff --git a/R/bcf.R b/R/bcf.R index 125e2994..52bddee0 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -87,6 +87,11 @@ NULL #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. #' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model. +#' @param observation_weights (Optional) Numeric vector of observation weights of length `nrow(X_train)`. Weights are +#' applied as `y_i | - ~ N(mu(X_i), sigma^2 / w_i)`, so larger weights increase an observation's influence on the fit. +#' All weights must be non-negative. Default: `NULL` (all observations equally weighted). Applied to both the +#' prognostic and treatment effect forests. Compatible with Gaussian (continuous/identity) and probit outcome models; +#' not compatible with cloglog link functions. #' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. #' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. #' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. @@ -237,6 +242,7 @@ bcf <- function( propensity_test = NULL, rfx_group_ids_test = NULL, rfx_basis_test = NULL, + observation_weights = NULL, num_gfr = 5, num_burnin = 0, num_mcmc = 100, @@ -616,6 +622,36 @@ bcf <- function( include_variance_forest = FALSE } + # observation_weights validation and compatibility checks + if (!is.null(observation_weights)) { + if (!is.numeric(observation_weights)) { + stop("observation_weights must be a numeric vector") + } + if (length(observation_weights) != nrow(X_train)) { + stop("length(observation_weights) must equal nrow(X_train)") + } + if (any(observation_weights < 0)) { + stop("observation_weights cannot have any negative values") + } + if (all(observation_weights == 0) && num_gfr > 0) { + stop( + "observation_weights are all zero (prior sampling mode) but num_gfr > 0. ", + "GFR warm-start is data-dependent and ill-defined with zero weights. ", + "Set num_gfr = 0 when using all-zero observation_weights." + ) + } + if (link_is_cloglog) { + stop( + "observation_weights are not compatible with cloglog link functions." + ) + } + if (include_variance_forest) { + warning( + "Results may be unreliable when observation_weights are deployed alongside a variance forest model." + ) + } + } + # Set the variance forest priors if not set if (include_variance_forest) { if (is.null(a_forest)) { @@ -933,7 +969,10 @@ bcf <- function( } message( "Z_train is a factor; converting to 0/1 using level order: ", - lvls[1], " = 0, ", lvls[2], " = 1" + lvls[1], + " = 0, ", + lvls[2], + " = 1" ) Z_train <- as.integer(Z_train) - 1L } @@ -944,7 +983,10 @@ bcf <- function( } message( "Z_test is a factor; converting to 0/1 using level order: ", - lvls[1], " = 0, ", lvls[2], " = 1" + lvls[1], + " = 0, ", + lvls[2], + " = 1" ) Z_test <- as.integer(Z_test) - 1L } @@ -1681,7 +1723,11 @@ bcf <- function( } # Data - forest_dataset_train <- createForestDataset(X_train, tau_basis_train) + forest_dataset_train <- createForestDataset( + X_train, + tau_basis_train, + observation_weights + ) if (has_test) { forest_dataset_test <- createForestDataset(X_test, tau_basis_test) } @@ -2360,7 +2406,9 @@ bcf <- function( if (sample_tau_0 && !is.null(previous_tau_0_samples)) { tau_0_old <- tau_0 # previous model stores tau_0 in original scale; convert to standardized scale - tau_0 <- as.numeric(previous_tau_0_samples[, warmstart_index] / previous_y_scale) + tau_0 <- as.numeric( + previous_tau_0_samples[, warmstart_index] / previous_y_scale + ) Z_basis_ws <- as.matrix(tau_basis_train) outcome_train$subtract_vector( as.numeric(Z_basis_ws %*% matrix(tau_0 - tau_0_old, ncol = 1)) @@ -3456,7 +3504,10 @@ predict.bcfmodel <- function( } warning( "Z is a factor; recoding to 0/1 using level order: ", - lvls[1], " = 0, ", lvls[2], " = 1" + lvls[1], + " = 0, ", + lvls[2], + " = 1" ) Z <- as.integer(Z) - 1L } @@ -4820,7 +4871,9 @@ createBCFModelFromJson <- function(json_object) { # Version inference and presence-check helpers .ver <- inferStochtreeJsonVersion(json_object) - has_field <- function(name) json_contains_field_cpp(json_object$json_ptr, name) + has_field <- function(name) { + json_contains_field_cpp(json_object$json_ptr, name) + } has_subfolder_field <- function(subfolder, name) { json_contains_field_subfolder_cpp(json_object$json_ptr, subfolder, name) } @@ -5036,11 +5089,23 @@ createBCFModelFromJson <- function(json_object) { } if (model_params[["adaptive_coding"]]) { if (has_subfolder_field("parameters", "b1_samples")) { - output[["b_1_samples"]] <- json_object$get_vector("b1_samples", "parameters") - output[["b_0_samples"]] <- json_object$get_vector("b0_samples", "parameters") + output[["b_1_samples"]] <- json_object$get_vector( + "b1_samples", + "parameters" + ) + output[["b_0_samples"]] <- json_object$get_vector( + "b0_samples", + "parameters" + ) } else { - output[["b_1_samples"]] <- json_object$get_vector("b_1_samples", "parameters") - output[["b_0_samples"]] <- json_object$get_vector("b_0_samples", "parameters") + output[["b_1_samples"]] <- json_object$get_vector( + "b_1_samples", + "parameters" + ) + output[["b_0_samples"]] <- json_object$get_vector( + "b_0_samples", + "parameters" + ) warning(sprintf( "JSON fields 'b_1_samples'/'b_0_samples' are deprecated; please re-save the model to use 'b1_samples'/'b0_samples' (inferred version: %s).", .ver @@ -5133,9 +5198,15 @@ createBCFModelFromCombinedJson <- function(json_object_list) { # Version inference and presence-check helpers .ver <- inferStochtreeJsonVersion(json_object_default) - has_field <- function(name) json_contains_field_cpp(json_object_default$json_ptr, name) + has_field <- function(name) { + json_contains_field_cpp(json_object_default$json_ptr, name) + } has_subfolder_field <- function(subfolder, name) { - json_contains_field_subfolder_cpp(json_object_default$json_ptr, subfolder, name) + json_contains_field_subfolder_cpp( + json_object_default$json_ptr, + subfolder, + name + ) } # Unpack the forests @@ -5209,9 +5280,13 @@ createBCFModelFromCombinedJson <- function(json_object_list) { "standardize" ) if (has_field("sigma2_init")) { - model_params[["initial_sigma2"]] <- json_object_default$get_scalar("sigma2_init") + model_params[["initial_sigma2"]] <- json_object_default$get_scalar( + "sigma2_init" + ) } else { - model_params[["initial_sigma2"]] <- json_object_default$get_scalar("initial_sigma2") + model_params[["initial_sigma2"]] <- json_object_default$get_scalar( + "initial_sigma2" + ) warning(sprintf( "JSON field 'initial_sigma2' is deprecated; please re-save the model to use 'sigma2_init' (inferred version: %s).", .ver @@ -5319,7 +5394,10 @@ createBCFModelFromCombinedJson <- function(json_object_list) { "outcome", "outcome_model" ) - outcome_model_link <- json_object_default$get_string("link", "outcome_model") + outcome_model_link <- json_object_default$get_string( + "link", + "outcome_model" + ) } else { outcome_model_outcome <- "continuous" outcome_model_link <- "identity" @@ -5551,7 +5629,10 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { # We don't support merging BCF models with independent propensity models # this way at the moment if ( - json_contains_field_cpp(json_object_list[[i]]$json_ptr, "internal_propensity_model") && + json_contains_field_cpp( + json_object_list[[i]]$json_ptr, + "internal_propensity_model" + ) && json_object_list[[i]]$get_boolean("internal_propensity_model") ) { stop( @@ -5566,9 +5647,15 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { # Version inference and presence-check helpers .ver <- inferStochtreeJsonVersion(json_object_default) - has_field <- function(name) json_contains_field_cpp(json_object_default$json_ptr, name) + has_field <- function(name) { + json_contains_field_cpp(json_object_default$json_ptr, name) + } has_subfolder_field <- function(subfolder, name) { - json_contains_field_subfolder_cpp(json_object_default$json_ptr, subfolder, name) + json_contains_field_subfolder_cpp( + json_object_default$json_ptr, + subfolder, + name + ) } # Unpack the forests @@ -5642,9 +5729,13 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { "standardize" ) if (has_field("sigma2_init")) { - model_params[["initial_sigma2"]] <- json_object_default$get_scalar("sigma2_init") + model_params[["initial_sigma2"]] <- json_object_default$get_scalar( + "sigma2_init" + ) } else { - model_params[["initial_sigma2"]] <- json_object_default$get_scalar("initial_sigma2") + model_params[["initial_sigma2"]] <- json_object_default$get_scalar( + "initial_sigma2" + ) warning(sprintf( "JSON field 'initial_sigma2' is deprecated; please re-save the model to use 'sigma2_init' (inferred version: %s).", .ver @@ -5752,7 +5843,10 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { "outcome", "outcome_model" ) - outcome_model_link <- json_object_default$get_string("link", "outcome_model") + outcome_model_link <- json_object_default$get_string( + "link", + "outcome_model" + ) } else { outcome_model_outcome <- "continuous" outcome_model_link <- "identity" diff --git a/man/bart.Rd b/man/bart.Rd index 9cd721b6..a1794208 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -14,6 +14,7 @@ bart( leaf_basis_test = NULL, rfx_group_ids_test = NULL, rfx_basis_test = NULL, + observation_weights = NULL, num_gfr = 5, num_burnin = 0, num_mcmc = 100, @@ -59,6 +60,12 @@ that were not in the training set.} \item{rfx_basis_test}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} +\item{observation_weights}{(Optional) Numeric vector of observation weights of length \code{nrow(X_train)}. Weights are +applied as \code{y_i | - ~ N(mu(X_i), sigma^2 / w_i)}, so larger weights increase an observation's influence on the fit. +All weights must be non-negative. Default: \code{NULL} (all observations equally weighted). Compatible with Gaussian +(continuous/identity) and probit outcome models; not compatible with cloglog link functions. Note: these are +referred to internally in the C++ layer as "variance weights" (\code{var_weights}), since they scale the residual variance.} + \item{num_gfr}{Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.} \item{num_burnin}{Number of "burn-in" iterations of the MCMC sampler. Default: 0.} diff --git a/man/bcf.Rd b/man/bcf.Rd index 2e3ab149..6860f410 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -16,6 +16,7 @@ bcf( propensity_test = NULL, rfx_group_ids_test = NULL, rfx_basis_test = NULL, + observation_weights = NULL, num_gfr = 5, num_burnin = 0, num_mcmc = 100, @@ -61,6 +62,12 @@ that were not in the training set.} \item{rfx_basis_test}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} +\item{observation_weights}{(Optional) Numeric vector of observation weights of length \code{nrow(X_train)}. Weights are +applied as \code{y_i | - ~ N(mu(X_i), sigma^2 / w_i)}, so larger weights increase an observation's influence on the fit. +All weights must be non-negative. Default: \code{NULL} (all observations equally weighted). Applied to both the +prognostic and treatment effect forests. Compatible with Gaussian (continuous/identity) and probit outcome models; +not compatible with cloglog link functions.} + \item{num_gfr}{Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.} \item{num_burnin}{Number of "burn-in" iterations of the MCMC sampler. Default: 0.} @@ -81,7 +88,7 @@ that were not in the training set.} \item \code{sigma2_global_scale} Scale parameter in the \code{IG(sigma2_global_shape, sigma2_global_scale)} global error variance model. Default: \code{0}. \item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to \code{1/ncol(X_train)}. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in \code{X_train} and then set \code{propensity_covariate} to \code{'none'} adjust \code{keep_vars} accordingly for the \code{prognostic} or \code{treatment_effect} forests. \item \code{propensity_covariate} Whether to include the propensity score as a covariate in either or both of the forests. Enter \code{"none"} for neither, \code{"prognostic"} for the prognostic forest, \code{"treatment_effect"} for the treatment forest, and \code{"both"} for both forests. If this is not \code{"none"} and a propensity score is not provided, it will be estimated from (\code{X_train}, \code{Z_train}) using \code{stochtree::bart()}. Default: \code{"mu"}. -\item \code{adaptive_coding} Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters \code{b_0} and \code{b_1} that attach to the outcome model \verb{[b_0 (1-Z) + b_1 Z] tau(X)}. This is ignored when Z is not binary. Default: \code{TRUE}. +\item \code{adaptive_coding} Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters \code{b_0} and \code{b_1} that attach to the outcome model \verb{[b_0 (1-Z) + b_1 Z] tau(X)}. This is ignored when Z is not binary. Default: \code{FALSE}. \item \code{control_coding_init} Initial value of the "control" group coding parameter. This is ignored when Z is not binary. Default: \code{-0.5}. \item \code{treated_coding_init} Initial value of the "treatment" group coding parameter. This is ignored when Z is not binary. Default: \code{0.5}. \item \code{rfx_prior_var} Prior on the (diagonals of the) covariance of the additive group-level random regression coefficients. Must be a vector of length \code{ncol(rfx_basis_train)}. Default: \code{rep(1, ncol(rfx_basis_train))} diff --git a/stochtree/bart.py b/stochtree/bart.py index 66945ada..0b7e6343 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -85,6 +85,7 @@ def sample( leaf_basis_test: Optional[np.ndarray] = None, rfx_group_ids_test: Optional[np.ndarray] = None, rfx_basis_test: Optional[np.ndarray] = None, + observation_weights: Optional[np.ndarray] = None, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, @@ -120,6 +121,14 @@ def sample( test set evaluation for group labels that were not in the training set. rfx_basis_test : np.array, optional Optional test set basis for "random-slope" regression in additive random effects model. + observation_weights : np.array, optional + Optional vector of observation weights of length ``n_train``. Weights are applied as + ``y_i | - ~ N(mu(X_i), sigma^2 / w_i)``, so larger weights increase an observation's + influence on the fit. All weights must be non-negative. Defaults to ``None`` (all + observations equally weighted). Compatible with Gaussian (continuous/identity) and + probit outcome models; not compatible with cloglog link functions. Note: these are + referred to internally in the C++ layer as "variance weights" (``var_weights``), since + they scale the residual variance. num_gfr : int, optional Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to `5`. num_burnin : int, optional @@ -424,6 +433,17 @@ def sample( self.include_mean_forest = True if num_trees_mean > 0 else False self.include_variance_forest = True if num_trees_variance > 0 else False + # observation_weights compatibility checks + if observation_weights is not None: + if link_is_cloglog: + raise ValueError( + "observation_weights are not compatible with cloglog link functions." + ) + if self.include_variance_forest: + warnings.warn( + "Results may be unreliable when observation_weights are deployed alongside a variance forest model." + ) + # Check data inputs if not isinstance(X_train, pd.DataFrame) and not isinstance( X_train, np.ndarray @@ -462,7 +482,24 @@ def sample( if rfx_basis_test is not None: if not isinstance(rfx_basis_test, np.ndarray): raise ValueError("rfx_basis_test must be a numpy array") - + if observation_weights is not None: + if not isinstance(observation_weights, np.ndarray): + raise ValueError("observation_weights must be a numpy array") + observation_weights_ = np.squeeze(observation_weights) + if observation_weights_.ndim != 1: + raise ValueError("observation_weights must be a 1-dimensional numpy array") + if np.any(observation_weights_ < 0): + raise ValueError("observation_weights cannot have any negative values") + + # Validate that observation_weights are not all-zero when num_gfr > 0 + if observation_weights is not None and num_gfr > 0: + if np.all(observation_weights == 0): + raise ValueError( + "observation_weights are all zero (prior sampling mode) but num_gfr > 0. " + "GFR warm-start is data-dependent and ill-defined with zero weights. " + "Set num_gfr=0 when using all-zero observation_weights." + ) + # Convert everything to standard shape (2-dimensional) if isinstance(X_train, np.ndarray): if X_train.ndim == 1: @@ -1343,6 +1380,8 @@ def sample( forest_dataset_train.add_covariates(X_train_processed) if self.has_basis: forest_dataset_train.add_basis(leaf_basis_train) + if observation_weights is not None: + forest_dataset_train.add_variance_weights(observation_weights_) if self.has_test: forest_dataset_test = Dataset() forest_dataset_test.add_covariates(X_test_processed) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index c0156543..f1666e1d 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -96,6 +96,7 @@ def sample( propensity_test: np.array = None, rfx_group_ids_test: np.array = None, rfx_basis_test: np.array = None, + observation_weights: Optional[np.ndarray] = None, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, @@ -136,6 +137,15 @@ def sample( test set evaluation for group labels that were not in the training set. rfx_basis_test : np.array, optional Optional test set basis for "random-slope" regression in additive random effects model. + observation_weights : np.array, optional + Optional vector of observation weights of length ``n_train``. Weights are applied as + ``y_i | - ~ N(mu(X_i), sigma^2 / w_i)``, so larger weights increase an observation's + influence on the fit. All weights must be non-negative. Defaults to ``None`` (all + observations equally weighted). Applied to both the prognostic and treatment effect + forests. Compatible with Gaussian (continuous/identity) and probit outcome models; + not compatible with cloglog link functions. Note: these are referred to internally in + the C++ layer as "variance weights" (``var_weights``), since they scale the residual + variance. num_gfr : int, optional Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to `5`. num_burnin : int, optional @@ -610,6 +620,30 @@ def sample( # Determine whether conditional variance model will be fit self.include_variance_forest = True if num_trees_variance > 0 else False + # observation_weights validation and compatibility checks + if observation_weights is not None: + if not isinstance(observation_weights, np.ndarray): + raise ValueError("observation_weights must be a numpy array") + observation_weights_ = np.squeeze(observation_weights) + if observation_weights_.ndim != 1: + raise ValueError("observation_weights must be a 1-dimensional numpy array") + if np.any(observation_weights_ < 0): + raise ValueError("observation_weights cannot have any negative values") + if np.all(observation_weights_ == 0) and num_gfr > 0: + raise ValueError( + "observation_weights are all zero (prior sampling mode) but num_gfr > 0. " + "GFR warm-start is data-dependent and ill-defined with zero weights. " + "Set num_gfr=0 when using all-zero observation_weights." + ) + if link_is_cloglog: + raise ValueError( + "observation_weights are not compatible with cloglog link functions." + ) + if self.include_variance_forest: + warnings.warn( + "Results may be unreliable when observation_weights are deployed alongside a variance forest model." + ) + # Check data inputs if not isinstance(X_train, pd.DataFrame) and not isinstance( X_train, np.ndarray @@ -1912,6 +1946,8 @@ def sample( forest_dataset_train = Dataset() forest_dataset_train.add_covariates(X_train_processed) forest_dataset_train.add_basis(tau_basis_train) + if observation_weights is not None: + forest_dataset_train.add_variance_weights(observation_weights_) if self.has_test: forest_dataset_test = Dataset() forest_dataset_test.add_covariates(X_test_processed) diff --git a/test/R/testthat/test-observation-weights.R b/test/R/testthat/test-observation-weights.R new file mode 100644 index 00000000..2e2dfda4 --- /dev/null +++ b/test/R/testthat/test-observation-weights.R @@ -0,0 +1,259 @@ +make_bart_data <- function(n = 100, p = 5, seed = 42) { + set.seed(seed) + X <- matrix(runif(n * p), ncol = p) + y <- sin(X[, 1] * pi) + rnorm(n, 0, 0.1) + n_train <- as.integer(0.8 * n) + list( + X_train = X[1:n_train, ], y_train = y[1:n_train], + X_test = X[(n_train + 1):n, ], n_train = n_train, n_test = n - n_train + ) +} + +make_bcf_data <- function(n = 100, p = 5, seed = 42) { + set.seed(seed) + X <- matrix(runif(n * p), ncol = p) + pi_X <- 0.25 + 0.5 * X[, 1] + Z <- rbinom(n, 1, pi_X) + y <- pi_X * 5 + X[, 2] * 2 * Z + rnorm(n, 0, 1) + n_train <- as.integer(0.8 * n) + list( + X_train = X[1:n_train, ], Z_train = Z[1:n_train], y_train = y[1:n_train], + pi_train = pi_X[1:n_train], X_test = X[(n_train + 1):n, ], + Z_test = Z[(n_train + 1):n], pi_test = pi_X[(n_train + 1):n], + n_train = n_train, n_test = n - n_train + ) +} + +test_that("BART: uniform weights produce identical predictions to no weights", { + skip_on_cran() + d <- make_bart_data() + num_mcmc <- 10 + + set.seed(1) + m1 <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, + general_params = list(random_seed = 1L) + ) + + set.seed(1) + m2 <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + observation_weights = rep(1.0, d$n_train), + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, + general_params = list(random_seed = 1L) + ) + + expect_equal(m1$y_hat_train, m2$y_hat_train) + expect_equal(m1$y_hat_test, m2$y_hat_test) +}) + +test_that("BART: non-uniform weights run and produce correct output shape", { + skip_on_cran() + d <- make_bart_data() + num_mcmc <- 10 + weights <- runif(d$n_train, 0.5, 2.0) + + expect_no_error( + m <- bart( + X_train = d$X_train, y_train = d$y_train, X_test = d$X_test, + observation_weights = weights, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc + ) + ) + expect_equal(dim(m$y_hat_train), c(d$n_train, num_mcmc)) + expect_equal(dim(m$y_hat_test), c(d$n_test, num_mcmc)) +}) + +test_that("BART: all-zero weights (prior mode) run with num_gfr = 0", { + skip_on_cran() + d <- make_bart_data() + num_mcmc <- 10 + + expect_no_error( + m <- bart( + X_train = d$X_train, y_train = d$y_train, + observation_weights = rep(0.0, d$n_train), + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc + ) + ) + expect_equal(dim(m$y_hat_train), c(d$n_train, num_mcmc)) +}) + +test_that("BART: non-numeric observation_weights raises error", { + skip_on_cran() + d <- make_bart_data() + expect_error( + bart( + X_train = d$X_train, y_train = d$y_train, + observation_weights = as.character(rep(1, d$n_train)), + num_gfr = 0, num_burnin = 0, num_mcmc = 5 + ), + "numeric" + ) +}) + +test_that("BART: wrong-length observation_weights raises error", { + skip_on_cran() + d <- make_bart_data() + expect_error( + bart( + X_train = d$X_train, y_train = d$y_train, + observation_weights = rep(1.0, d$n_train + 1), + num_gfr = 0, num_burnin = 0, num_mcmc = 5 + ), + "nrow" + ) +}) + +test_that("BART: negative observation_weights raises error", { + skip_on_cran() + d <- make_bart_data() + weights <- rep(1.0, d$n_train) + weights[1] <- -1.0 + expect_error( + bart( + X_train = d$X_train, y_train = d$y_train, + observation_weights = weights, + num_gfr = 0, num_burnin = 0, num_mcmc = 5 + ), + "negative" + ) +}) + +test_that("BART: all-zero weights with num_gfr > 0 raises error", { + skip_on_cran() + d <- make_bart_data() + expect_error( + bart( + X_train = d$X_train, y_train = d$y_train, + observation_weights = rep(0.0, d$n_train), + num_gfr = 5, num_burnin = 0, num_mcmc = 10 + ), + "num_gfr" + ) +}) + +test_that("BART: observation_weights with cloglog outcome raises error", { + skip_on_cran() + d <- make_bart_data() + y_ord <- sample(1:3, d$n_train, replace = TRUE) + expect_error( + bart( + X_train = d$X_train, y_train = y_ord, + observation_weights = rep(1.0, d$n_train), + num_gfr = 0, num_burnin = 0, num_mcmc = 5, + general_params = list(outcome_model = OutcomeModel(outcome = "ordinal", link = "cloglog")) + ), + "cloglog" + ) +}) + +test_that("BART: observation_weights with variance forest raises warning", { + skip_on_cran() + d <- make_bart_data() + expect_warning( + bart( + X_train = d$X_train, y_train = d$y_train, + observation_weights = rep(1.0, d$n_train), + num_gfr = 0, num_burnin = 0, num_mcmc = 5, + variance_forest_params = list(num_trees = 5) + ), + "variance forest" + ) +}) + +test_that("BCF: uniform weights produce identical predictions to no weights", { + skip_on_cran() + d <- make_bcf_data() + num_mcmc <- 10 + + set.seed(1) + m1 <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, X_test = d$X_test, + Z_test = d$Z_test, propensity_test = d$pi_test, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, + general_params = list(random_seed = 1L) + ) + + set.seed(1) + m2 <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, X_test = d$X_test, + Z_test = d$Z_test, propensity_test = d$pi_test, + observation_weights = rep(1.0, d$n_train), + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, + general_params = list(random_seed = 1L) + ) + + expect_equal(m1$y_hat_train, m2$y_hat_train) + expect_equal(m1$tau_hat_train, m2$tau_hat_train) +}) + +test_that("BCF: non-uniform weights run and produce correct output shape", { + skip_on_cran() + d <- make_bcf_data() + num_mcmc <- 10 + weights <- runif(d$n_train, 0.5, 2.0) + + expect_no_error( + m <- bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, X_test = d$X_test, + Z_test = d$Z_test, propensity_test = d$pi_test, + observation_weights = weights, + num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc + ) + ) + expect_equal(dim(m$y_hat_train), c(d$n_train, num_mcmc)) + expect_equal(dim(m$tau_hat_train), c(d$n_train, num_mcmc)) + expect_equal(dim(m$y_hat_test), c(d$n_test, num_mcmc)) + expect_equal(dim(m$tau_hat_test), c(d$n_test, num_mcmc)) +}) + +test_that("BCF: negative observation_weights raises error", { + skip_on_cran() + d <- make_bcf_data() + weights <- rep(1.0, d$n_train) + weights[1] <- -1.0 + expect_error( + bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + observation_weights = weights, + num_gfr = 0, num_burnin = 0, num_mcmc = 5 + ), + "negative" + ) +}) + +test_that("BCF: all-zero weights with num_gfr > 0 raises error", { + skip_on_cran() + d <- make_bcf_data() + expect_error( + bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = d$y_train, + propensity_train = d$pi_train, + observation_weights = rep(0.0, d$n_train), + num_gfr = 5, num_burnin = 0, num_mcmc = 10 + ), + "num_gfr" + ) +}) + +test_that("BCF: observation_weights with cloglog outcome raises error", { + skip_on_cran() + d <- make_bcf_data() + y_bin <- rbinom(d$n_train, 1, 0.5) + expect_error( + bcf( + X_train = d$X_train, Z_train = d$Z_train, y_train = y_bin, + propensity_train = d$pi_train, + observation_weights = rep(1.0, d$n_train), + num_gfr = 0, num_burnin = 0, num_mcmc = 5, + general_params = list(outcome_model = OutcomeModel(outcome = "binary", link = "cloglog")) + ), + "cloglog" + ) +}) diff --git a/test/python/test_observation_weights.py b/test/python/test_observation_weights.py new file mode 100644 index 00000000..6c89216a --- /dev/null +++ b/test/python/test_observation_weights.py @@ -0,0 +1,206 @@ +import numpy as np +import pytest + +from stochtree import BARTModel, BCFModel, OutcomeModel + + +def make_bart_data(n=100, p=5, seed=42): + rng = np.random.default_rng(seed) + X = rng.uniform(0, 1, (n, p)) + y = np.sin(X[:, 0] * np.pi) + rng.normal(0, 0.1, n) + n_train = int(0.8 * n) + return X[:n_train], y[:n_train], X[n_train:], n_train, n - n_train + + +def make_bcf_data(n=100, p=5, seed=42): + rng = np.random.default_rng(seed) + X = rng.uniform(0, 1, (n, p)) + pi_X = 0.25 + 0.5 * X[:, 0] + Z = rng.binomial(1, pi_X, n).astype(float) + y = pi_X * 5 + X[:, 1] * 2 * Z + rng.normal(0, 1, n) + n_train = int(0.8 * n) + return ( + X[:n_train], Z[:n_train], y[:n_train], pi_X[:n_train], + X[n_train:], Z[n_train:], pi_X[n_train:], + n_train, n - n_train, + ) + + +class TestBARTObservationWeights: + def test_uniform_weights_match_no_weights(self): + """Uniform weights of 1.0 should produce identical predictions to no weights.""" + X_train, y_train, X_test, n_train, n_test = make_bart_data() + kwargs = dict( + X_train=X_train, y_train=y_train, X_test=X_test, + num_gfr=0, num_burnin=0, num_mcmc=10, + general_params={"random_seed": 1}, + ) + m1 = BARTModel() + m1.sample(**kwargs) + + m2 = BARTModel() + m2.sample(**kwargs, observation_weights=np.ones(n_train)) + + np.testing.assert_array_equal(m1.y_hat_train, m2.y_hat_train) + np.testing.assert_array_equal(m1.y_hat_test, m2.y_hat_test) + + def test_nonuniform_weights_output_shape(self): + """Non-uniform weights: output shapes are correct.""" + X_train, y_train, X_test, n_train, n_test = make_bart_data() + rng = np.random.default_rng(0) + weights = rng.uniform(0.5, 2.0, n_train) + num_mcmc = 10 + + m = BARTModel() + m.sample( + X_train=X_train, y_train=y_train, X_test=X_test, + observation_weights=weights, + num_gfr=0, num_burnin=0, num_mcmc=num_mcmc, + ) + assert m.y_hat_train.shape == (n_train, num_mcmc) + assert m.y_hat_test.shape == (n_test, num_mcmc) + + def test_zero_weights_prior_mode(self): + """All-zero weights with num_gfr=0 runs (prior sampling mode).""" + X_train, y_train, _, n_train, _ = make_bart_data() + num_mcmc = 10 + + m = BARTModel() + m.sample( + X_train=X_train, y_train=y_train, + observation_weights=np.zeros(n_train), + num_gfr=0, num_burnin=0, num_mcmc=num_mcmc, + ) + assert m.y_hat_train.shape == (n_train, num_mcmc) + + def test_invalid_type_raises(self): + X_train, y_train, _, n_train, _ = make_bart_data() + with pytest.raises(ValueError, match="numpy array"): + BARTModel().sample( + X_train=X_train, y_train=y_train, + observation_weights=list(np.ones(n_train)), + num_gfr=0, num_burnin=0, num_mcmc=5, + ) + + def test_2d_weights_raises(self): + X_train, y_train, _, n_train, _ = make_bart_data() + with pytest.raises(ValueError, match="1-dimensional"): + BARTModel().sample( + X_train=X_train, y_train=y_train, + observation_weights=np.ones((n_train, 2)), + num_gfr=0, num_burnin=0, num_mcmc=5, + ) + + def test_negative_weights_raises(self): + X_train, y_train, _, n_train, _ = make_bart_data() + weights = np.ones(n_train) + weights[0] = -1.0 + with pytest.raises(ValueError, match="negative"): + BARTModel().sample( + X_train=X_train, y_train=y_train, + observation_weights=weights, + num_gfr=0, num_burnin=0, num_mcmc=5, + ) + + def test_all_zero_with_gfr_raises(self): + X_train, y_train, _, n_train, _ = make_bart_data() + with pytest.raises(ValueError, match="num_gfr"): + BARTModel().sample( + X_train=X_train, y_train=y_train, + observation_weights=np.zeros(n_train), + num_gfr=5, num_burnin=0, num_mcmc=10, + ) + + def test_cloglog_raises(self): + rng = np.random.default_rng(0) + n = 50 + X = rng.uniform(0, 1, (n, 3)) + y = rng.choice([1, 2, 3], n).astype(float) + with pytest.raises(ValueError, match="cloglog"): + BARTModel().sample( + X_train=X, y_train=y, + observation_weights=np.ones(n), + num_gfr=0, num_burnin=0, num_mcmc=5, + general_params={"outcome_model": OutcomeModel(outcome="ordinal", link="cloglog")}, + ) + + def test_variance_forest_warns(self): + X_train, y_train, _, n_train, _ = make_bart_data() + with pytest.warns(UserWarning, match="variance forest"): + BARTModel().sample( + X_train=X_train, y_train=y_train, + observation_weights=np.ones(n_train), + num_gfr=0, num_burnin=0, num_mcmc=5, + variance_forest_params={"num_trees": 5}, + ) + + +class TestBCFObservationWeights: + def test_uniform_weights_match_no_weights(self): + """Uniform weights of 1.0 should produce identical predictions to no weights.""" + X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, n_train, _ = make_bcf_data() + kwargs = dict( + X_train=X_train, Z_train=Z_train, y_train=y_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, propensity_test=pi_test, + num_gfr=0, num_burnin=0, num_mcmc=10, + general_params={"random_seed": 1}, + ) + m1 = BCFModel() + m1.sample(**kwargs) + + m2 = BCFModel() + m2.sample(**kwargs, observation_weights=np.ones(n_train)) + + np.testing.assert_array_equal(m1.y_hat_train, m2.y_hat_train) + np.testing.assert_array_equal(m1.tau_hat_train, m2.tau_hat_train) + + def test_nonuniform_weights_output_shape(self): + X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, n_train, n_test = make_bcf_data() + rng = np.random.default_rng(0) + weights = rng.uniform(0.5, 2.0, n_train) + num_mcmc = 10 + + m = BCFModel() + m.sample( + X_train=X_train, Z_train=Z_train, y_train=y_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, propensity_test=pi_test, + observation_weights=weights, + num_gfr=0, num_burnin=0, num_mcmc=num_mcmc, + ) + assert m.y_hat_train.shape == (n_train, num_mcmc) + assert m.tau_hat_train.shape == (n_train, num_mcmc) + assert m.y_hat_test.shape == (n_test, num_mcmc) + assert m.tau_hat_test.shape == (n_test, num_mcmc) + + def test_negative_weights_raises(self): + X_train, Z_train, y_train, pi_train, _, _, _, n_train, _ = make_bcf_data() + weights = np.ones(n_train) + weights[0] = -1.0 + with pytest.raises(ValueError, match="negative"): + BCFModel().sample( + X_train=X_train, Z_train=Z_train, y_train=y_train, + propensity_train=pi_train, + observation_weights=weights, + num_gfr=0, num_burnin=0, num_mcmc=5, + ) + + def test_all_zero_with_gfr_raises(self): + X_train, Z_train, y_train, pi_train, _, _, _, n_train, _ = make_bcf_data() + with pytest.raises(ValueError, match="num_gfr"): + BCFModel().sample( + X_train=X_train, Z_train=Z_train, y_train=y_train, + propensity_train=pi_train, + observation_weights=np.zeros(n_train), + num_gfr=5, num_burnin=0, num_mcmc=10, + ) + + def test_cloglog_raises(self): + X_train, Z_train, y_train, pi_train, _, _, _, n_train, _ = make_bcf_data() + with pytest.raises(ValueError, match="cloglog"): + BCFModel().sample( + X_train=X_train, Z_train=Z_train, y_train=y_train, + propensity_train=pi_train, + observation_weights=np.ones(n_train), + num_gfr=0, num_burnin=0, num_mcmc=5, + general_params={"outcome_model": OutcomeModel(outcome="binary", link="cloglog")}, + )