Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 69 additions & 21 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ NULL
#' @param Z_train Vector of (continuous or binary) treatment assignments.
#' @param y_train Outcome to be modeled by the ensemble.
#' @param propensity_train (Optional) Vector of propensity scores. If not provided, this will be estimated from the data.
#' If `NULL` and `previous_model_json` is provided with an internally estimated propensity model, that model's
#' propensity estimates are re-used rather than re-fitted.
#' @param rfx_group_ids_train (Optional) Group labels used for an additive random effects model.
#' @param rfx_basis_train (Optional) Basis for "random-slope" regression in an additive random effects model.
#' If `rfx_group_ids_train` is provided with a regression basis, an intercept-only random effects model
Expand All @@ -95,7 +97,12 @@ NULL
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
#' @param previous_model_json (Optional) JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`.
#' @param previous_model_json (Optional) JSON string containing a previous BCF model. This can be used to "continue" a
#' sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest
#' samples. If the previous model used an internally estimated propensity score (i.e. `propensity_train` was not
#' supplied to that run), the fitted propensity model is carried forward and re-used rather than being re-estimated.
#' This ensures that multi-chain warm-starts remain consistent with the propensity scores used in the initial run.
#' Default: `NULL`.
#' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BCF sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`. If `num_chains` in the `general_params` list is > 1, then each successive chain will be initialized from a different sample, counting backwards from `previous_model_warmstart_sample_num`. That is, if `previous_model_warmstart_sample_num = 10` and `num_chains = 4`, then chain 1 will be initialized from sample 10, chain 2 from sample 9, chain 3 from sample 8, and chain 4 from sample 7. If `previous_model_json` is provided but `previous_model_warmstart_sample_num` is NULL, the last sample in the previous model will be used to initialize the first chain, counting backwards as noted before. If more chains are requested than there are samples in `previous_model_json`, a warning will be raised and only the last sample will be used.
#' @param general_params (Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
Expand Down Expand Up @@ -1282,26 +1289,67 @@ bcf <- function(
internal_propensity_model <- FALSE
if ((is.null(propensity_train)) && (propensity_covariate != "none")) {
internal_propensity_model <- TRUE
# Estimate using the last of several iterations of GFR BART
num_gfr_propensity <- 10
num_burnin_propensity <- 0
num_mcmc_propensity <- 10
bart_model_propensity <- bart(
X_train = X_train,
y_train = as.numeric(Z_train),
X_test = X_test,
num_gfr = num_gfr_propensity,
num_burnin = num_burnin_propensity,
num_mcmc = num_mcmc_propensity
)
propensity_train <- rowMeans(bart_model_propensity$y_hat_train)
if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) {
propensity_train <- as.matrix(propensity_train)
}
if (has_test) {
propensity_test <- rowMeans(bart_model_propensity$y_hat_test)
if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) {
propensity_test <- as.matrix(propensity_test)
if (
has_prev_model &&
previous_bcf_model$model_params$internal_propensity_model
) {
# Reuse the propensity model from the warm-started BCF model rather than
# re-fitting from scratch. Training propensities come from the previous
# model's stored predictions; test propensities are re-predicted on the
# (potentially new) test set.
bart_model_propensity <- previous_bcf_model$bart_propensity_model
propensity_train <- predict(
bart_model_propensity,
X = X_train,
terms = "y_hat",
type = "mean"
)
if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) {
propensity_train <- as.matrix(propensity_train)
}
if (has_test) {
propensity_test <- predict(
bart_model_propensity,
X = X_test,
terms = "y_hat",
type = "mean"
)
if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) {
propensity_test <- as.matrix(propensity_test)
}
}
} else {
# Estimate using the last of several iterations of GFR BART
num_gfr_propensity <- 10
num_burnin_propensity <- 0
num_mcmc_propensity <- 10
bart_model_propensity <- bart(
X_train = X_train,
y_train = as.numeric(Z_train),
X_test = X_test,
num_gfr = num_gfr_propensity,
num_burnin = num_burnin_propensity,
num_mcmc = num_mcmc_propensity
)
propensity_train <- predict(
bart_model_propensity,
X = X_train,
terms = "y_hat",
type = "mean"
)
if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) {
propensity_train <- as.matrix(propensity_train)
}
if (has_test) {
propensity_test <- predict(
bart_model_propensity,
X = X_test,
terms = "y_hat",
type = "mean"
)
if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) {
propensity_test <- as.matrix(propensity_test)
}
}
}
}
Expand Down
6 changes: 0 additions & 6 deletions src/py_stochtree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@
namespace py = pybind11;
using data_size_t = StochTree::data_size_t;

enum ForestLeafModel {
kConstant,
kUnivariateRegression,
kMultivariateRegression
};

class ForestSamplerCpp;

class ForestDatasetCpp {
Expand Down
6 changes: 0 additions & 6 deletions src/stochtree_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,3 @@
#include <stochtree/partition_tracker.h>
#include <stochtree/random_effects.h>
#include <stochtree/tree_sampler.h>

enum ForestLeafModel {
kConstant,
kUnivariateRegression,
kMultivariateRegression
};
38 changes: 30 additions & 8 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def sample(
Outcome to be modeled by the ensemble.
propensity_train : np.array
Optional vector of propensity scores. If not provided, this will be estimated from the data.
If ``None`` and ``previous_model_json`` is provided with an internally estimated propensity
model, that model's propensity estimates are re-used rather than re-fitted.
rfx_group_ids_train : np.array, optional
Optional group labels used for an additive random effects model.
rfx_basis_train : np.array, optional
Expand Down Expand Up @@ -172,6 +174,11 @@ def sample(
JSON string containing a previous BCF model. This can be used to
"continue" a sampler interactively after inspecting the samples or
to run parallel chains "warm-started" from existing forest samples.
If the previous model used an internally estimated propensity score
(i.e. ``propensity_train`` was not supplied to that run), the fitted
propensity model is carried forward and re-used rather than being
re-estimated. This ensures that multi-chain warm-starts remain
consistent with the propensity scores used in the initial run.
Defaults to `None`.
previous_model_warmstart_sample_num : int, optional
Sample number from `previous_model_json` that will be used to
Expand Down Expand Up @@ -1465,6 +1472,22 @@ def sample(
)
propensity_covariate = "none"
self.internal_propensity_model = True
elif has_prev_model and previous_bcf_model.internal_propensity_model:
# Reuse the propensity model from the warm-started BCF model rather
# than re-fitting from scratch. Re-predict on the current (preprocessed)
# train and test sets so we don't rely on y_hat_train being present
# after JSON round-trip deserialization.
self.bart_propensity_model = previous_bcf_model.bart_propensity_model
propensity_train = np.expand_dims(
self.bart_propensity_model.predict(X=X_train_processed, terms="y_hat", type="mean"),
1,
)
if self.has_test:
propensity_test = np.expand_dims(
self.bart_propensity_model.predict(X=X_test_processed, terms="y_hat", type="mean"),
1,
)
self.internal_propensity_model = True
else:
self.bart_propensity_model = BARTModel()
num_gfr_propensity = 10
Expand All @@ -1480,11 +1503,9 @@ def sample(
num_mcmc=num_mcmc_propensity,
general_params={"random_seed": random_seed},
)
propensity_train = np.mean(
self.bart_propensity_model.y_hat_train, axis=1, keepdims=True
)
propensity_test = np.mean(
self.bart_propensity_model.y_hat_test, axis=1, keepdims=True
propensity_test = np.expand_dims(
self.bart_propensity_model.predict(X=X_test_processed, terms="y_hat", type="mean"),
1,
)
else:
self.bart_propensity_model.sample(
Expand All @@ -1495,9 +1516,10 @@ def sample(
num_mcmc=num_mcmc_propensity,
general_params={"random_seed": random_seed},
)
propensity_train = np.mean(
self.bart_propensity_model.y_hat_train, axis=1, keepdims=True
)
propensity_train = np.expand_dims(
self.bart_propensity_model.predict(X=X_train_processed, terms="y_hat", type="mean"),
1,
)
self.internal_propensity_model = True
else:
self.internal_propensity_model = False
Expand Down
52 changes: 52 additions & 0 deletions test/R/testthat/test-bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -997,3 +997,55 @@ test_that("BCF factor-valued treatment handling", {
regexp = "exactly 2 levels"
)
})

test_that("Warmstart BCF reuses internal propensity model", {
skip_on_cran()

set.seed(42)
n <- 100
p <- 5
X <- matrix(runif(n * p), ncol = p)
pi_X <- 0.25 + 0.5 * X[, 1]
Z <- rbinom(n, 1, pi_X)
y <- pi_X * 3 + X[, 2] * Z + rnorm(n, 0, 1)
n_test <- 20
n_train <- n - n_test
train_inds <- 1:n_train
test_inds <- (n_train + 1):n

X_train <- X[train_inds, ]
X_test <- X[test_inds, ]
Z_train <- Z[train_inds]
Z_test <- Z[test_inds]
y_train <- y[train_inds]

# Fit first model without propensity — triggers internal propensity BART
m1 <- bcf(
X_train = X_train, Z_train = Z_train, y_train = y_train,
X_test = X_test, Z_test = Z_test,
num_gfr = 5, num_burnin = 0, num_mcmc = 10
)
expect_true(m1$model_params$internal_propensity_model)

# Propensity predictions from the first model's propensity BART
pi_train_m1 <- predict(m1$bart_propensity_model, X = X_train, terms = "y_hat", type = "mean")

# Warm-start second model from first — propensity model should be reused
m1_json <- saveBCFModelToJsonString(m1)
m2 <- bcf(
X_train = X_train, Z_train = Z_train, y_train = y_train,
X_test = X_test, Z_test = Z_test,
num_gfr = 0, num_burnin = 0, num_mcmc = 10,
previous_model_json = m1_json,
previous_model_warmstart_sample_num = 10L
)
expect_true(m2$model_params$internal_propensity_model)

# Propensity model reused: predictions on train set should be identical
pi_train_m2 <- predict(m2$bart_propensity_model, X = X_train, terms = "y_hat", type = "mean")
expect_equal(pi_train_m1, pi_train_m2)

# Output shapes should be correct
expect_equal(dim(m2$y_hat_train), c(n_train, 10))
expect_equal(dim(m2$y_hat_test), c(n_test, 10))
})
46 changes: 46 additions & 0 deletions test/python/test_bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,3 +1068,49 @@ def test_internal_propensity_with_categorical_dataframe(self):
assert bcf_model.y_hat_test is not None
assert bcf_model.tau_hat_train is not None
assert bcf_model.tau_hat_test is not None

def test_warmstart_reuses_internal_propensity(self):
# When a BCF model fitted without user-supplied propensities is used to
# warm-start a new run, the second run should reuse the internal
# propensity model rather than re-fitting it from scratch.
rng = np.random.default_rng(7)
n = 100
X = rng.uniform(0, 1, (n, 5))
pi_X = 0.25 + 0.5 * X[:, 0]
Z = rng.binomial(1, pi_X, n).astype(float)
y = pi_X * 3 + X[:, 1] * Z + rng.normal(0, 1, n)
n_train = 80
X_train, X_test = X[:n_train], X[n_train:]
Z_train, Z_test = Z[:n_train], Z[n_train:]
y_train = y[:n_train]

# Fit first model — no propensity provided, so internal model is fitted
m1 = BCFModel()
m1.sample(
X_train=X_train, Z_train=Z_train, y_train=y_train,
X_test=X_test, Z_test=Z_test,
num_gfr=5, num_burnin=0, num_mcmc=10,
general_params={"random_seed": 1},
)
assert m1.internal_propensity_model
# Propensity predictions from the first model (via predict, which is
# what the warm-start path uses after JSON round-trip)
pi_train_m1 = m1.bart_propensity_model.predict(X=X_train, terms="y_hat", type="mean")

# Warm-start a second model from the first — propensity should be reused
m2 = BCFModel()
m2.sample(
X_train=X_train, Z_train=Z_train, y_train=y_train,
X_test=X_test, Z_test=Z_test,
num_gfr=0, num_burnin=0, num_mcmc=10,
previous_model_json=m1.to_json(),
previous_model_warmstart_sample_num=9,
general_params={"random_seed": 2},
)
assert m2.internal_propensity_model
# Propensities used in m2 should match those from the reused propensity model
pi_train_m2 = m2.bart_propensity_model.predict(X_train, terms="y_hat", type="mean")
np.testing.assert_array_equal(pi_train_m1, pi_train_m2)
# Output shapes should be correct
assert m2.y_hat_train.shape == (n_train, 10)
assert m2.y_hat_test.shape == (n - n_train, 10)
Loading