Skip to content

Fix validation skipped when IterableDataset exhausts early#21552

Open
avocardio wants to merge 2 commits intoLightning-AI:masterfrom
avocardio:fix/iterable-dataset-validation-on-exhaustion
Open

Fix validation skipped when IterableDataset exhausts early#21552
avocardio wants to merge 2 commits intoLightning-AI:masterfrom
avocardio:fix/iterable-dataset-validation-on-exhaustion

Conversation

@avocardio
Copy link

@avocardio avocardio commented Feb 25, 2026

Fixes #19624

Problem

If an IterableDataset implements __len__ but yields fewer batches than expected (common with webdataset, DALI, or any streaming dataset where shard boundaries / drop_last / worker splitting cause the actual count to differ from the estimate), validation never runs — not just for that epoch, but for every subsequent epoch too.

The root cause is in _TrainingEpochLoop.run(). When _DataFetcher.__next__() hits the end of the underlying iterator, it sets done = True and re-raises StopIteration. The except StopIteration: break in the training loop then skips on_advance_end(), which is where the validation check lives. Since no training batch was actually processed on that final iteration (the fetch itself failed), skipping the per-batch bookkeeping is fine — but the end-of-epoch validation should still fire.

Fix

After the while loop exits, check if the fetcher was exhausted and run validation if appropriate. The check mirrors the existing logic in on_advance_end but only triggers the validation part, since no batch was processed.

I ran into this while training a ViT-B/32 model on webdataset shards. Validation was silently skipped every epoch until I traced it to this codepath. After the fix, validation fires reliably — confirmed across 20+ epochs on two separate multi-GPU runs.

Test

Added a regression test with a minimal IterableDataset that reports len=10 but only yields 8 samples.


📚 Documentation preview 📚: https://pytorch-lightning--21552.org.readthedocs.build/en/21552/

avocardio and others added 2 commits February 25, 2026 18:47
…length

When an IterableDataset reports a length via __len__ but produces fewer
batches (due to shard boundaries, rounding, or drop_last=True with
multiple workers), StopIteration is raised in _DataFetcher.__next__
before fetched >= length. This StopIteration propagates to the training
epoch loop's run() method, where `except StopIteration: break` exits
the loop — skipping on_advance_end() and the validation check it
contains.

The fix adds a post-loop validation check: when the data fetcher is
done (StopIteration was caught) and validation should run at the epoch
boundary, we set is_last_batch=True and run the validation check that
was skipped.

Fixes Lightning-AI#19624
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Feb 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

IterableDataset with CORRECT length causes validation loop to be skipped

1 participant