diff --git a/src/grelu/lightning/__init__.py b/src/grelu/lightning/__init__.py index 7159fde..1ae19e8 100644 --- a/src/grelu/lightning/__init__.py +++ b/src/grelu/lightning/__init__.py @@ -676,6 +676,24 @@ def tune_on_dataset( """ # Move train data parameters self.base_data_params = self.data_params.copy() + + # Warn if fine-tuning chromosomes overlap with pretraining chromosomes + base_intervals = self.base_data_params.get("train", {}).get("intervals") + if ( + base_intervals is not None + and hasattr(train_dataset, "intervals") + and train_dataset.intervals is not None + ): + base_chroms = set(base_intervals["chrom"]) + new_chroms = set(train_dataset.intervals["chrom"]) + overlap = sorted(base_chroms & new_chroms) + if overlap: + warnings.warn( + f"Fine-tuning dataset contains {len(overlap)} chromosome(s) " + f"({', '.join(overlap)}) that overlap with the pretrained " + "model's training data. This may lead to data leakage." + ) + self.data_params = {} # Make new model head diff --git a/tests/test_lightning.py b/tests/test_lightning.py index d734ba7..7bc9e19 100644 --- a/tests/test_lightning.py +++ b/tests/test_lightning.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np import pandas as pd import torch @@ -344,6 +346,38 @@ def test_lightning_model_finetune(): assert multitask_reg_model.model_params["n_tasks"] == 1 +def test_lightning_model_finetune_chrom_overlap_warning(): + """Warn when fine-tuning chromosomes overlap with pretraining chromosomes.""" + model = generate_model(task="regression", loss="poisson", n_tasks=2) + # Simulate a pretrained model that was trained on chr1 + model.data_params["train"] = { + "intervals": {"chrom": ["chr1", "chr1"], "start": [0, 100], "end": [2, 102]}, + "seq_len": 2, + } + + # Fine-tune with interval_dataset which also uses chr1 + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + model.tune_on_dataset( + interval_dataset, interval_dataset, final_pool_func="avg" + ) + overlap_warnings = [x for x in w if "overlap" in str(x.message)] + assert len(overlap_warnings) == 1 + assert "chr1" in str(overlap_warnings[0].message) + assert "data leakage" in str(overlap_warnings[0].message) + + +def test_lightning_model_finetune_no_chrom_warning(): + """No warning when fine-tuning with non-interval datasets.""" + model = generate_model(task="regression", loss="poisson", n_tasks=2) + # ldataset is string-based (no intervals), so no warning should fire + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + model.tune_on_dataset(ldataset, ldataset, final_pool_func="avg") + overlap_warnings = [x for x in w if "overlap" in str(x.message)] + assert len(overlap_warnings) == 0 + + def test_lightning_model_ensemble(): # Make individual models model0 = generate_model(task="binary", loss="bce", n_tasks=2)