diff --git a/R/bcf.R b/R/bcf.R index f9543a28..125e2994 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -923,6 +923,55 @@ bcf <- function( X_test <- preprocessPredictionData(X_test, X_train_metadata) } + # Handle factor-valued treatment vectors before any numeric conversion. + # as.numeric() on a factor returns level indices (1, 2, ...), not 0/1, so + # factors must be explicitly converted to 0/1 first. + if (is.factor(Z_train)) { + lvls <- levels(Z_train) + if (length(lvls) != 2) { + stop("Factor Z_train must have exactly 2 levels for binary treatment") + } + message( + "Z_train is a factor; converting to 0/1 using level order: ", + lvls[1], " = 0, ", lvls[2], " = 1" + ) + Z_train <- as.integer(Z_train) - 1L + } + if (!is.null(Z_test) && is.factor(Z_test)) { + lvls <- levels(Z_test) + if (length(lvls) != 2) { + stop("Factor Z_test must have exactly 2 levels for binary treatment") + } + message( + "Z_test is a factor; converting to 0/1 using level order: ", + lvls[1], " = 0, ", lvls[2], " = 1" + ) + Z_test <- as.integer(Z_test) - 1L + } + + # Check that all inputs are numeric before matrix conversions + if (!is.numeric(y_train)) { + stop("y_train must be numeric") + } + if (!is.numeric(Z_train)) { + stop("Z_train must be numeric") + } + if (!is.null(Z_test) && !is.numeric(Z_test)) { + stop("Z_test must be numeric") + } + if (!is.null(propensity_train) && !is.numeric(propensity_train)) { + stop("propensity_train must be numeric") + } + if (!is.null(propensity_test) && !is.numeric(propensity_test)) { + stop("propensity_test must be numeric") + } + if (!is.null(rfx_basis_train) && !is.numeric(rfx_basis_train)) { + stop("rfx_basis_train must be numeric") + } + if (!is.null(rfx_basis_test) && !is.numeric(rfx_basis_test)) { + stop("rfx_basis_test must be numeric") + } + # Convert all input data to matrices if not already converted Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train)) Z_train <- matrix(as.numeric(Z_train), ncol = Z_col) @@ -942,6 +991,14 @@ bcf <- function( rfx_basis_test <- as.matrix(rfx_basis_test) } + # Convert y_train to a vector if passed as a one-column matrix + if (is.matrix(y_train)) { + if (ncol(y_train) > 1) { + stop("y_train must be a numeric vector or a one-column matrix") + } + y_train <- as.numeric(y_train) + } + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- FALSE has_rfx_test <- FALSE @@ -964,17 +1021,6 @@ bcf <- function( } } - # Check that outcome and treatment are numeric - if (!is.numeric(y_train)) { - stop("y_train must be numeric") - } - if (!is.numeric(Z_train)) { - stop("Z_train must be numeric") - } - if (!is.null(Z_test)) { - if (!is.numeric(Z_test)) stop("Z_test must be numeric") - } - # Data consistency checks if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { stop("X_train and X_test must have the same number of columns") @@ -3402,6 +3448,19 @@ predict.bcfmodel <- function( stop("X must be a matrix or dataframe") } + # Handle factor-valued treatment before numeric conversion + if (is.factor(Z)) { + lvls <- levels(Z) + if (length(lvls) != 2) { + stop("Factor Z must have exactly 2 levels for binary treatment") + } + warning( + "Z is a factor; recoding to 0/1 using level order: ", + lvls[1], " = 0, ", lvls[2], " = 1" + ) + Z <- as.integer(Z) - 1L + } + # Convert all input data to matrices if not already converted if ((is.null(dim(Z))) && (!is.null(Z))) { Z <- as.matrix(as.numeric(Z)) diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index 99c25316..de971440 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -922,3 +922,78 @@ test_that("BCF JSON serialization roundtrip covers all deserialization paths", { expect_equal(rowMeans(preds_rt[["y_hat"]]), y_hat_orig) expect_equal(rowMeans(preds_rt[["tau_hat"]]), tau_hat_orig) }) + +test_that("BCF factor-valued treatment handling", { + skip_on_cran() + + # Shared data: binary treatment DGP + n <- 100 + p <- 5 + set.seed(42) + X <- matrix(runif(n * p), ncol = p) + pi_X <- 0.4 + 0.2 * X[, 1] + Z_numeric <- rbinom(n, 1, pi_X) + tau_X <- 1 + X[, 2] + mu_X <- 2 * X[, 3] + y <- mu_X + tau_X * Z_numeric + rnorm(n, 0, 1) + + # Binary factor treatment: levels "0" and "1" + # Verify the conversion produces 0/1 values identical to the original + Z_factor_binary <- factor(Z_numeric) + expect_equal(levels(Z_factor_binary), c("0", "1")) + expect_equal(as.integer(Z_factor_binary) - 1L, as.integer(Z_numeric)) + + # Factor treatment should run without error and emit an informative message + expect_message( + suppressWarnings(bcf( + X_train = X, y_train = y, Z_train = Z_factor_binary, + propensity_train = pi_X, num_gfr = 0, num_burnin = 5, num_mcmc = 5 + )), + regexp = "Z_train is a factor" + ) + + # Logical treatment converted to factor: levels "FALSE" and "TRUE" + # as.factor(logical) sorts alphabetically: "FALSE" = 0, "TRUE" = 1 + Z_logical <- as.logical(Z_numeric) + Z_factor_logical <- as.factor(Z_logical) + expect_equal(levels(Z_factor_logical), c("FALSE", "TRUE")) + expect_equal(as.integer(Z_factor_logical) - 1L, as.integer(Z_numeric)) + + expect_message( + suppressWarnings(bcf( + X_train = X, y_train = y, Z_train = Z_factor_logical, + propensity_train = pi_X, num_gfr = 0, num_burnin = 5, num_mcmc = 5 + )), + regexp = "Z_train is a factor" + ) + + # Factor treatment with more than 2 levels should error immediately + Z_factor_categorical <- factor(sample(c("A", "B", "C"), n, replace = TRUE)) + expect_error( + bcf( + X_train = X, y_train = y, Z_train = Z_factor_categorical, + propensity_train = pi_X, num_gfr = 0, num_burnin = 5, num_mcmc = 5 + ), + regexp = "exactly 2 levels" + ) + + # predict.bcfmodel should also handle factor Z, raising a warning + suppressMessages( + bcf_model <- bcf( + X_train = X, y_train = y, Z_train = Z_numeric, + propensity_train = pi_X, num_gfr = 0, num_burnin = 5, num_mcmc = 5 + ) + ) + expect_warning( + predict(bcf_model, X, Z_factor_binary, pi_X), + regexp = "Z is a factor" + ) + expect_warning( + predict(bcf_model, X, Z_factor_logical, pi_X), + regexp = "Z is a factor" + ) + expect_error( + predict(bcf_model, X, Z_factor_categorical, pi_X), + regexp = "exactly 2 levels" + ) +})