Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,55 @@ bcf <- function(
X_test <- preprocessPredictionData(X_test, X_train_metadata)
}

# Handle factor-valued treatment vectors before any numeric conversion.
# as.numeric() on a factor returns level indices (1, 2, ...), not 0/1, so
# factors must be explicitly converted to 0/1 first.
if (is.factor(Z_train)) {
lvls <- levels(Z_train)
if (length(lvls) != 2) {
stop("Factor Z_train must have exactly 2 levels for binary treatment")
}
message(
"Z_train is a factor; converting to 0/1 using level order: ",
lvls[1], " = 0, ", lvls[2], " = 1"
)
Z_train <- as.integer(Z_train) - 1L
}
if (!is.null(Z_test) && is.factor(Z_test)) {
lvls <- levels(Z_test)
if (length(lvls) != 2) {
stop("Factor Z_test must have exactly 2 levels for binary treatment")
}
message(
"Z_test is a factor; converting to 0/1 using level order: ",
lvls[1], " = 0, ", lvls[2], " = 1"
)
Z_test <- as.integer(Z_test) - 1L
}

# Check that all inputs are numeric before matrix conversions
if (!is.numeric(y_train)) {
stop("y_train must be numeric")
}
if (!is.numeric(Z_train)) {
stop("Z_train must be numeric")
}
if (!is.null(Z_test) && !is.numeric(Z_test)) {
stop("Z_test must be numeric")
}
if (!is.null(propensity_train) && !is.numeric(propensity_train)) {
stop("propensity_train must be numeric")
}
if (!is.null(propensity_test) && !is.numeric(propensity_test)) {
stop("propensity_test must be numeric")
}
if (!is.null(rfx_basis_train) && !is.numeric(rfx_basis_train)) {
stop("rfx_basis_train must be numeric")
}
if (!is.null(rfx_basis_test) && !is.numeric(rfx_basis_test)) {
stop("rfx_basis_test must be numeric")
}

# Convert all input data to matrices if not already converted
Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train))
Z_train <- matrix(as.numeric(Z_train), ncol = Z_col)
Expand All @@ -942,6 +991,14 @@ bcf <- function(
rfx_basis_test <- as.matrix(rfx_basis_test)
}

# Convert y_train to a vector if passed as a one-column matrix
if (is.matrix(y_train)) {
if (ncol(y_train) > 1) {
stop("y_train must be a numeric vector or a one-column matrix")
}
y_train <- as.numeric(y_train)
}

# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
has_rfx <- FALSE
has_rfx_test <- FALSE
Expand All @@ -964,17 +1021,6 @@ bcf <- function(
}
}

# Check that outcome and treatment are numeric
if (!is.numeric(y_train)) {
stop("y_train must be numeric")
}
if (!is.numeric(Z_train)) {
stop("Z_train must be numeric")
}
if (!is.null(Z_test)) {
if (!is.numeric(Z_test)) stop("Z_test must be numeric")
}

# Data consistency checks
if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) {
stop("X_train and X_test must have the same number of columns")
Expand Down Expand Up @@ -3402,6 +3448,19 @@ predict.bcfmodel <- function(
stop("X must be a matrix or dataframe")
}

# Handle factor-valued treatment before numeric conversion
if (is.factor(Z)) {
lvls <- levels(Z)
if (length(lvls) != 2) {
stop("Factor Z must have exactly 2 levels for binary treatment")
}
warning(
"Z is a factor; recoding to 0/1 using level order: ",
lvls[1], " = 0, ", lvls[2], " = 1"
)
Z <- as.integer(Z) - 1L
}

# Convert all input data to matrices if not already converted
if ((is.null(dim(Z))) && (!is.null(Z))) {
Z <- as.matrix(as.numeric(Z))
Expand Down
75 changes: 75 additions & 0 deletions test/R/testthat/test-bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -922,3 +922,78 @@ test_that("BCF JSON serialization roundtrip covers all deserialization paths", {
expect_equal(rowMeans(preds_rt[["y_hat"]]), y_hat_orig)
expect_equal(rowMeans(preds_rt[["tau_hat"]]), tau_hat_orig)
})

test_that("BCF factor-valued treatment handling", {
skip_on_cran()

# Shared data: binary treatment DGP
n <- 100
p <- 5
set.seed(42)
X <- matrix(runif(n * p), ncol = p)
pi_X <- 0.4 + 0.2 * X[, 1]
Z_numeric <- rbinom(n, 1, pi_X)
tau_X <- 1 + X[, 2]
mu_X <- 2 * X[, 3]
y <- mu_X + tau_X * Z_numeric + rnorm(n, 0, 1)

# Binary factor treatment: levels "0" and "1"
# Verify the conversion produces 0/1 values identical to the original
Z_factor_binary <- factor(Z_numeric)
expect_equal(levels(Z_factor_binary), c("0", "1"))
expect_equal(as.integer(Z_factor_binary) - 1L, as.integer(Z_numeric))

# Factor treatment should run without error and emit an informative message
expect_message(
suppressWarnings(bcf(
X_train = X, y_train = y, Z_train = Z_factor_binary,
propensity_train = pi_X, num_gfr = 0, num_burnin = 5, num_mcmc = 5
)),
regexp = "Z_train is a factor"
)

# Logical treatment converted to factor: levels "FALSE" and "TRUE"
# as.factor(logical) sorts alphabetically: "FALSE" = 0, "TRUE" = 1
Z_logical <- as.logical(Z_numeric)
Z_factor_logical <- as.factor(Z_logical)
expect_equal(levels(Z_factor_logical), c("FALSE", "TRUE"))
expect_equal(as.integer(Z_factor_logical) - 1L, as.integer(Z_numeric))

expect_message(
suppressWarnings(bcf(
X_train = X, y_train = y, Z_train = Z_factor_logical,
propensity_train = pi_X, num_gfr = 0, num_burnin = 5, num_mcmc = 5
)),
regexp = "Z_train is a factor"
)

# Factor treatment with more than 2 levels should error immediately
Z_factor_categorical <- factor(sample(c("A", "B", "C"), n, replace = TRUE))
expect_error(
bcf(
X_train = X, y_train = y, Z_train = Z_factor_categorical,
propensity_train = pi_X, num_gfr = 0, num_burnin = 5, num_mcmc = 5
),
regexp = "exactly 2 levels"
)

# predict.bcfmodel should also handle factor Z, raising a warning
suppressMessages(
bcf_model <- bcf(
X_train = X, y_train = y, Z_train = Z_numeric,
propensity_train = pi_X, num_gfr = 0, num_burnin = 5, num_mcmc = 5
)
)
expect_warning(
predict(bcf_model, X, Z_factor_binary, pi_X),
regexp = "Z is a factor"
)
expect_warning(
predict(bcf_model, X, Z_factor_logical, pi_X),
regexp = "Z is a factor"
)
expect_error(
predict(bcf_model, X, Z_factor_categorical, pi_X),
regexp = "exactly 2 levels"
)
})
Loading