Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/grelu/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,24 @@ def tune_on_dataset(
"""
# Move train data parameters
self.base_data_params = self.data_params.copy()

# Warn if fine-tuning chromosomes overlap with pretraining chromosomes
base_intervals = self.base_data_params.get("train", {}).get("intervals")
if (
base_intervals is not None
and hasattr(train_dataset, "intervals")
and train_dataset.intervals is not None
):
base_chroms = set(base_intervals["chrom"])
new_chroms = set(train_dataset.intervals["chrom"])
overlap = sorted(base_chroms & new_chroms)
if overlap:
warnings.warn(
f"Fine-tuning dataset contains {len(overlap)} chromosome(s) "
f"({', '.join(overlap)}) that overlap with the pretrained "
"model's training data. This may lead to data leakage."
)

self.data_params = {}

# Make new model head
Expand Down
34 changes: 34 additions & 0 deletions tests/test_lightning.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -344,6 +346,38 @@ def test_lightning_model_finetune():
assert multitask_reg_model.model_params["n_tasks"] == 1


def test_lightning_model_finetune_chrom_overlap_warning():
"""Warn when fine-tuning chromosomes overlap with pretraining chromosomes."""
model = generate_model(task="regression", loss="poisson", n_tasks=2)
# Simulate a pretrained model that was trained on chr1
model.data_params["train"] = {
"intervals": {"chrom": ["chr1", "chr1"], "start": [0, 100], "end": [2, 102]},
"seq_len": 2,
}

# Fine-tune with interval_dataset which also uses chr1
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
model.tune_on_dataset(
interval_dataset, interval_dataset, final_pool_func="avg"
)
overlap_warnings = [x for x in w if "overlap" in str(x.message)]
assert len(overlap_warnings) == 1
assert "chr1" in str(overlap_warnings[0].message)
assert "data leakage" in str(overlap_warnings[0].message)


def test_lightning_model_finetune_no_chrom_warning():
"""No warning when fine-tuning with non-interval datasets."""
model = generate_model(task="regression", loss="poisson", n_tasks=2)
# ldataset is string-based (no intervals), so no warning should fire
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
model.tune_on_dataset(ldataset, ldataset, final_pool_func="avg")
overlap_warnings = [x for x in w if "overlap" in str(x.message)]
assert len(overlap_warnings) == 0


def test_lightning_model_ensemble():
# Make individual models
model0 = generate_model(task="binary", loss="bce", n_tasks=2)
Expand Down