Skip to content

Guide Training

andrewscouten edited this page Mar 15, 2026 · 3 revisions

Training Pipeline

The OncoLearn training pipeline trains a multimodal cancer classification model using genomic (miRNA/mRNA), clinical, and optionally medical imaging data.

Overview

The trainer builds a GatedLateFusionClassifier from encoders specified in the experiment config:

Encoder Model Input
Gene RNA BERT (ibm-research/biomed.rna.bert.110m.mlm.multitask.v1) miRNA / mRNA expression matrix
Clinical (optional) FT-Transformer Clinical feature vector
Image (optional) FM-BCMRI 3D ViT MRI/mammography images

Note: RNA BERT is a gated HuggingFace model. If it cannot be downloaded (no token, or offline), the gene encoder automatically falls back to a linear projection so training can proceed without pretrained weights.

Memory: The SCBert transformer has O(n²) attention. To keep it tractable the encoder truncates its input to the top 512 features by magnitude before the forward pass (configurable via max_seq_len in the encoder config).

Training is driven by PyTorch Lightning with AdamW, a cosine LR scheduler, and early stopping (patience = 10 epochs).


Config Files

Experiments are specified via YAML configs under data/configs/modeling/multimodal/:

Config Variant Data source Labels
tcga_brca_tabular_only.yaml v2_no_imaging XenaBrowser AJCC stage (gene only)
tcga_brca_cbioportal_pam50.yaml cBioPortal PAM50 subtype (gene + clinical + image)
tcga_brca_cbioportal_stage.yaml cBioPortal AJCC stage (gene + clinical + image)

Use --config to pass any config directly:

oncolearn train --config data/configs/modeling/multimodal/tcga_brca_tabular_only.yaml

Each YAML config references a pipeline file (data.pipeline) that defines which datasets to load and how they are joined. See Pipeline DSL for details.


Basic Usage

# Via CLI
oncolearn train --variant v2_no_imaging --epochs 10 --batch_size 8

# Or directly via the module
python -m oncolearn.trainer --variant v2_no_imaging --epochs 10 --batch_size 8

Arguments

Argument Default Description
--config Path to a YAML config (takes precedence over --variant)
--variant v2_no_imaging Shorthand: v2_no_imaging loads tcga_brca_tabular_only.yaml
--epochs 10 Override training.max_epochs (shorthand only)
--batch_size 16 Override training.batch_size (shorthand only)

Variants

v2_no_imaging — Tabular only

Uses only the gene/miRNA encoder with XenaBrowser data. Useful when image data is unavailable or for faster iteration.

oncolearn train --variant v2_no_imaging --epochs 20 --batch_size 32

Multimodal (cBioPortal)

For multimodal training (gene + clinical + image), pass a config directly:

# PAM50 subtype prediction
oncolearn train --config data/configs/modeling/multimodal/tcga_brca_cbioportal_pam50.yaml

# AJCC stage prediction
oncolearn train --config data/configs/modeling/multimodal/tcga_brca_cbioportal_stage.yaml

Multimodal training requires K-fold splits pre-generated with oncolearn preprocess multimodal kfold. See K-Fold Splits below.


Running in Docker

docker compose (recommended)

Pick the service matching your GPU — volumes and memory limits are pre-configured in docker-compose.yml:

# AMD/WSL2
docker compose --profile prod-rocm-wsl run --rm prod-rocm-wsl \
  python -m oncolearn.trainer --variant v2_no_imaging --epochs 10 --batch_size 8

# AMD/native Linux
docker compose --profile prod-rocm run --rm prod-rocm \
  python -m oncolearn.trainer --variant v2_no_imaging --epochs 10 --batch_size 8

# NVIDIA
docker compose --profile prod-cuda run --rm prod-cuda \
  python -m oncolearn.trainer --variant v2_no_imaging --epochs 10 --batch_size 8

The service mounts the following directories automatically (see docker-compose.yml):

Mount Purpose
data/ Read-only input data
src/ Live code — changes take effect without rebuilding the image
.hf-cache/ Persists HuggingFace model downloads
models/ Saved model checkpoints
outputs/ Training logs and Lightning checkpoints

See the Docker guide for all available services and build commands.


Data Prerequisites

Tabular-only

# miRNA expression
oncolearn xena download --cohorts BRCA --category mirna_seq --unzip

# Clinical / stage labels
oncolearn xena download --cohorts BRCA --category clinical --unzip

Multimodal (cBioPortal)

# mRNA + clinical data (cBioPortal)
oncolearn cbioportal download --cohorts BRCA

# Imaging (TCIA)
oncolearn tcia download --cohorts BRCA --yes

Expected directory layout:

data/
└── sources/
    ├── xenabrowser/
    │   └── TCGA-BRCA/
    │       ├── TCGA-BRCA.mirna.tsv
    │       └── TCGA-BRCA.clinical.tsv
    ├── cbioportal/
    │   └── TCGA-BRCA/
    └── tcia/
        └── TCGA-BRCA/
            └── TCIA_TCGA-BRCA_*/
                └── TCGA-BRCA/

K-Fold Splits

Generate stratified K-fold patient splits before training for reproducible cross-validation.

Tabular-only splits

oncolearn xena preprocess \
  --config data/configs/modeling/multimodal/tcga_brca_tabular_only.yaml \
  --n_splits 5 --seed 42

Multimodal splits

# PAM50 subtype labels
oncolearn preprocess multimodal kfold 5 --stratified --label pam50

# AJCC stage labels
oncolearn preprocess multimodal kfold 5 --stratified --label stage

The multimodal command automatically intersects gene ∩ image ∩ clinical patient sets before stratifying. See oncolearn preprocess multimodal kfold for all options.

Then add the split folder to your config:

data:
  splits_dir: data/configs/modeling/multimodal/splits/pam50/kfold/fold_0

Output

The trainer writes per-epoch metrics to stdout via PyTorch Lightning and saves checkpoints to outputs/<experiment_name>/:

Epoch 2/50  ━━━━━━━━━━━━━━━━━ 8/8
  train_loss: 1.241
  val_loss:   1.439
  val_acc:    0.375

Checkpoints saved:

  • best_model.ckpt — best validation accuracy
  • epoch_N.ckpt — periodic snapshots (every save_every_n_epochs epochs)

Early stopping triggers when val_acc does not improve for early_stopping_patience consecutive epochs (default: 10).


Advanced Config Options

Option Description
model.modality_dropout_prob Stochastically drop a modality during training (at least one is always kept)
training.use_class_weights Inverse-frequency class weighting to handle imbalanced stages
data.splits_dir Path to external train.txt / test.txt / validation.txt split files
training.hpo Optuna hyperparameter optimisation config (n_trials, search_space)
training.cross_validation Cross-validation over pre-generated K-fold dirs; metrics averaged across folds
training.accelerator "auto" detects GPU automatically; "cpu" forces CPU
training.regularization.gradient_clip_val Gradient clipping (0 = disabled)
training.regularization.l1_lambda L1 weight regularisation strength

Clone this wiki locally