-
Notifications
You must be signed in to change notification settings - Fork 17
[FEATURE] Expose observation weights in high-level BART/BCF interface #329
Description
Summary
Add an observation_weights argument to bart() (R) and BARTModel.sample() (Python), and to the corresponding BCF interfaces.
Motivation
Observation weights (also called case weights) allow individual observations to contribute differently to the likelihood. They have at least three practical use cases:
- Survey-weighted regression — weight observations by inverse sampling probability
- Importance resampling — upweight rare subpopulations or reweight to a target distribution
- BART prior sampling — passing all-zero weights tells the sampler to ignore the data entirely, producing draws from the tree structure prior and leaf priors (see #200 and RFC 0003)
Implementation notes
Observation weights are already fully supported in the stochtree C++ core. The ForestDataset class stores weights as var_weights_ and the leaf models use them in weighted sufficient statistics (y_i | - ~ N(μ(X_i), σ²/w_i)). The R low-level wrapper already exposes AddVarianceWeights(). No C++ changes are required.
The work is to thread observation_weights through the high-level R (bart(), bcf()) and Python (BARTModel.sample(), BCFModel.sample()) interfaces and pass them to the ForestDataset construction step.
Scope
- Add
observation_weights = NULL/observation_weights=Noneargument tobart(),bcf(),BARTModel.sample(),BCFModel.sample() - Validate: numeric, length
nrow(X_train), all values>= 0 - Pass to
ForestDatasetconstruction (both mu and tau forest datasets in BCF) - Weights are not serialized (they are a property of the training data, not the model)
- Add unit tests covering: uniform weights (should match no-weights), zero weights (should ignore data), and a weighted regression example
Reference
RFC 0003: https://github.com/StochasticTree/rfcs/blob/main/0003-bart-prior-sampling.md