Skip to content

PyTests for distributed workflows for PyTorch #199

@cedriclim1

Description

@cedriclim1

Problem

There were a few issues in electronmicroscopy/quantem-tutorials#4 with weird DDP errors across Python versions. Due to new guidelines in Python 3.14 PEP734 new guidelines for multiprocessing: https://docs.python.org/3/whatsnew/3.14.html#multiprocessing on how to start process pools. For example, when launching a DataLoader it seems required to pass a multiprocessing_context (https://docs.pytorch.org/docs/stable/data.html):

train_dataloader = DataLoader(
    train_dataset,  # type: ignore[reportArgumentType] --> Torch datasets do not have a len method, but still works.
    batch_size=batch_size,
    num_workers=num_workers,
    sampler=train_sampler,
    shuffle=shuffle,
    pin_memory=pin_mem,
    drop_last=True,
    persistent_workers=persist,
    multiprocessing_context="spawn" # PROBLEM Here
)

Without passing this potentially get errors like this:

Image

This is slightly confusing since it presents itself as a nccl error, but passing a multiprocessing_context onto the Dataloader seems to resolve this issue.

Proposed Solution

We probably have to write some PyTests to test compatibility of DDP workflows on HPC clusters. This is slightly annoying to test since we would have to run the PyTests on NERSC on an interactive job. Not quite sure what the best things to do for PyTests in this case, but will give some more thought on it and open a PR at some point.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions