diff --git a/R/bart.R b/R/bart.R index 63aa06da..71107d0b 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1058,7 +1058,9 @@ bart <- function( # Handle standardization, prior calibration, and initialization of forest # differently for binary and continuous outcomes if (link_is_probit) { - # Compute a probit-scale offset and fix scale to 1 + # Probit-scale intercept: center the forest on the population-average latent mean. + # The forest predicts mu(X) and y_bar_train is added back at prediction time. + # The latent z sampling uses y_bar_train to set the correct truncated normal mean and to center z before the residual update. y_bar_train <- qnorm(mean_cpp(as.numeric(y_train))) y_std_train <- 1 standardize <- FALSE @@ -1591,6 +1593,10 @@ bart <- function( if (include_mean_forest) { if (link_is_probit) { # Sample latent probit variable, z | - + # outcome_pred is the centered forest prediction (not including y_bar_train). + # The truncated normal mean is outcome_pred + y_bar_train (the full eta on the probit scale). + # The residual stored is z - y_bar_train - outcome_pred so the forest sees a + # zero-centered signal and the prior shrinkage toward 0 is well-calibrated. outcome_pred <- active_forest_mean$predict( forest_dataset_train ) @@ -1601,15 +1607,16 @@ bart <- function( ) outcome_pred <- outcome_pred + rfx_pred } - mu0 <- outcome_pred[y_train == 0] - mu1 <- outcome_pred[y_train == 1] + eta_pred <- outcome_pred + y_bar_train + mu0 <- eta_pred[y_train == 0] + mu1 <- eta_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) resid_train[y_train == 0] <- mu0 + qnorm(u0) resid_train[y_train == 1] <- mu1 + qnorm(u1) - # Update outcome - outcome_train$update_data(resid_train - outcome_pred) + # Update outcome: center z by y_bar_train before passing to forest + outcome_train$update_data(resid_train - y_bar_train - outcome_pred) } # Sample mean forest @@ -2127,15 +2134,18 @@ bart <- function( ) outcome_pred <- outcome_pred + rfx_pred } - mu0 <- outcome_pred[y_train == 0] - mu1 <- outcome_pred[y_train == 1] + eta_pred <- outcome_pred + y_bar_train + mu0 <- eta_pred[y_train == 0] + mu1 <- eta_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) resid_train[y_train == 0] <- mu0 + qnorm(u0) resid_train[y_train == 1] <- mu1 + qnorm(u1) - # Update outcome - outcome_train$update_data(resid_train - outcome_pred) + # Update outcome: center z by y_bar_train before passing to forest + outcome_train$update_data( + resid_train - y_bar_train - outcome_pred + ) } forest_model_mean$sample_one_iteration( diff --git a/R/bcf.R b/R/bcf.R index b7af4f6f..0ab76098 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1475,7 +1475,9 @@ bcf <- function( # Handle standardization, prior calibration, and initialization of forest # differently for binary and continuous outcomes if (link_is_probit) { - # Compute a probit-scale offset and fix scale to 1 + # Probit-scale intercept: center the forest on the population-average latent mean. + # The forest predicts mu(X) and y_bar_train is added back at prediction time. + # The latent z sampling uses y_bar_train to set the correct truncated normal mean and to center z before the residual update. y_bar_train <- qnorm(mean_cpp(as.numeric(y_train))) y_std_train <- 1 @@ -1948,6 +1950,10 @@ bcf <- function( if (link_is_probit) { # Sample latent probit variable, z | - + # outcome_pred is the centered forest prediction (not including y_bar_train). + # The truncated normal mean is outcome_pred + y_bar_train (the full eta on the probit scale). + # The residual stored is z - y_bar_train - outcome_pred so the forest sees a + # zero-centered signal and the prior shrinkage toward 0 is well-calibrated. mu_forest_pred <- active_forest_mu$predict(forest_dataset_train) tau_forest_pred <- active_forest_tau$predict( forest_dataset_train @@ -1960,15 +1966,16 @@ bcf <- function( ) outcome_pred <- outcome_pred + rfx_pred } - mu0 <- outcome_pred[y_train == 0] - mu1 <- outcome_pred[y_train == 1] + eta_pred <- outcome_pred + y_bar_train + mu0 <- eta_pred[y_train == 0] + mu1 <- eta_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) resid_train[y_train == 0] <- mu0 + qnorm(u0) resid_train[y_train == 1] <- mu1 + qnorm(u1) - # Update outcome - outcome_train$update_data(resid_train - outcome_pred) + # Update outcome: center z by y_bar_train before passing to forests + outcome_train$update_data(resid_train - y_bar_train - outcome_pred) } # Sample the prognostic forest @@ -2028,7 +2035,14 @@ bcf <- function( Z_basis_mat <- as.matrix(tau_basis_train) # tau(X) * basis contribution per observation tau_x_full <- rowSums(Z_basis_mat * as.matrix(tau_x_raw_tau0)) - partial_resid_tau0 <- resid_train - + # For probit, resid_train holds the full-scale latent z; center it so that + # tau_0 does not absorb the probit intercept y_bar_train. + resid_for_tau0 <- if (link_is_probit) { + resid_train - y_bar_train + } else { + resid_train + } + partial_resid_tau0 <- resid_for_tau0 - as.numeric(mu_x_raw_tau0) - tau_x_full if (has_rfx) { @@ -2087,7 +2101,14 @@ bcf <- function( tau_x_raw_train <- active_forest_tau$predict_raw( forest_dataset_train ) - partial_resid_mu_train <- resid_train - mu_x_raw_train + # For probit, resid_train holds full-scale z; center it so b_0/b_1 do not + # absorb the probit intercept y_bar_train. + resid_for_coding <- if (link_is_probit) { + resid_train - y_bar_train + } else { + resid_train + } + partial_resid_mu_train <- resid_for_coding - mu_x_raw_train if (has_rfx) { rfx_preds_train <- rfx_model$predict( rfx_dataset_train, @@ -2698,15 +2719,16 @@ bcf <- function( ) outcome_pred <- outcome_pred + rfx_pred } - mu0 <- outcome_pred[y_train == 0] - mu1 <- outcome_pred[y_train == 1] + eta_pred <- outcome_pred + y_bar_train + mu0 <- eta_pred[y_train == 0] + mu1 <- eta_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) resid_train[y_train == 0] <- mu0 + qnorm(u0) resid_train[y_train == 1] <- mu1 + qnorm(u1) - # Update outcome - outcome_train$update_data(resid_train - outcome_pred) + # Update outcome: center z by y_bar_train before passing to forests + outcome_train$update_data(resid_train - y_bar_train - outcome_pred) } # Sample the prognostic forest @@ -2768,7 +2790,14 @@ bcf <- function( Z_basis_mat <- as.matrix(tau_basis_train) # tau(X) * basis contribution per observation tau_x_full <- rowSums(Z_basis_mat * as.matrix(tau_x_raw_tau0)) - partial_resid_tau0 <- resid_train - + # For probit, resid_train holds the full-scale latent z; center it so that + # tau_0 does not absorb the probit intercept y_bar_train. + resid_for_tau0 <- if (link_is_probit) { + resid_train - y_bar_train + } else { + resid_train + } + partial_resid_tau0 <- resid_for_tau0 - as.numeric(mu_x_raw_tau0) - tau_x_full if (has_rfx) { @@ -2827,7 +2856,14 @@ bcf <- function( tau_x_raw_train <- active_forest_tau$predict_raw( forest_dataset_train ) - partial_resid_mu_train <- resid_train - mu_x_raw_train + # For probit, resid_train holds full-scale z; center it so b_0/b_1 do not + # absorb the probit intercept y_bar_train. + resid_for_coding <- if (link_is_probit) { + resid_train - y_bar_train + } else { + resid_train + } + partial_resid_mu_train <- resid_for_coding - mu_x_raw_train if (has_rfx) { rfx_preds_train <- rfx_model$predict( rfx_dataset_train, diff --git a/stochtree/bart.py b/stochtree/bart.py index 0b7e6343..d9451c8c 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1560,8 +1560,10 @@ def sample( if self.has_rfx: rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) outcome_pred = outcome_pred + rfx_pred - mu0 = outcome_pred[y_train[:, 0] == 0] - mu1 = outcome_pred[y_train[:, 0] == 1] + # Full probit-scale predictor: forest learns z - y_bar, so add y_bar back + eta_pred = outcome_pred + self.y_bar + mu0 = eta_pred[y_train[:, 0] == 0] + mu1 = eta_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -1577,8 +1579,8 @@ def sample( resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) - # Update outcome - new_outcome = np.squeeze(resid_train) - outcome_pred + # Update outcome: center z by y_bar before passing to forest + new_outcome = np.squeeze(resid_train) - self.y_bar - outcome_pred residual_train.update_data(new_outcome) # Sample the mean forest @@ -1885,8 +1887,10 @@ def sample( rfx_dataset_train, rfx_tracker ) outcome_pred = outcome_pred + rfx_pred - mu0 = outcome_pred[y_train[:, 0] == 0] - mu1 = outcome_pred[y_train[:, 0] == 1] + # Full probit-scale predictor: forest learns z - y_bar, so add y_bar back + eta_pred = outcome_pred + self.y_bar + mu0 = eta_pred[y_train[:, 0] == 0] + mu1 = eta_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -1902,8 +1906,8 @@ def sample( resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) - # Update outcome - new_outcome = np.squeeze(resid_train) - outcome_pred + # Update outcome: center z by y_bar before passing to forest + new_outcome = np.squeeze(resid_train) - self.y_bar - outcome_pred residual_train.update_data(new_outcome) # Sample the mean forest diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 1d80d562..ac04bda9 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2121,14 +2121,19 @@ def sample( if link_is_probit: # Sample latent probit variable z | - + # outcome_pred is the centered forest prediction (not including y_bar_train). + # The truncated normal mean is outcome_pred + y_bar_train (the full eta on the probit scale). + # The residual stored is z - y_bar_train - outcome_pred so the forest sees a + # zero-centered signal and the prior shrinkage toward 0 is well-calibrated. forest_pred_mu = active_forest_mu.predict(forest_dataset_train) forest_pred_tau = active_forest_tau.predict(forest_dataset_train) outcome_pred = forest_pred_mu + forest_pred_tau if self.has_rfx: rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) outcome_pred = outcome_pred + rfx_pred - mu0 = outcome_pred[y_train[:, 0] == 0] - mu1 = outcome_pred[y_train[:, 0] == 1] + eta_pred = outcome_pred + self.y_bar + mu0 = eta_pred[y_train[:, 0] == 0] + mu1 = eta_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -2144,8 +2149,8 @@ def sample( resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) - # Update outcome - new_outcome = np.squeeze(resid_train) - outcome_pred + # Update outcome: center z by y_bar before passing to forests + new_outcome = (np.squeeze(resid_train) - self.y_bar) - outcome_pred residual_train.update_data(new_outcome) # Sample the prognostic forest @@ -2195,7 +2200,9 @@ def sample( Z_basis = tau_basis_train.reshape(-1, 1) if tau_basis_train.ndim == 1 else tau_basis_train tau_x_raw_2d = tau_x_raw_tau0.reshape(self.n_train, -1) tau_x_full = np.sum(Z_basis * tau_x_raw_2d, axis=1) - partial_resid_tau0 = np.squeeze(resid_train) - mu_x_tau0 - tau_x_full + # Center z by y_bar so tau_0 does not absorb the probit intercept + resid_for_tau0 = (np.squeeze(resid_train) - self.y_bar) if link_is_probit else np.squeeze(resid_train) + partial_resid_tau0 = resid_for_tau0 - mu_x_tau0 - tau_x_full if self.has_rfx: partial_resid_tau0 = partial_resid_tau0 - np.squeeze( rfx_model.predict(rfx_dataset_train, rfx_tracker) @@ -2234,7 +2241,9 @@ def sample( tau_x = np.squeeze( active_forest_tau.predict_raw(forest_dataset_train) ) - partial_resid_train = np.squeeze(resid_train - mu_x) + # Center z by y_bar so coding regression does not absorb the probit intercept + resid_for_coding = (resid_train - self.y_bar) if link_is_probit else resid_train + partial_resid_train = np.squeeze(resid_for_coding - mu_x) if self.has_rfx: rfx_pred = np.squeeze( rfx_model.predict(rfx_dataset_train, rfx_tracker) @@ -2697,8 +2706,10 @@ def sample( if self.has_rfx: rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) outcome_pred = outcome_pred + rfx_pred - mu0 = outcome_pred[y_train[:, 0] == 0] - mu1 = outcome_pred[y_train[:, 0] == 1] + # Full probit-scale predictor: forests learn z - y_bar, so add y_bar back + eta_pred = outcome_pred + self.y_bar + mu0 = eta_pred[y_train[:, 0] == 0] + mu1 = eta_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -2714,8 +2725,8 @@ def sample( resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) - # Update outcome - new_outcome = np.squeeze(resid_train) - outcome_pred + # Update outcome: center z by y_bar before passing to forests + new_outcome = np.squeeze(resid_train) - self.y_bar - outcome_pred residual_train.update_data(new_outcome) # Sample the prognostic forest @@ -2765,7 +2776,9 @@ def sample( Z_basis = tau_basis_train.reshape(-1, 1) if tau_basis_train.ndim == 1 else tau_basis_train tau_x_raw_2d = tau_x_raw_tau0.reshape(self.n_train, -1) tau_x_full = np.sum(Z_basis * tau_x_raw_2d, axis=1) - partial_resid_tau0 = np.squeeze(resid_train) - mu_x_tau0 - tau_x_full + # Center z by y_bar so tau_0 does not absorb the probit intercept + resid_for_tau0 = (np.squeeze(resid_train) - self.y_bar) if link_is_probit else np.squeeze(resid_train) + partial_resid_tau0 = resid_for_tau0 - mu_x_tau0 - tau_x_full if self.has_rfx: partial_resid_tau0 = partial_resid_tau0 - np.squeeze( rfx_model.predict(rfx_dataset_train, rfx_tracker) @@ -2804,7 +2817,9 @@ def sample( tau_x = np.squeeze( active_forest_tau.predict_raw(forest_dataset_train) ) - partial_resid_train = np.squeeze(resid_train - mu_x) + # Center z by y_bar so coding regression does not absorb the probit intercept + resid_for_coding = (resid_train - self.y_bar) if link_is_probit else resid_train + partial_resid_train = np.squeeze(resid_for_coding - mu_x) if self.has_rfx: rfx_pred = np.squeeze( rfx_model.predict(rfx_dataset_train, rfx_tracker)