Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
## New Features

* Added support for parametric treatment effect term in BCF [#309](https://github.com/StochasticTree/stochtree/pull/309/)
* Added support for observation-level weights passed as data arguments to BART and BCF [#333](https://github.com/StochasticTree/stochtree/pull/333)

## Bug Fixes

* Fixed multi-chain BCF bugs with the parametric intercept term in R and Python [#326](https://github.com/StochasticTree/stochtree/pull/326)
* Fixed indexing bugs for multivariate treatment BCF in Python [#326](https://github.com/StochasticTree/stochtree/pull/326)
* Convert binary factor-valued treatments to 0/1 binary numeric treatment in `bcf()` R function [#332](https://github.com/StochasticTree/stochtree/pull/332)

## Documentation and Other Maintenance

Expand Down
56 changes: 53 additions & 3 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ NULL
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
#' that were not in the training set.
#' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model.
#' @param observation_weights (Optional) Numeric vector of observation weights of length `nrow(X_train)`. Weights are
#' applied as `y_i | - ~ N(mu(X_i), sigma^2 / w_i)`, so larger weights increase an observation's influence on the fit.
#' All weights must be non-negative. Default: `NULL` (all observations equally weighted). Compatible with Gaussian
#' (continuous/identity) and probit outcome models; not compatible with cloglog link functions. Note: these are
#' referred to internally in the C++ layer as "variance weights" (`var_weights`), since they scale the residual variance.
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
Expand Down Expand Up @@ -189,6 +194,7 @@ bart <- function(
leaf_basis_test = NULL,
rfx_group_ids_test = NULL,
rfx_basis_test = NULL,
observation_weights = NULL,
num_gfr = 5,
num_burnin = 0,
num_mcmc = 100,
Expand Down Expand Up @@ -507,6 +513,20 @@ bart <- function(
include_mean_forest = FALSE
}

# observation_weights compatibility checks
if (!is.null(observation_weights)) {
if (link_is_cloglog) {
stop(
"observation_weights are not compatible with cloglog link functions."
)
}
if (include_variance_forest) {
warning(
"Results may be unreliable when observation_weights are deployed alongside a variance forest model."
)
}
}

# Set the variance forest priors if not set
if (include_variance_forest) {
if (is.null(a_forest)) {
Expand All @@ -531,6 +551,26 @@ bart <- function(
stop("variable_weights cannot have any negative weights")
}

# Observation weight validation
if (!is.null(observation_weights)) {
if (!is.numeric(observation_weights)) {
stop("observation_weights must be a numeric vector")
}
if (length(observation_weights) != nrow(X_train)) {
stop("length(observation_weights) must equal nrow(X_train)")
}
if (any(observation_weights < 0)) {
stop("observation_weights cannot have any negative values")
}
if (all(observation_weights == 0) && num_gfr > 0) {
stop(
"observation_weights are all zero (prior sampling mode) but num_gfr > 0. ",
"GFR warm-start is data-dependent and ill-defined with zero weights. ",
"Set num_gfr = 0 when using all-zero observation_weights."
)
}
}

# Check covariates are matrix or dataframe
if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
stop("X_train must be a matrix or dataframe")
Expand Down Expand Up @@ -1217,13 +1257,20 @@ bart <- function(

# Data
if (leaf_regression) {
forest_dataset_train <- createForestDataset(X_train, leaf_basis_train)
forest_dataset_train <- createForestDataset(
X_train,
leaf_basis_train,
observation_weights
)
if (has_test) {
forest_dataset_test <- createForestDataset(X_test, leaf_basis_test)
}
requires_basis <- TRUE
} else {
forest_dataset_train <- createForestDataset(X_train)
forest_dataset_train <- createForestDataset(
X_train,
variance_weights = observation_weights
)
if (has_test) {
forest_dataset_test <- createForestDataset(X_test)
}
Expand Down Expand Up @@ -4460,7 +4507,10 @@ createBARTModelFromCombinedJsonString <- function(json_string_list) {
"outcome",
"outcome_model"
)
outcome_model_link <- json_object_default$get_string("link", "outcome_model")
outcome_model_link <- json_object_default$get_string(
"link",
"outcome_model"
)
} else {
outcome_model_outcome <- "continuous"
outcome_model_link <- "identity"
Expand Down
136 changes: 115 additions & 21 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ NULL
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
#' that were not in the training set.
#' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model.
#' @param observation_weights (Optional) Numeric vector of observation weights of length `nrow(X_train)`. Weights are
#' applied as `y_i | - ~ N(mu(X_i), sigma^2 / w_i)`, so larger weights increase an observation's influence on the fit.
#' All weights must be non-negative. Default: `NULL` (all observations equally weighted). Applied to both the
#' prognostic and treatment effect forests. Compatible with Gaussian (continuous/identity) and probit outcome models;
#' not compatible with cloglog link functions.
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
Expand Down Expand Up @@ -237,6 +242,7 @@ bcf <- function(
propensity_test = NULL,
rfx_group_ids_test = NULL,
rfx_basis_test = NULL,
observation_weights = NULL,
num_gfr = 5,
num_burnin = 0,
num_mcmc = 100,
Expand Down Expand Up @@ -616,6 +622,36 @@ bcf <- function(
include_variance_forest = FALSE
}

# observation_weights validation and compatibility checks
if (!is.null(observation_weights)) {
if (!is.numeric(observation_weights)) {
stop("observation_weights must be a numeric vector")
}
if (length(observation_weights) != nrow(X_train)) {
stop("length(observation_weights) must equal nrow(X_train)")
}
if (any(observation_weights < 0)) {
stop("observation_weights cannot have any negative values")
}
if (all(observation_weights == 0) && num_gfr > 0) {
stop(
"observation_weights are all zero (prior sampling mode) but num_gfr > 0. ",
"GFR warm-start is data-dependent and ill-defined with zero weights. ",
"Set num_gfr = 0 when using all-zero observation_weights."
)
}
if (link_is_cloglog) {
stop(
"observation_weights are not compatible with cloglog link functions."
)
}
if (include_variance_forest) {
warning(
"Results may be unreliable when observation_weights are deployed alongside a variance forest model."
)
}
}

# Set the variance forest priors if not set
if (include_variance_forest) {
if (is.null(a_forest)) {
Expand Down Expand Up @@ -933,7 +969,10 @@ bcf <- function(
}
message(
"Z_train is a factor; converting to 0/1 using level order: ",
lvls[1], " = 0, ", lvls[2], " = 1"
lvls[1],
" = 0, ",
lvls[2],
" = 1"
)
Z_train <- as.integer(Z_train) - 1L
}
Expand All @@ -944,7 +983,10 @@ bcf <- function(
}
message(
"Z_test is a factor; converting to 0/1 using level order: ",
lvls[1], " = 0, ", lvls[2], " = 1"
lvls[1],
" = 0, ",
lvls[2],
" = 1"
)
Z_test <- as.integer(Z_test) - 1L
}
Expand Down Expand Up @@ -1681,7 +1723,11 @@ bcf <- function(
}

# Data
forest_dataset_train <- createForestDataset(X_train, tau_basis_train)
forest_dataset_train <- createForestDataset(
X_train,
tau_basis_train,
observation_weights
)
if (has_test) {
forest_dataset_test <- createForestDataset(X_test, tau_basis_test)
}
Expand Down Expand Up @@ -2360,7 +2406,9 @@ bcf <- function(
if (sample_tau_0 && !is.null(previous_tau_0_samples)) {
tau_0_old <- tau_0
# previous model stores tau_0 in original scale; convert to standardized scale
tau_0 <- as.numeric(previous_tau_0_samples[, warmstart_index] / previous_y_scale)
tau_0 <- as.numeric(
previous_tau_0_samples[, warmstart_index] / previous_y_scale
)
Z_basis_ws <- as.matrix(tau_basis_train)
outcome_train$subtract_vector(
as.numeric(Z_basis_ws %*% matrix(tau_0 - tau_0_old, ncol = 1))
Expand Down Expand Up @@ -3456,7 +3504,10 @@ predict.bcfmodel <- function(
}
warning(
"Z is a factor; recoding to 0/1 using level order: ",
lvls[1], " = 0, ", lvls[2], " = 1"
lvls[1],
" = 0, ",
lvls[2],
" = 1"
)
Z <- as.integer(Z) - 1L
}
Expand Down Expand Up @@ -4820,7 +4871,9 @@ createBCFModelFromJson <- function(json_object) {

# Version inference and presence-check helpers
.ver <- inferStochtreeJsonVersion(json_object)
has_field <- function(name) json_contains_field_cpp(json_object$json_ptr, name)
has_field <- function(name) {
json_contains_field_cpp(json_object$json_ptr, name)
}
has_subfolder_field <- function(subfolder, name) {
json_contains_field_subfolder_cpp(json_object$json_ptr, subfolder, name)
}
Expand Down Expand Up @@ -5036,11 +5089,23 @@ createBCFModelFromJson <- function(json_object) {
}
if (model_params[["adaptive_coding"]]) {
if (has_subfolder_field("parameters", "b1_samples")) {
output[["b_1_samples"]] <- json_object$get_vector("b1_samples", "parameters")
output[["b_0_samples"]] <- json_object$get_vector("b0_samples", "parameters")
output[["b_1_samples"]] <- json_object$get_vector(
"b1_samples",
"parameters"
)
output[["b_0_samples"]] <- json_object$get_vector(
"b0_samples",
"parameters"
)
} else {
output[["b_1_samples"]] <- json_object$get_vector("b_1_samples", "parameters")
output[["b_0_samples"]] <- json_object$get_vector("b_0_samples", "parameters")
output[["b_1_samples"]] <- json_object$get_vector(
"b_1_samples",
"parameters"
)
output[["b_0_samples"]] <- json_object$get_vector(
"b_0_samples",
"parameters"
)
warning(sprintf(
"JSON fields 'b_1_samples'/'b_0_samples' are deprecated; please re-save the model to use 'b1_samples'/'b0_samples' (inferred version: %s).",
.ver
Expand Down Expand Up @@ -5133,9 +5198,15 @@ createBCFModelFromCombinedJson <- function(json_object_list) {

# Version inference and presence-check helpers
.ver <- inferStochtreeJsonVersion(json_object_default)
has_field <- function(name) json_contains_field_cpp(json_object_default$json_ptr, name)
has_field <- function(name) {
json_contains_field_cpp(json_object_default$json_ptr, name)
}
has_subfolder_field <- function(subfolder, name) {
json_contains_field_subfolder_cpp(json_object_default$json_ptr, subfolder, name)
json_contains_field_subfolder_cpp(
json_object_default$json_ptr,
subfolder,
name
)
}

# Unpack the forests
Expand Down Expand Up @@ -5209,9 +5280,13 @@ createBCFModelFromCombinedJson <- function(json_object_list) {
"standardize"
)
if (has_field("sigma2_init")) {
model_params[["initial_sigma2"]] <- json_object_default$get_scalar("sigma2_init")
model_params[["initial_sigma2"]] <- json_object_default$get_scalar(
"sigma2_init"
)
} else {
model_params[["initial_sigma2"]] <- json_object_default$get_scalar("initial_sigma2")
model_params[["initial_sigma2"]] <- json_object_default$get_scalar(
"initial_sigma2"
)
warning(sprintf(
"JSON field 'initial_sigma2' is deprecated; please re-save the model to use 'sigma2_init' (inferred version: %s).",
.ver
Expand Down Expand Up @@ -5319,7 +5394,10 @@ createBCFModelFromCombinedJson <- function(json_object_list) {
"outcome",
"outcome_model"
)
outcome_model_link <- json_object_default$get_string("link", "outcome_model")
outcome_model_link <- json_object_default$get_string(
"link",
"outcome_model"
)
} else {
outcome_model_outcome <- "continuous"
outcome_model_link <- "identity"
Expand Down Expand Up @@ -5551,7 +5629,10 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) {
# We don't support merging BCF models with independent propensity models
# this way at the moment
if (
json_contains_field_cpp(json_object_list[[i]]$json_ptr, "internal_propensity_model") &&
json_contains_field_cpp(
json_object_list[[i]]$json_ptr,
"internal_propensity_model"
) &&
json_object_list[[i]]$get_boolean("internal_propensity_model")
) {
stop(
Expand All @@ -5566,9 +5647,15 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) {

# Version inference and presence-check helpers
.ver <- inferStochtreeJsonVersion(json_object_default)
has_field <- function(name) json_contains_field_cpp(json_object_default$json_ptr, name)
has_field <- function(name) {
json_contains_field_cpp(json_object_default$json_ptr, name)
}
has_subfolder_field <- function(subfolder, name) {
json_contains_field_subfolder_cpp(json_object_default$json_ptr, subfolder, name)
json_contains_field_subfolder_cpp(
json_object_default$json_ptr,
subfolder,
name
)
}

# Unpack the forests
Expand Down Expand Up @@ -5642,9 +5729,13 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) {
"standardize"
)
if (has_field("sigma2_init")) {
model_params[["initial_sigma2"]] <- json_object_default$get_scalar("sigma2_init")
model_params[["initial_sigma2"]] <- json_object_default$get_scalar(
"sigma2_init"
)
} else {
model_params[["initial_sigma2"]] <- json_object_default$get_scalar("initial_sigma2")
model_params[["initial_sigma2"]] <- json_object_default$get_scalar(
"initial_sigma2"
)
warning(sprintf(
"JSON field 'initial_sigma2' is deprecated; please re-save the model to use 'sigma2_init' (inferred version: %s).",
.ver
Expand Down Expand Up @@ -5752,7 +5843,10 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) {
"outcome",
"outcome_model"
)
outcome_model_link <- json_object_default$get_string("link", "outcome_model")
outcome_model_link <- json_object_default$get_string(
"link",
"outcome_model"
)
} else {
outcome_model_outcome <- "continuous"
outcome_model_link <- "identity"
Expand Down
7 changes: 7 additions & 0 deletions man/bart.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading