diff --git a/R/bcf.R b/R/bcf.R index 52bddee0..b7af4f6f 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -74,6 +74,8 @@ NULL #' @param Z_train Vector of (continuous or binary) treatment assignments. #' @param y_train Outcome to be modeled by the ensemble. #' @param propensity_train (Optional) Vector of propensity scores. If not provided, this will be estimated from the data. +#' If `NULL` and `previous_model_json` is provided with an internally estimated propensity model, that model's +#' propensity estimates are re-used rather than re-fitted. #' @param rfx_group_ids_train (Optional) Group labels used for an additive random effects model. #' @param rfx_basis_train (Optional) Basis for "random-slope" regression in an additive random effects model. #' If `rfx_group_ids_train` is provided with a regression basis, an intercept-only random effects model @@ -95,7 +97,12 @@ NULL #' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. #' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. #' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. -#' @param previous_model_json (Optional) JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`. +#' @param previous_model_json (Optional) JSON string containing a previous BCF model. This can be used to "continue" a +#' sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest +#' samples. If the previous model used an internally estimated propensity score (i.e. `propensity_train` was not +#' supplied to that run), the fitted propensity model is carried forward and re-used rather than being re-estimated. +#' This ensures that multi-chain warm-starts remain consistent with the propensity scores used in the initial run. +#' Default: `NULL`. #' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BCF sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`. If `num_chains` in the `general_params` list is > 1, then each successive chain will be initialized from a different sample, counting backwards from `previous_model_warmstart_sample_num`. That is, if `previous_model_warmstart_sample_num = 10` and `num_chains = 4`, then chain 1 will be initialized from sample 10, chain 2 from sample 9, chain 3 from sample 8, and chain 4 from sample 7. If `previous_model_json` is provided but `previous_model_warmstart_sample_num` is NULL, the last sample in the previous model will be used to initialize the first chain, counting backwards as noted before. If more chains are requested than there are samples in `previous_model_json`, a warning will be raised and only the last sample will be used. #' @param general_params (Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional. #' @@ -1282,26 +1289,67 @@ bcf <- function( internal_propensity_model <- FALSE if ((is.null(propensity_train)) && (propensity_covariate != "none")) { internal_propensity_model <- TRUE - # Estimate using the last of several iterations of GFR BART - num_gfr_propensity <- 10 - num_burnin_propensity <- 0 - num_mcmc_propensity <- 10 - bart_model_propensity <- bart( - X_train = X_train, - y_train = as.numeric(Z_train), - X_test = X_test, - num_gfr = num_gfr_propensity, - num_burnin = num_burnin_propensity, - num_mcmc = num_mcmc_propensity - ) - propensity_train <- rowMeans(bart_model_propensity$y_hat_train) - if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { - propensity_train <- as.matrix(propensity_train) - } - if (has_test) { - propensity_test <- rowMeans(bart_model_propensity$y_hat_test) - if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { - propensity_test <- as.matrix(propensity_test) + if ( + has_prev_model && + previous_bcf_model$model_params$internal_propensity_model + ) { + # Reuse the propensity model from the warm-started BCF model rather than + # re-fitting from scratch. Training propensities come from the previous + # model's stored predictions; test propensities are re-predicted on the + # (potentially new) test set. + bart_model_propensity <- previous_bcf_model$bart_propensity_model + propensity_train <- predict( + bart_model_propensity, + X = X_train, + terms = "y_hat", + type = "mean" + ) + if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { + propensity_train <- as.matrix(propensity_train) + } + if (has_test) { + propensity_test <- predict( + bart_model_propensity, + X = X_test, + terms = "y_hat", + type = "mean" + ) + if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { + propensity_test <- as.matrix(propensity_test) + } + } + } else { + # Estimate using the last of several iterations of GFR BART + num_gfr_propensity <- 10 + num_burnin_propensity <- 0 + num_mcmc_propensity <- 10 + bart_model_propensity <- bart( + X_train = X_train, + y_train = as.numeric(Z_train), + X_test = X_test, + num_gfr = num_gfr_propensity, + num_burnin = num_burnin_propensity, + num_mcmc = num_mcmc_propensity + ) + propensity_train <- predict( + bart_model_propensity, + X = X_train, + terms = "y_hat", + type = "mean" + ) + if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { + propensity_train <- as.matrix(propensity_train) + } + if (has_test) { + propensity_test <- predict( + bart_model_propensity, + X = X_test, + terms = "y_hat", + type = "mean" + ) + if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { + propensity_test <- as.matrix(propensity_test) + } } } } diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index a9355500..5b7ff265 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -21,12 +21,6 @@ namespace py = pybind11; using data_size_t = StochTree::data_size_t; -enum ForestLeafModel { - kConstant, - kUnivariateRegression, - kMultivariateRegression -}; - class ForestSamplerCpp; class ForestDatasetCpp { diff --git a/src/stochtree_types.h b/src/stochtree_types.h index 9f4e77df..0e17038f 100644 --- a/src/stochtree_types.h +++ b/src/stochtree_types.h @@ -8,9 +8,3 @@ #include #include #include - -enum ForestLeafModel { - kConstant, - kUnivariateRegression, - kMultivariateRegression -}; diff --git a/stochtree/bcf.py b/stochtree/bcf.py index f1666e1d..1d80d562 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -121,6 +121,8 @@ def sample( Outcome to be modeled by the ensemble. propensity_train : np.array Optional vector of propensity scores. If not provided, this will be estimated from the data. + If ``None`` and ``previous_model_json`` is provided with an internally estimated propensity + model, that model's propensity estimates are re-used rather than re-fitted. rfx_group_ids_train : np.array, optional Optional group labels used for an additive random effects model. rfx_basis_train : np.array, optional @@ -172,6 +174,11 @@ def sample( JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. + If the previous model used an internally estimated propensity score + (i.e. ``propensity_train`` was not supplied to that run), the fitted + propensity model is carried forward and re-used rather than being + re-estimated. This ensures that multi-chain warm-starts remain + consistent with the propensity scores used in the initial run. Defaults to `None`. previous_model_warmstart_sample_num : int, optional Sample number from `previous_model_json` that will be used to @@ -1465,6 +1472,22 @@ def sample( ) propensity_covariate = "none" self.internal_propensity_model = True + elif has_prev_model and previous_bcf_model.internal_propensity_model: + # Reuse the propensity model from the warm-started BCF model rather + # than re-fitting from scratch. Re-predict on the current (preprocessed) + # train and test sets so we don't rely on y_hat_train being present + # after JSON round-trip deserialization. + self.bart_propensity_model = previous_bcf_model.bart_propensity_model + propensity_train = np.expand_dims( + self.bart_propensity_model.predict(X=X_train_processed, terms="y_hat", type="mean"), + 1, + ) + if self.has_test: + propensity_test = np.expand_dims( + self.bart_propensity_model.predict(X=X_test_processed, terms="y_hat", type="mean"), + 1, + ) + self.internal_propensity_model = True else: self.bart_propensity_model = BARTModel() num_gfr_propensity = 10 @@ -1480,11 +1503,9 @@ def sample( num_mcmc=num_mcmc_propensity, general_params={"random_seed": random_seed}, ) - propensity_train = np.mean( - self.bart_propensity_model.y_hat_train, axis=1, keepdims=True - ) - propensity_test = np.mean( - self.bart_propensity_model.y_hat_test, axis=1, keepdims=True + propensity_test = np.expand_dims( + self.bart_propensity_model.predict(X=X_test_processed, terms="y_hat", type="mean"), + 1, ) else: self.bart_propensity_model.sample( @@ -1495,9 +1516,10 @@ def sample( num_mcmc=num_mcmc_propensity, general_params={"random_seed": random_seed}, ) - propensity_train = np.mean( - self.bart_propensity_model.y_hat_train, axis=1, keepdims=True - ) + propensity_train = np.expand_dims( + self.bart_propensity_model.predict(X=X_train_processed, terms="y_hat", type="mean"), + 1, + ) self.internal_propensity_model = True else: self.internal_propensity_model = False diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index de971440..0d117ddf 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -997,3 +997,55 @@ test_that("BCF factor-valued treatment handling", { regexp = "exactly 2 levels" ) }) + +test_that("Warmstart BCF reuses internal propensity model", { + skip_on_cran() + + set.seed(42) + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + pi_X <- 0.25 + 0.5 * X[, 1] + Z <- rbinom(n, 1, pi_X) + y <- pi_X * 3 + X[, 2] * Z + rnorm(n, 0, 1) + n_test <- 20 + n_train <- n - n_test + train_inds <- 1:n_train + test_inds <- (n_train + 1):n + + X_train <- X[train_inds, ] + X_test <- X[test_inds, ] + Z_train <- Z[train_inds] + Z_test <- Z[test_inds] + y_train <- y[train_inds] + + # Fit first model without propensity — triggers internal propensity BART + m1 <- bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, + X_test = X_test, Z_test = Z_test, + num_gfr = 5, num_burnin = 0, num_mcmc = 10 + ) + expect_true(m1$model_params$internal_propensity_model) + + # Propensity predictions from the first model's propensity BART + pi_train_m1 <- predict(m1$bart_propensity_model, X = X_train, terms = "y_hat", type = "mean") + + # Warm-start second model from first — propensity model should be reused + m1_json <- saveBCFModelToJsonString(m1) + m2 <- bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, + X_test = X_test, Z_test = Z_test, + num_gfr = 0, num_burnin = 0, num_mcmc = 10, + previous_model_json = m1_json, + previous_model_warmstart_sample_num = 10L + ) + expect_true(m2$model_params$internal_propensity_model) + + # Propensity model reused: predictions on train set should be identical + pi_train_m2 <- predict(m2$bart_propensity_model, X = X_train, terms = "y_hat", type = "mean") + expect_equal(pi_train_m1, pi_train_m2) + + # Output shapes should be correct + expect_equal(dim(m2$y_hat_train), c(n_train, 10)) + expect_equal(dim(m2$y_hat_test), c(n_test, 10)) +}) diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index ab44c482..1ea84554 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -1068,3 +1068,49 @@ def test_internal_propensity_with_categorical_dataframe(self): assert bcf_model.y_hat_test is not None assert bcf_model.tau_hat_train is not None assert bcf_model.tau_hat_test is not None + + def test_warmstart_reuses_internal_propensity(self): + # When a BCF model fitted without user-supplied propensities is used to + # warm-start a new run, the second run should reuse the internal + # propensity model rather than re-fitting it from scratch. + rng = np.random.default_rng(7) + n = 100 + X = rng.uniform(0, 1, (n, 5)) + pi_X = 0.25 + 0.5 * X[:, 0] + Z = rng.binomial(1, pi_X, n).astype(float) + y = pi_X * 3 + X[:, 1] * Z + rng.normal(0, 1, n) + n_train = 80 + X_train, X_test = X[:n_train], X[n_train:] + Z_train, Z_test = Z[:n_train], Z[n_train:] + y_train = y[:n_train] + + # Fit first model — no propensity provided, so internal model is fitted + m1 = BCFModel() + m1.sample( + X_train=X_train, Z_train=Z_train, y_train=y_train, + X_test=X_test, Z_test=Z_test, + num_gfr=5, num_burnin=0, num_mcmc=10, + general_params={"random_seed": 1}, + ) + assert m1.internal_propensity_model + # Propensity predictions from the first model (via predict, which is + # what the warm-start path uses after JSON round-trip) + pi_train_m1 = m1.bart_propensity_model.predict(X=X_train, terms="y_hat", type="mean") + + # Warm-start a second model from the first — propensity should be reused + m2 = BCFModel() + m2.sample( + X_train=X_train, Z_train=Z_train, y_train=y_train, + X_test=X_test, Z_test=Z_test, + num_gfr=0, num_burnin=0, num_mcmc=10, + previous_model_json=m1.to_json(), + previous_model_warmstart_sample_num=9, + general_params={"random_seed": 2}, + ) + assert m2.internal_propensity_model + # Propensities used in m2 should match those from the reused propensity model + pi_train_m2 = m2.bart_propensity_model.predict(X_train, terms="y_hat", type="mean") + np.testing.assert_array_equal(pi_train_m1, pi_train_m2) + # Output shapes should be correct + assert m2.y_hat_train.shape == (n_train, 10) + assert m2.y_hat_test.shape == (n - n_train, 10)