-
Notifications
You must be signed in to change notification settings - Fork 6
Guide Training
The OncoLearn training pipeline trains a multimodal cancer classification model using genomic (miRNA/mRNA), clinical, and optionally medical imaging data.
The trainer builds a GatedLateFusionClassifier from encoders specified in the experiment config:
| Encoder | Model | Input |
|---|---|---|
| Gene | RNA BERT (ibm-research/biomed.rna.bert.110m.mlm.multitask.v1) |
miRNA / mRNA expression matrix |
| Clinical (optional) | FT-Transformer | Clinical feature vector |
| Image (optional) | FM-BCMRI 3D ViT | MRI/mammography images |
Note: RNA BERT is a gated HuggingFace model. If it cannot be downloaded (no token, or offline), the gene encoder automatically falls back to a linear projection so training can proceed without pretrained weights.
Memory: The SCBert transformer has O(n²) attention. To keep it tractable the encoder truncates its input to the top 512 features by magnitude before the forward pass (configurable via
max_seq_lenin the encoder config).
Training is driven by PyTorch Lightning with AdamW, a cosine LR scheduler, and early stopping (patience = 10 epochs).
Experiments are specified via YAML configs under data/configs/modeling/multimodal/:
| Config | Variant | Data source | Labels |
|---|---|---|---|
tcga_brca_tabular_only.yaml |
v2_no_imaging |
XenaBrowser | AJCC stage (gene only) |
tcga_brca_cbioportal_pam50.yaml |
— | cBioPortal | PAM50 subtype (gene + clinical + image) |
tcga_brca_cbioportal_stage.yaml |
— | cBioPortal | AJCC stage (gene + clinical + image) |
Use --config to pass any config directly:
oncolearn train --config data/configs/modeling/multimodal/tcga_brca_tabular_only.yamlEach YAML config references a pipeline file (data.pipeline) that defines which datasets to load and how they are joined. See Pipeline DSL for details.
# Via CLI
oncolearn train --variant v2_no_imaging --epochs 10 --batch_size 8
# Or directly via the module
python -m oncolearn.trainer --variant v2_no_imaging --epochs 10 --batch_size 8| Argument | Default | Description |
|---|---|---|
--config |
— | Path to a YAML config (takes precedence over --variant) |
--variant |
v2_no_imaging |
Shorthand: v2_no_imaging loads tcga_brca_tabular_only.yaml
|
--epochs |
10 |
Override training.max_epochs (shorthand only) |
--batch_size |
16 |
Override training.batch_size (shorthand only) |
Uses only the gene/miRNA encoder with XenaBrowser data. Useful when image data is unavailable or for faster iteration.
oncolearn train --variant v2_no_imaging --epochs 20 --batch_size 32For multimodal training (gene + clinical + image), pass a config directly:
# PAM50 subtype prediction
oncolearn train --config data/configs/modeling/multimodal/tcga_brca_cbioportal_pam50.yaml
# AJCC stage prediction
oncolearn train --config data/configs/modeling/multimodal/tcga_brca_cbioportal_stage.yamlMultimodal training requires K-fold splits pre-generated with
oncolearn preprocess multimodal kfold. See K-Fold Splits below.
Pick the service matching your GPU — volumes and memory limits are pre-configured in docker-compose.yml:
# AMD/WSL2
docker compose --profile prod-rocm-wsl run --rm prod-rocm-wsl \
python -m oncolearn.trainer --variant v2_no_imaging --epochs 10 --batch_size 8
# AMD/native Linux
docker compose --profile prod-rocm run --rm prod-rocm \
python -m oncolearn.trainer --variant v2_no_imaging --epochs 10 --batch_size 8
# NVIDIA
docker compose --profile prod-cuda run --rm prod-cuda \
python -m oncolearn.trainer --variant v2_no_imaging --epochs 10 --batch_size 8The service mounts the following directories automatically (see docker-compose.yml):
| Mount | Purpose |
|---|---|
data/ |
Read-only input data |
src/ |
Live code — changes take effect without rebuilding the image |
.hf-cache/ |
Persists HuggingFace model downloads |
models/ |
Saved model checkpoints |
outputs/ |
Training logs and Lightning checkpoints |
See the Docker guide for all available services and build commands.
# miRNA expression
oncolearn xena download --cohorts BRCA --category mirna_seq --unzip
# Clinical / stage labels
oncolearn xena download --cohorts BRCA --category clinical --unzip# mRNA + clinical data (cBioPortal)
oncolearn cbioportal download --cohorts BRCA
# Imaging (TCIA)
oncolearn tcia download --cohorts BRCA --yesExpected directory layout:
data/
└── sources/
├── xenabrowser/
│ └── TCGA-BRCA/
│ ├── TCGA-BRCA.mirna.tsv
│ └── TCGA-BRCA.clinical.tsv
├── cbioportal/
│ └── TCGA-BRCA/
└── tcia/
└── TCGA-BRCA/
└── TCIA_TCGA-BRCA_*/
└── TCGA-BRCA/
Generate stratified K-fold patient splits before training for reproducible cross-validation.
oncolearn xena preprocess \
--config data/configs/modeling/multimodal/tcga_brca_tabular_only.yaml \
--n_splits 5 --seed 42# PAM50 subtype labels
oncolearn preprocess multimodal kfold 5 --stratified --label pam50
# AJCC stage labels
oncolearn preprocess multimodal kfold 5 --stratified --label stageThe multimodal command automatically intersects gene ∩ image ∩ clinical patient sets before stratifying. See oncolearn preprocess multimodal kfold for all options.
Then add the split folder to your config:
data:
splits_dir: data/configs/modeling/multimodal/splits/pam50/kfold/fold_0The trainer writes per-epoch metrics to stdout via PyTorch Lightning and saves checkpoints to outputs/<experiment_name>/:
Epoch 2/50 ━━━━━━━━━━━━━━━━━ 8/8
train_loss: 1.241
val_loss: 1.439
val_acc: 0.375
Checkpoints saved:
-
best_model.ckpt— best validation accuracy -
epoch_N.ckpt— periodic snapshots (everysave_every_n_epochsepochs)
Early stopping triggers when val_acc does not improve for early_stopping_patience consecutive epochs (default: 10).
| Option | Description |
|---|---|
model.modality_dropout_prob |
Stochastically drop a modality during training (at least one is always kept) |
training.use_class_weights |
Inverse-frequency class weighting to handle imbalanced stages |
data.splits_dir |
Path to external train.txt / test.txt / validation.txt split files |
training.hpo |
Optuna hyperparameter optimisation config (n_trials, search_space) |
training.cross_validation |
Cross-validation over pre-generated K-fold dirs; metrics averaged across folds |
training.accelerator |
"auto" detects GPU automatically; "cpu" forces CPU |
training.regularization.gradient_clip_val |
Gradient clipping (0 = disabled) |
training.regularization.l1_lambda |
L1 weight regularisation strength |
OncoLearn | A comprehensive toolkit for cancer genomics analysis and biomarker discovery.
Built with ❤️ for cancer research