Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion streaming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import streaming.vision as vision
from streaming._version import __version__ # noqa: F401
from streaming.base import (CSVWriter, JSONWriter, LocalDataset, MDSWriter, Stream,
StreamingDataLoader, StreamingDataset, TSVWriter, XSVWriter)
StreamingDataLoader, MegatronStreamingDataLoader, StreamingDataset, MegatronStreamingDataset, TSVWriter, XSVWriter)

__all__ = [
'StreamingDataLoader',
'MegatronStreamingDataLoader',
'Stream',
'StreamingDataset',
'MegatronStreamingDataset',
'CSVWriter',
'JSONWriter',
'MDSWriter',
Expand Down
5 changes: 3 additions & 2 deletions streaming/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

"""MosaicML Streaming Datasets for cloud-native model training."""

from streaming.base.dataloader import StreamingDataLoader
from streaming.base.dataloader import StreamingDataLoader, MegatronStreamingDataLoader
from streaming.base.dataset import StreamingDataset
from streaming.base.megatron_dataset import MegatronStreamingDataset
from streaming.base.format import CSVWriter, JSONWriter, MDSWriter, TSVWriter, XSVWriter
from streaming.base.local import LocalDataset
from streaming.base.stream import Stream

__all__ = [
'StreamingDataLoader', 'Stream', 'StreamingDataset', 'CSVWriter', 'JSONWriter', 'LocalDataset',
'StreamingDataLoader', 'MegatronStreamingDataLoader', 'Stream', 'StreamingDataset', 'MegatronStreamingDataset', 'CSVWriter', 'JSONWriter', 'LocalDataset',
'MDSWriter', 'TSVWriter', 'XSVWriter'
]
24 changes: 24 additions & 0 deletions streaming/base/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from transformers.tokenization_utils_base import BatchEncoding

from streaming.base.dataset import StreamingDataset
from streaming.base.megatron_dataset import MegatronStreamingDataset, DPWorld
from streaming.base.world import World


Expand Down Expand Up @@ -99,3 +100,26 @@ def __del__(self) -> None:
"""Terminate the workers during cleanup."""
if self._iterator is not None:
self._iterator._shutdown_workers() # type: ignore [reportGeneralTypeIssues]


class MegatronStreamingDataLoader(StreamingDataLoader):
"""A streaming data loader that allows for resumable iteration with MegatronStreamingDataset.

Args:
*args: List arguments.
**kwargs: Keyword arguments.

"""

def __init__(self, *args, **kwargs) -> None: # pyright: ignore
dataset = kwargs.get('dataset', None)
dataset = dataset or args[0]
if not isinstance(dataset, MegatronStreamingDataset):
raise ValueError('MegatronStreamingDataLoader requires a MegatronStreamingDataset.')
super().__init__(*args, **kwargs)

def state_dict(self):
world = DPWorld.detect()
num_samples = self.num_samples_yielded * world.num_ranks

return self.dataset.state_dict(num_samples, False)
24 changes: 15 additions & 9 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from streaming.base.constant import (BARRIER, BARRIER_FILELOCK, CACHE_FILELOCK, CACHE_USAGE,
DEFAULT_TIMEOUT, EPOCH_DATA, EPOCH_SHAPE, NEXT_EPOCH, RESUME,
SHARD_ACCESS_TIMES, SHARD_STATES, TICK)
from streaming.base.distributed import maybe_init_dist
from streaming.base.distributed import maybe_init_dist, barrier
from streaming.base.format import get_index_basename
from streaming.base.registry_utils import construct_from_registry
from streaming.base.sampling import get_sampling
Expand Down Expand Up @@ -366,12 +366,8 @@ def __init__(self,
# * `parallel_` is who we think we are for iterating purposes, where groups of process
# must act the same if `replication` is specified.
# This can enable tensor or sequence parallelism.
world = World.detect()
self._unique_rank_world = world
if replication is not None:
self._parallel_rank_world = world.replicate(replication)
else:
self._parallel_rank_world = world.copy()
self._unique_rank_world = self._create_unique_rank_world()
self._parallel_rank_world = self._create_parallel_rank_world()
self._unique_worker_world: World
self._parallel_worker_world: World

Expand Down Expand Up @@ -538,6 +534,7 @@ def __init__(self,
streams_remote = [
os.path.join(x.remote, x.split) if x.remote is not None else None for x in streams
]

self._shm_prefix_int, self._locals_shm = get_shm_prefix(streams_local, streams_remote,
self._unique_rank_world)
self._filelock_root = gettempdir()
Expand Down Expand Up @@ -597,8 +594,7 @@ def __init__(self,
self._shard_states[shard_id] = _ShardState.LOCAL if size else _ShardState.REMOTE
self._shard_access_times[shard_id] = time_ns()

if dist.is_available() and dist.is_initialized():
dist.barrier()
barrier()

if destroy_dist:
dist.destroy_process_group()
Expand Down Expand Up @@ -677,6 +673,15 @@ def __len__(self) -> int:
int: Dataset length.
"""
return self.length

def _create_unique_rank_world(self) -> World:
return World.detect()

def _create_parallel_rank_world(self) -> World:
if self.replication is not None:
return self._unique_rank_world.replicate(self.replication)
else:
return self._unique_rank_world.copy()

def _set_shuffle_block_size(self, world: World):
"""Set the shuffle block size value."""
Expand Down Expand Up @@ -1499,6 +1504,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]:

# Get this worker's partition of samples to process.
sample_ids = self._get_work(epoch, sample_in_epoch)

if not len(sample_ids): # Resumed at end of epoch, out of samples.
return

Expand Down
22 changes: 20 additions & 2 deletions streaming/base/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from torch import Tensor
from torch import distributed as dist

import logging
logger = logging.getLogger(__name__)


__all__ = [
'all_gather', 'barrier', 'broadcast', 'get_rank', 'get_local_rank', 'get_local_world_size',
'get_world_size'
Expand Down Expand Up @@ -58,8 +62,22 @@ def get_local_world_size() -> int:

def barrier() -> None:
"""Synchronizes all processes."""
if dist.is_available() and dist.is_initialized():
dist.barrier()
try:
from streaming.base.megatron_dataset_utils import get_dataset_building_group
dataset_building_group = get_dataset_building_group()
except ImportError:
print(f'import error for megatron get_dataset_building_group\n', flush=True)
dataset_building_group = None

if dataset_building_group is None:
logger.warning('dataset_building_group is None, cannot barrier on megatron dataset builder ranks. This would lead to deadlocks if all ranks are not building and iterating datasets at the same time.')
if dist.is_available() and dist.is_initialized():
print(f'barrier on megatron all ranks\n', flush=True)
dist.barrier()
else:
print(f'barrier on megatron dataset builder ranks\n', flush=True)
dist.barrier(group=dataset_building_group)



def broadcast(tensor: Tensor, src: int) -> None:
Expand Down
59 changes: 59 additions & 0 deletions streaming/base/megatron_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from streaming.base.dataset import StreamingDataset
from streaming.base.world import World

import torch
import os

__all__ = ['MegatronStreamingDataset']


class PerNodeWorld(World):
@classmethod
def detect(cls):
from streaming.base.megatron_dataset_utils import get_dataset_builder_ranks_by_node, get_parallel_rank_info
# global _DATASET_BUILDER_RANKS_BY_NODE
all_ranks_per_node = get_dataset_builder_ranks_by_node()
parallel_rank_info = get_parallel_rank_info()
node = torch.distributed.get_rank() // int(os.environ["LOCAL_WORLD_SIZE"])
rank = all_ranks_per_node[node][parallel_rank_info]
ranks_for_node = len(all_ranks_per_node[node])
num_nodes = 1
worker_of_rank, workers_per_rank = cls._get_worker_info()
worker = rank * workers_per_rank + worker_of_rank
return cls(num_nodes, ranks_for_node, workers_per_rank, worker)

def replicate(self, replication):
raise NotImplementedError(f'PerNodeWorld does not support replicate()')

class DPWorld(World):
@classmethod
def detect(cls):
from megatron.core import mpu
rank = mpu.get_data_parallel_rank()
ranks_per_node = mpu.get_data_parallel_world_size()
num_nodes = 1
worker_of_rank, workers_per_rank = cls._get_worker_info()
worker = rank * workers_per_rank + worker_of_rank
return cls(num_nodes, ranks_per_node, workers_per_rank, worker)

def replicate(self, replication):
raise NotImplementedError(f'DPWorld does not support replicate()')


class MegatronStreamingDataset(StreamingDataset):
"""A dataset class compatible with Megatron-LM's data loading. This dataset can be used in place of Megatron's BlendedMegatronDataset

This class extends the base StreamingDataset class to ensure compatibility with Megatron-LM's distributed data loading policies for N-D parallelisms.
"""

def __init__(self, *args, **kwargs):
replication = kwargs.pop('replication', None)
assert replication is None, 'MegatronStreamingDataset does not support replication. Please remove the `replication` argument.'
super().__init__(*args, **kwargs)
print(f'Initialized MegatronStreamingDataset\n', flush=True)

def _create_unique_rank_world(self) -> World:
return PerNodeWorld.detect()

def _create_parallel_rank_world(self) -> World:
return DPWorld.detect()
Loading