Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,9 @@ bart <- function(
# Handle standardization, prior calibration, and initialization of forest
# differently for binary and continuous outcomes
if (link_is_probit) {
# Compute a probit-scale offset and fix scale to 1
# Probit-scale intercept: center the forest on the population-average latent mean.
# The forest predicts mu(X) and y_bar_train is added back at prediction time.
# The latent z sampling uses y_bar_train to set the correct truncated normal mean and to center z before the residual update.
y_bar_train <- qnorm(mean_cpp(as.numeric(y_train)))
y_std_train <- 1
standardize <- FALSE
Expand Down Expand Up @@ -1591,6 +1593,10 @@ bart <- function(
if (include_mean_forest) {
if (link_is_probit) {
# Sample latent probit variable, z | -
# outcome_pred is the centered forest prediction (not including y_bar_train).
# The truncated normal mean is outcome_pred + y_bar_train (the full eta on the probit scale).
# The residual stored is z - y_bar_train - outcome_pred so the forest sees a
# zero-centered signal and the prior shrinkage toward 0 is well-calibrated.
outcome_pred <- active_forest_mean$predict(
forest_dataset_train
)
Expand All @@ -1601,15 +1607,16 @@ bart <- function(
)
outcome_pred <- outcome_pred + rfx_pred
}
mu0 <- outcome_pred[y_train == 0]
mu1 <- outcome_pred[y_train == 1]
eta_pred <- outcome_pred + y_bar_train
mu0 <- eta_pred[y_train == 0]
mu1 <- eta_pred[y_train == 1]
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
resid_train[y_train == 0] <- mu0 + qnorm(u0)
resid_train[y_train == 1] <- mu1 + qnorm(u1)

# Update outcome
outcome_train$update_data(resid_train - outcome_pred)
# Update outcome: center z by y_bar_train before passing to forest
outcome_train$update_data(resid_train - y_bar_train - outcome_pred)
}

# Sample mean forest
Expand Down Expand Up @@ -2127,15 +2134,18 @@ bart <- function(
)
outcome_pred <- outcome_pred + rfx_pred
}
mu0 <- outcome_pred[y_train == 0]
mu1 <- outcome_pred[y_train == 1]
eta_pred <- outcome_pred + y_bar_train
mu0 <- eta_pred[y_train == 0]
mu1 <- eta_pred[y_train == 1]
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
resid_train[y_train == 0] <- mu0 + qnorm(u0)
resid_train[y_train == 1] <- mu1 + qnorm(u1)

# Update outcome
outcome_train$update_data(resid_train - outcome_pred)
# Update outcome: center z by y_bar_train before passing to forest
outcome_train$update_data(
resid_train - y_bar_train - outcome_pred
)
}

forest_model_mean$sample_one_iteration(
Expand Down
62 changes: 49 additions & 13 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,9 @@ bcf <- function(
# Handle standardization, prior calibration, and initialization of forest
# differently for binary and continuous outcomes
if (link_is_probit) {
# Compute a probit-scale offset and fix scale to 1
# Probit-scale intercept: center the forest on the population-average latent mean.
# The forest predicts mu(X) and y_bar_train is added back at prediction time.
# The latent z sampling uses y_bar_train to set the correct truncated normal mean and to center z before the residual update.
y_bar_train <- qnorm(mean_cpp(as.numeric(y_train)))
y_std_train <- 1

Expand Down Expand Up @@ -1948,6 +1950,10 @@ bcf <- function(

if (link_is_probit) {
# Sample latent probit variable, z | -
# outcome_pred is the centered forest prediction (not including y_bar_train).
# The truncated normal mean is outcome_pred + y_bar_train (the full eta on the probit scale).
# The residual stored is z - y_bar_train - outcome_pred so the forest sees a
# zero-centered signal and the prior shrinkage toward 0 is well-calibrated.
mu_forest_pred <- active_forest_mu$predict(forest_dataset_train)
tau_forest_pred <- active_forest_tau$predict(
forest_dataset_train
Expand All @@ -1960,15 +1966,16 @@ bcf <- function(
)
outcome_pred <- outcome_pred + rfx_pred
}
mu0 <- outcome_pred[y_train == 0]
mu1 <- outcome_pred[y_train == 1]
eta_pred <- outcome_pred + y_bar_train
mu0 <- eta_pred[y_train == 0]
mu1 <- eta_pred[y_train == 1]
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
resid_train[y_train == 0] <- mu0 + qnorm(u0)
resid_train[y_train == 1] <- mu1 + qnorm(u1)

# Update outcome
outcome_train$update_data(resid_train - outcome_pred)
# Update outcome: center z by y_bar_train before passing to forests
outcome_train$update_data(resid_train - y_bar_train - outcome_pred)
}

# Sample the prognostic forest
Expand Down Expand Up @@ -2028,7 +2035,14 @@ bcf <- function(
Z_basis_mat <- as.matrix(tau_basis_train)
# tau(X) * basis contribution per observation
tau_x_full <- rowSums(Z_basis_mat * as.matrix(tau_x_raw_tau0))
partial_resid_tau0 <- resid_train -
# For probit, resid_train holds the full-scale latent z; center it so that
# tau_0 does not absorb the probit intercept y_bar_train.
resid_for_tau0 <- if (link_is_probit) {
resid_train - y_bar_train
} else {
resid_train
}
partial_resid_tau0 <- resid_for_tau0 -
as.numeric(mu_x_raw_tau0) -
tau_x_full
if (has_rfx) {
Expand Down Expand Up @@ -2087,7 +2101,14 @@ bcf <- function(
tau_x_raw_train <- active_forest_tau$predict_raw(
forest_dataset_train
)
partial_resid_mu_train <- resid_train - mu_x_raw_train
# For probit, resid_train holds full-scale z; center it so b_0/b_1 do not
# absorb the probit intercept y_bar_train.
resid_for_coding <- if (link_is_probit) {
resid_train - y_bar_train
} else {
resid_train
}
partial_resid_mu_train <- resid_for_coding - mu_x_raw_train
if (has_rfx) {
rfx_preds_train <- rfx_model$predict(
rfx_dataset_train,
Expand Down Expand Up @@ -2698,15 +2719,16 @@ bcf <- function(
)
outcome_pred <- outcome_pred + rfx_pred
}
mu0 <- outcome_pred[y_train == 0]
mu1 <- outcome_pred[y_train == 1]
eta_pred <- outcome_pred + y_bar_train
mu0 <- eta_pred[y_train == 0]
mu1 <- eta_pred[y_train == 1]
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
resid_train[y_train == 0] <- mu0 + qnorm(u0)
resid_train[y_train == 1] <- mu1 + qnorm(u1)

# Update outcome
outcome_train$update_data(resid_train - outcome_pred)
# Update outcome: center z by y_bar_train before passing to forests
outcome_train$update_data(resid_train - y_bar_train - outcome_pred)
}

# Sample the prognostic forest
Expand Down Expand Up @@ -2768,7 +2790,14 @@ bcf <- function(
Z_basis_mat <- as.matrix(tau_basis_train)
# tau(X) * basis contribution per observation
tau_x_full <- rowSums(Z_basis_mat * as.matrix(tau_x_raw_tau0))
partial_resid_tau0 <- resid_train -
# For probit, resid_train holds the full-scale latent z; center it so that
# tau_0 does not absorb the probit intercept y_bar_train.
resid_for_tau0 <- if (link_is_probit) {
resid_train - y_bar_train
} else {
resid_train
}
partial_resid_tau0 <- resid_for_tau0 -
as.numeric(mu_x_raw_tau0) -
tau_x_full
if (has_rfx) {
Expand Down Expand Up @@ -2827,7 +2856,14 @@ bcf <- function(
tau_x_raw_train <- active_forest_tau$predict_raw(
forest_dataset_train
)
partial_resid_mu_train <- resid_train - mu_x_raw_train
# For probit, resid_train holds full-scale z; center it so b_0/b_1 do not
# absorb the probit intercept y_bar_train.
resid_for_coding <- if (link_is_probit) {
resid_train - y_bar_train
} else {
resid_train
}
partial_resid_mu_train <- resid_for_coding - mu_x_raw_train
if (has_rfx) {
rfx_preds_train <- rfx_model$predict(
rfx_dataset_train,
Expand Down
20 changes: 12 additions & 8 deletions stochtree/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,8 +1560,10 @@ def sample(
if self.has_rfx:
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
outcome_pred = outcome_pred + rfx_pred
mu0 = outcome_pred[y_train[:, 0] == 0]
mu1 = outcome_pred[y_train[:, 0] == 1]
# Full probit-scale predictor: forest learns z - y_bar, so add y_bar back
eta_pred = outcome_pred + self.y_bar
mu0 = eta_pred[y_train[:, 0] == 0]
mu1 = eta_pred[y_train[:, 0] == 1]
n0 = np.sum(y_train[:, 0] == 0)
n1 = np.sum(y_train[:, 0] == 1)
u0 = self.rng.uniform(
Expand All @@ -1577,8 +1579,8 @@ def sample(
resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0)
resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1)

# Update outcome
new_outcome = np.squeeze(resid_train) - outcome_pred
# Update outcome: center z by y_bar before passing to forest
new_outcome = np.squeeze(resid_train) - self.y_bar - outcome_pred
residual_train.update_data(new_outcome)

# Sample the mean forest
Expand Down Expand Up @@ -1885,8 +1887,10 @@ def sample(
rfx_dataset_train, rfx_tracker
)
outcome_pred = outcome_pred + rfx_pred
mu0 = outcome_pred[y_train[:, 0] == 0]
mu1 = outcome_pred[y_train[:, 0] == 1]
# Full probit-scale predictor: forest learns z - y_bar, so add y_bar back
eta_pred = outcome_pred + self.y_bar
mu0 = eta_pred[y_train[:, 0] == 0]
mu1 = eta_pred[y_train[:, 0] == 1]
n0 = np.sum(y_train[:, 0] == 0)
n1 = np.sum(y_train[:, 0] == 1)
u0 = self.rng.uniform(
Expand All @@ -1902,8 +1906,8 @@ def sample(
resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0)
resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1)

# Update outcome
new_outcome = np.squeeze(resid_train) - outcome_pred
# Update outcome: center z by y_bar before passing to forest
new_outcome = np.squeeze(resid_train) - self.y_bar - outcome_pred
residual_train.update_data(new_outcome)

# Sample the mean forest
Expand Down
39 changes: 27 additions & 12 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,14 +2121,19 @@ def sample(

if link_is_probit:
# Sample latent probit variable z | -
# outcome_pred is the centered forest prediction (not including y_bar_train).
# The truncated normal mean is outcome_pred + y_bar_train (the full eta on the probit scale).
# The residual stored is z - y_bar_train - outcome_pred so the forest sees a
# zero-centered signal and the prior shrinkage toward 0 is well-calibrated.
forest_pred_mu = active_forest_mu.predict(forest_dataset_train)
forest_pred_tau = active_forest_tau.predict(forest_dataset_train)
outcome_pred = forest_pred_mu + forest_pred_tau
if self.has_rfx:
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
outcome_pred = outcome_pred + rfx_pred
mu0 = outcome_pred[y_train[:, 0] == 0]
mu1 = outcome_pred[y_train[:, 0] == 1]
eta_pred = outcome_pred + self.y_bar
mu0 = eta_pred[y_train[:, 0] == 0]
mu1 = eta_pred[y_train[:, 0] == 1]
n0 = np.sum(y_train[:, 0] == 0)
n1 = np.sum(y_train[:, 0] == 1)
u0 = self.rng.uniform(
Expand All @@ -2144,8 +2149,8 @@ def sample(
resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0)
resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1)

# Update outcome
new_outcome = np.squeeze(resid_train) - outcome_pred
# Update outcome: center z by y_bar before passing to forests
new_outcome = (np.squeeze(resid_train) - self.y_bar) - outcome_pred
residual_train.update_data(new_outcome)

# Sample the prognostic forest
Expand Down Expand Up @@ -2195,7 +2200,9 @@ def sample(
Z_basis = tau_basis_train.reshape(-1, 1) if tau_basis_train.ndim == 1 else tau_basis_train
tau_x_raw_2d = tau_x_raw_tau0.reshape(self.n_train, -1)
tau_x_full = np.sum(Z_basis * tau_x_raw_2d, axis=1)
partial_resid_tau0 = np.squeeze(resid_train) - mu_x_tau0 - tau_x_full
# Center z by y_bar so tau_0 does not absorb the probit intercept
resid_for_tau0 = (np.squeeze(resid_train) - self.y_bar) if link_is_probit else np.squeeze(resid_train)
partial_resid_tau0 = resid_for_tau0 - mu_x_tau0 - tau_x_full
if self.has_rfx:
partial_resid_tau0 = partial_resid_tau0 - np.squeeze(
rfx_model.predict(rfx_dataset_train, rfx_tracker)
Expand Down Expand Up @@ -2234,7 +2241,9 @@ def sample(
tau_x = np.squeeze(
active_forest_tau.predict_raw(forest_dataset_train)
)
partial_resid_train = np.squeeze(resid_train - mu_x)
# Center z by y_bar so coding regression does not absorb the probit intercept
resid_for_coding = (resid_train - self.y_bar) if link_is_probit else resid_train
partial_resid_train = np.squeeze(resid_for_coding - mu_x)
if self.has_rfx:
rfx_pred = np.squeeze(
rfx_model.predict(rfx_dataset_train, rfx_tracker)
Expand Down Expand Up @@ -2697,8 +2706,10 @@ def sample(
if self.has_rfx:
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
outcome_pred = outcome_pred + rfx_pred
mu0 = outcome_pred[y_train[:, 0] == 0]
mu1 = outcome_pred[y_train[:, 0] == 1]
# Full probit-scale predictor: forests learn z - y_bar, so add y_bar back
eta_pred = outcome_pred + self.y_bar
mu0 = eta_pred[y_train[:, 0] == 0]
mu1 = eta_pred[y_train[:, 0] == 1]
n0 = np.sum(y_train[:, 0] == 0)
n1 = np.sum(y_train[:, 0] == 1)
u0 = self.rng.uniform(
Expand All @@ -2714,8 +2725,8 @@ def sample(
resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0)
resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1)

# Update outcome
new_outcome = np.squeeze(resid_train) - outcome_pred
# Update outcome: center z by y_bar before passing to forests
new_outcome = np.squeeze(resid_train) - self.y_bar - outcome_pred
residual_train.update_data(new_outcome)

# Sample the prognostic forest
Expand Down Expand Up @@ -2765,7 +2776,9 @@ def sample(
Z_basis = tau_basis_train.reshape(-1, 1) if tau_basis_train.ndim == 1 else tau_basis_train
tau_x_raw_2d = tau_x_raw_tau0.reshape(self.n_train, -1)
tau_x_full = np.sum(Z_basis * tau_x_raw_2d, axis=1)
partial_resid_tau0 = np.squeeze(resid_train) - mu_x_tau0 - tau_x_full
# Center z by y_bar so tau_0 does not absorb the probit intercept
resid_for_tau0 = (np.squeeze(resid_train) - self.y_bar) if link_is_probit else np.squeeze(resid_train)
partial_resid_tau0 = resid_for_tau0 - mu_x_tau0 - tau_x_full
if self.has_rfx:
partial_resid_tau0 = partial_resid_tau0 - np.squeeze(
rfx_model.predict(rfx_dataset_train, rfx_tracker)
Expand Down Expand Up @@ -2804,7 +2817,9 @@ def sample(
tau_x = np.squeeze(
active_forest_tau.predict_raw(forest_dataset_train)
)
partial_resid_train = np.squeeze(resid_train - mu_x)
# Center z by y_bar so coding regression does not absorb the probit intercept
resid_for_coding = (resid_train - self.y_bar) if link_is_probit else resid_train
partial_resid_train = np.squeeze(resid_for_coding - mu_x)
if self.has_rfx:
rfx_pred = np.squeeze(
rfx_model.predict(rfx_dataset_train, rfx_tracker)
Expand Down
Loading