diff --git a/docs/source/conf.py b/docs/source/conf.py index dbd0f1b83..41b253732 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -365,6 +365,7 @@ def _modules_to_rst() -> List[types.ModuleType]: document_modules: List[types.Module] = [ streaming, streaming.base.compression, + streaming.base.coord, streaming.base.format, streaming.base.hashing, streaming.base.partition, diff --git a/setup.py b/setup.py index c70a84143..de2116c9e 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ 'azure-storage-blob>=12.0.0,<13', 'azure-storage-file-datalake>=12.11.0,<13', 'azure-identity>=1.13.0', + 'psutil==5.9.4', ] extra_deps = {} diff --git a/streaming/base/coord/__init__.py b/streaming/base/coord/__init__.py new file mode 100644 index 000000000..1bd1d49d9 --- /dev/null +++ b/streaming/base/coord/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Functionality having to do with coordination between replicas.""" diff --git a/streaming/base/coord/filesystem/__init__.py b/streaming/base/coord/filesystem/__init__.py new file mode 100644 index 000000000..5febe7967 --- /dev/null +++ b/streaming/base/coord/filesystem/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Coordinating using files.""" + +from streaming.base.coord.filesystem.waiting import (create_file, wait_for_creation, + wait_for_deletion) + +__all__ = ['create_file', 'wait_for_creation', 'wait_for_deletion'] diff --git a/streaming/base/coord/filesystem/waiting.py b/streaming/base/coord/filesystem/waiting.py new file mode 100644 index 000000000..daf4a8996 --- /dev/null +++ b/streaming/base/coord/filesystem/waiting.py @@ -0,0 +1,71 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Waiting on files.""" + +import os +from typing import Any, Optional + +from streaming.base.coord.waiting import wait + +__all__ = ['wait_for_creation', 'wait_for_deletion', 'create_file'] + + +def wait_for_creation( + path: str, + timeout: Optional[float] = 30, + tick: float = 0.007, + lock: Optional[Any] = None, +) -> None: + """Wait for the creation of a path on the local filesystem. + + Args: + path (str): Local path to wait on the creation of. + timeout (float, optional): How long to wait before raising an exception, in seconds. + Defaults to ``30``. + tick (float): Check interval, in seconds. Defaults to ``0.007``. + lock (Any, optional): Context manager (this is intended for locks) to be held when + checking the predicate. Defaults to ``None``. + """ + + def stop(): + return os.path.exists(path) + + wait(stop, timeout, tick, lock) + + +def wait_for_deletion( + path: str, + timeout: Optional[float] = 30, + tick: float = 0.007, + lock: Optional[Any] = None, +) -> None: + """Wait for the deletion of a path on the local filesystem. + + Args: + path (str): Local path to wait on the deletion of. + timeout (float, optional): How long to wait before raising an exception, in seconds. + Defaults to ``30``. + tick (float): Check interval, in seconds. Defaults to ``0.007``. + lock (Any, optional): Context manager (this is intended for locks) to be held when + checking the predicate. Defaults to ``None``. + """ + + def stop(): + return not os.path.exists(path) + + wait(stop, timeout, tick, lock) + + +def create_file(filename: str) -> None: + """Create a file at the given path on the local filesystem. + + Raises an exception if the path already exists. + + Args: + filename (str): Filename to create. + """ + dirname = os.path.dirname(filename) + os.makedirs(dirname, exist_ok=True) + with open(filename, 'x'): + pass diff --git a/streaming/base/coord/job/__init__.py b/streaming/base/coord/job/__init__.py new file mode 100644 index 000000000..69207a1db --- /dev/null +++ b/streaming/base/coord/job/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Handling for jobs, which are collections of StreamingDataset replicas with the same config.""" + +from streaming.base.coord.job.dir import JobDir +from streaming.base.coord.job.registry import JobRegistry + +__all__ = ['JobDir', 'JobRegistry'] diff --git a/streaming/base/coord/job/dir.py b/streaming/base/coord/job/dir.py new file mode 100644 index 000000000..e553183bd --- /dev/null +++ b/streaming/base/coord/job/dir.py @@ -0,0 +1,63 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""A directory containing all dataset-wide filesystem state for a Streaming job.""" + +import os +from typing import Sequence + +from streaming.base.coord.job.registry import JobRegistry +from streaming.base.stream import Stream +from streaming.base.world import World + +__all__ = ['JobDir'] + + +class JobDir: + """Represents a Streaming job lease. On ``__del__``, cleans up after itself. + + When it goes out of scope naturally, this Job will delete its config dir and its hold on all + the local dirs it is streaming to. + + If this process dies badly and the destructor is not reached, the same cleanup will be done by + some future process incidentally as it registers or unregisters a Streaming job. It can tell it + died by a combination of pid and process create time. + + Args: + registry (JobRegistry): Stremaing job registry. + """ + + def __init__(self, registry: JobRegistry, streams: Sequence[Stream], world: World) -> None: + self.registry = registry + self.streams = streams + self.world = world + self.job_hash = registry.register(streams, world) + + def get_filename(self, path: str) -> str: + """Get a filename by relative path under its job dir. + + Args: + path (str): Path relative to job dir. + + Returns: + str: Filename. + """ + return os.path.join(self.registry.config_root, self.job_hash, path) + + def manual_unregister(self) -> None: + """Explicitly un-register the job ahead of its deletion. + + This is useful when you want to ensure that this job is un-registered synchronously instead + of whenever the garbage collector eventually gets around to it. + + This job must be registered when this is called. + """ + self.registry.unregister(self.job_hash, self.world, True) + + def __del__(self) -> None: + """Destructor. + + You may unregister the job explicitly ahead of time (to ensure it happens synchronously + instead of eventually). + """ + self.registry.unregister(self.job_hash, self.world, False) diff --git a/streaming/base/coord/job/entry.py b/streaming/base/coord/job/entry.py new file mode 100644 index 000000000..6ebf88a6f --- /dev/null +++ b/streaming/base/coord/job/entry.py @@ -0,0 +1,65 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""An entry in a Streaming job registry file.""" + +from typing import Any, Dict, List, Optional + +from typing_extensions import Self + +__all__ = ['JobEntry'] + + +class JobEntry: + """Info about a Streaming job for local dir reuse detection purposes. + + Args: + index (int, optional): The job's index in the total list. + job_hash (str): Job hash. + stream_hashes (List[str]): Stream hashes. + stream_locals (List[str], optional): Stream locals, if available. + process_id (int): PID of local rank zero of the Streaming job. + register_time (int): Process registration time. + """ + + def __init__( + self, + *, + index: Optional[int] = None, + job_hash: str, + stream_hashes: List[str], + stream_locals: Optional[List[str]] = None, + process_id: int, + register_time: int, + ) -> None: + self.index = index + self.job_hash = job_hash + self.stream_hashes = stream_hashes + self.stream_locals = stream_locals + self.process_id = process_id + self.register_time = register_time + + @classmethod + def from_json(cls, obj: Dict[str, Any]) -> Self: + """Load from JSON. + + Args: + obj (Dict[str, Any]): Source JSON object. + + Returns: + Self: Loaded JobEntry. + """ + return cls(job_hash=obj['job_hash'], + stream_hashes=obj['stream_hashes'], + stream_locals=obj.get('stream_locals'), + process_id=obj['process_id'], + register_time=obj['register_time']) + + def to_json(self) -> Dict[str, Any]: + return { + 'job_hash': self.job_hash, + 'stream_hashes': self.stream_hashes, + # stream_locals is not saved, only their hashes. + 'process_id': self.process_id, + 'register_time': self.register_time, + } diff --git a/streaming/base/coord/job/file.py b/streaming/base/coord/job/file.py new file mode 100644 index 000000000..cbb631b53 --- /dev/null +++ b/streaming/base/coord/job/file.py @@ -0,0 +1,145 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""A Streaming job registry file.""" + +import json +import os +from typing import Dict, List + +from typing_extensions import Self + +from streaming.base.coord.job.entry import JobEntry + +__all__ = ['RegistryFile'] + + +class RegistryFile: + """StreamingDataset job registry, which is backed by a JSON file. + + Args: + jobs (List[JobEntry]): List of StreamingDataset jobs. + """ + + def __init__(self, jobs: List[JobEntry]) -> None: + self.jobs = [] + self.job_hash2job = {} + self.stream_hash2job = {} + self.num_jobs = 0 + for job in jobs: + self.add(job) + + @classmethod + def read(cls, filename: str) -> Self: + if os.path.exists(filename): + try: + obj = json.load(open(filename)) + except: + os.remove(filename) + obj = {} + else: + obj = {} + jobs = obj.get('jobs') or [] + jobs = [JobEntry.from_json(job) for job in jobs] + return cls(jobs) + + def write(self, filename: str) -> None: + jobs = [job.to_json() for job in filter(bool, self.jobs)] + obj = {'jobs': jobs} + with open(filename, 'w') as out: + json.dump(obj, out) + + def __len__(self) -> int: + """Get the number of jobs registered. + + Returns: + int: Number of registered jobs. + """ + return self.num_jobs + + def add(self, job: JobEntry) -> None: + """Register a Streaming job. + + Args: + job (Job): The job. + """ + # Check that stream locals line up. + if job.stream_locals: + if len(job.stream_hashes) != len(job.stream_locals): + raise ValueError(f'If locals are provided, must have one local per stream hash, ' + + f'but got: {len(job.stream_hashes)} hashes vs ' + + f'{len(job.stream_locals)} locals.') + norm_stream_locals = job.stream_locals + else: + norm_stream_locals = [None] * len(job.stream_hashes) + + # Check dataset hash for reuse. + if job.job_hash in self.job_hash2job: + if job.stream_locals: + raise ValueError(f'Reused dataset local path(s): {job.stream_locals}.') + else: + raise ValueError(f'Reused dataset local path(s): stream hashes = ' + + f'{job.stream_hashes}, dataset hash = {job.job_hash}.') + + # Check each stream hash for reuse. + for stream_hash, norm_stream_local in zip(job.stream_hashes, norm_stream_locals): + if stream_hash in self.stream_hash2job: + if norm_stream_local: + raise ValueError('Reused stream local path: {norm_stream_local}.') + else: + raise ValueError('Reused stream local path: stream hash = {stream_hash}.') + + # Do the insertion. + job.index = len(self.jobs) + self.jobs.append(job) + self.job_hash2job[job.job_hash] = job + for stream_hash in job.stream_hashes: + self.stream_hash2job[stream_hash] = job + self.num_jobs += 1 + + def contains(self, job_hash: str) -> bool: + """Tell whether the given job_hash is registered. + + Args: + job_hash (str): Potentially registered job hash. + + Returns: + bool: Whether the job hash is registered. + """ + return job_hash in self.job_hash2job + + def remove(self, job_hash: str) -> None: + """Deregister a Streaming job. + + Args: + job_hash (str): Job hash. + """ + job = self.job_hash2job.get(job_hash) + if not job: + raise ValueError(f'Job hash not found: {job_hash}.') + + if job.index is None: + raise ValueError('Internal error in job registration: job index is missing.') + + self.jobs[job.index] = None + del self.job_hash2job[job.job_hash] + for stream_hash in job.stream_hashes: + del self.stream_hash2job[stream_hash] + self.num_jobs -= 1 + + def filter(self, pid2create_time: Dict[int, int]) -> List[str]: + """Filter our collection of Streaming jobs. + + Args: + pid2create_time (Dict[int, int]): Mapping of pid to creation time. + + Returns: + List[str]: List of hashes of removed datasets. + """ + del_job_hashes = [] + for job in filter(bool, self.jobs): + create_time = pid2create_time.get(job.process_id) + if not create_time or job.register_time < create_time: + self.remove(job.job_hash) + del_job_hashes.append(job.job_hash) + return del_job_hashes diff --git a/streaming/base/coord/job/registry.py b/streaming/base/coord/job/registry.py new file mode 100644 index 000000000..0ce748175 --- /dev/null +++ b/streaming/base/coord/job/registry.py @@ -0,0 +1,262 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""A directory containing all Streaming-wide filesystem state. + +Useful for detecting collisions between different jobs' local dirs. +""" + +import os +from hashlib import sha3_224 +from shutil import rmtree +from time import time_ns +from typing import Dict, List, Optional, Sequence, Tuple + +from filelock import FileLock +from psutil import process_iter + +from streaming.base.coord.filesystem.waiting import wait_for_creation, wait_for_deletion +from streaming.base.coord.job.entry import JobEntry +from streaming.base.coord.job.file import RegistryFile +from streaming.base.stream import Stream +from streaming.base.world import World + +__all__ = ['JobRegistry'] + + +class JobRegistry: + """StreamingDataset job registry, for the purpose of detecting local dir reuse. + + This class is safe for concurrent access via a filelock. + + Args: + config_root (str): Streaming configuration root directory, used for collision detection, + filelock paths, etc. Defaults to ``/tmp/streaming``, using the equivalent temp root on + your system. + timeout (float, optional): How long to wait before raising an exception, in seconds. + Defaults to ``30``. + tick (float): Check interval, in seconds. Defaults to ``0.007``. + """ + + def __init__( + self, + config_root: str, + timeout: Optional[float] = 30, + tick: float = 0.007, + ) -> None: + self.config_root = config_root + self.timeout = timeout + self.tick = tick + + self.lock_filename = os.path.join(config_root, 'registry.lock') + self.lock = FileLock(self.lock_filename) + + self.registry_filename = os.path.join(config_root, 'registry.json') + + def _get_live_procs(self) -> Dict[int, int]: + """List the pids and creation times of every live process in the system. + + The creation times protect us from PID reuse. + + Returns: + Dict[int, int]: Mapping of pid to integer creation time. + """ + ret = {} + for obj in process_iter(['pid', 'create_time']): + ret[obj.pid] = int(obj.create_time() * 1e9) + return ret + + def _hash(self, data: bytes) -> str: + """Get a short, deterministic, fixed-length code for the given data. + + Args: + data (bytes): The data to hash. + + Returns: + str: Truncated hex digest. + """ + return sha3_224(data).hexdigest()[:8] + + def _hash_streams(self, streams: Sequence[Stream]) -> Tuple[List[str], List[str], str]: + """Get a short, opaque str key for a StreamingDataset and each of its Streams. + + This is useful for collision detection. + + Args: + streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in + combination with process IDs and creation times lets us uniquely identify a + Streaming job. + + Returns: + Tuple[str, List[str], List[str]]: Triple of (normalized stream locals, stream hashes, + and dataset hash). + """ + # Get a list of the normalized locals of each Stream. + stream_locals = [] + for stream in streams: + local = os.path.join(stream.local, stream.split) + local = os.path.normpath(local) + local = os.path.abspath(local) + stream_locals.append(local) + + # Collect the locals into a deduped set. + stream_locals_set = set() + for local in stream_locals: + if local in stream_locals_set: + raise ValueError(f'Reused local path: {local}.') + stream_locals_set.add(local) + + # Verify that no local is contained within another local. + for local in stream_locals: + parts = local.split(os.path.sep)[1:] + for num_parts in range(1, len(parts) - 1): # Leftmost is '' because they start with /. + parent = os.path.sep.join(parts[:num_parts]) + if parent in stream_locals_set: + raise ValueError(f'One local path contains another local path: {parent} vs ' + + f'{local}.') + + # Hash each local. + stream_hashes = [] + for local in sorted(stream_locals): + data = local.encode('utf-8') + stream_hash = self._hash(data) + stream_hashes.append(stream_hash) + + # Hash the dataset. + text = ','.join(stream_hashes) + data = text.encode('utf-8') + job_hash = self._hash(data) + + return stream_locals, stream_hashes, job_hash + + def _make_job_dir(self, job_hash: str) -> None: + """Create a Streaming job config dir. + + Args: + job_hash: Streaming config subdir for this job. + """ + dirname = os.path.join(self.config_root, job_hash) + os.makedirs(dirname) + + def _remove_job_dir(self, job_hash: str) -> None: + """Delete a Streaming job config dir. + + Args: + job_hash: Streaming config subdir for this job. + """ + dirname = os.path.join(self.config_root, job_hash) + rmtree(dirname) + + def register(self, streams: Sequence[Stream], world: World) -> str: + """Register or look up this collection of StreamingDataset replicas. + + Called by all ranks. + + Args: + streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in + combination with process IDs and creation times lets us uniquely identify a + Streaming job. + world (World): Rank-wise world state. + + Returns: + str: Subdir for this collection of StreamingDataset replicas. + """ + if not world.is_local_leader: + _, _, job_hash = self._hash_streams(streams) + dirname = os.path.join(self.config_root, job_hash) + wait_for_creation(dirname, self.timeout, self.tick, self.lock) + return job_hash + + # Collect our stream locals and hash them, resulting in a job hash. + stream_locals, stream_hashes, job_hash = self._hash_streams(streams) + + with self.lock: + # Get registration time. + register_time = time_ns() + + # Load the job database. + db = RegistryFile.read(self.registry_filename) + + # Perform liveness checks on the jobs we have registered. + pid2create_time = self._get_live_procs() + del_job_hashes = db.filter(pid2create_time) + + # Add an entry for this job. + pid = os.getpid() + create_time = pid2create_time.get(pid) + if create_time is None: + raise RuntimeError('`psutil` thinks we are dead, and yet here we are: pid {pid}.') + entry = JobEntry(job_hash=job_hash, + stream_hashes=stream_hashes, + stream_locals=stream_locals, + process_id=pid, + register_time=register_time) + db.add(entry) + + # Save the new db to disk. + db.write(self.registry_filename) + + # Add and remove job directories accordingly. + self._make_job_dir(job_hash) + map(self._remove_job_dir, del_job_hashes) + + return job_hash + + def is_registered(self, job_hash: str) -> bool: + """Tell whether the given job_hash is registered. + + Called by all ranks. + + Args: + job_hash (str): Potentially registered job hash. + + Returns: + bool: Whether the job hash is registered. + """ + dirname = os.path.join(self.config_root, job_hash) + with self.lock: + return os.path.isdir(dirname) + + def unregister(self, job_hash: str, world: World, strict: bool = True) -> None: + """Unregister this collection of StreamingDataset replicas. + + Called by all ranks. + + Args: + job_hash (str): Subdir identifying this Streaming job. + world (World): Rank-wise world state. + strict (bool): If strict, require the job to be currently registered at start. + """ + if not world.is_local_leader: + dirname = os.path.join(self.config_root, job_hash) + wait_for_deletion(dirname, self.timeout, self.tick, self.lock) + return + + with self.lock: + # Load the job database. + db = RegistryFile.read(self.registry_filename) + + # Check if the job hash is registered. + was_registered = db.contains(job_hash) + + # If strict, require the job to be registered. + if strict and not was_registered: + raise ValueError(f'Attempted to unregister job {job_hash}, but it was not ' + + f'registered.') + + # Unregister the job, if it is registered. + if was_registered: + db.remove(job_hash) + self._remove_job_dir(job_hash) + + # Perform liveness checks on the jobs we have registered. + pid2create_time = self._get_live_procs() + del_job_hashes = db.filter(pid2create_time) + + # If we unregistered the job and/or we garbage collected job(s), save the new jobs + # database back to disk. + if was_registered or del_job_hashes: + db.write(self.registry_filename) + + # Remove each directory corresponding to a job that was garbage collected. + map(self._remove_job_dir, del_job_hashes) diff --git a/streaming/base/coord/waiting.py b/streaming/base/coord/waiting.py new file mode 100644 index 000000000..92a640630 --- /dev/null +++ b/streaming/base/coord/waiting.py @@ -0,0 +1,72 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Waiting on predicates.""" + +from contextlib import nullcontext +from time import sleep, time +from typing import Any, Callable, Optional + +__all__ = ['wait'] + + +def _say_duration(duration: float) -> str: + """Pretty-print a duration. + + Args: + duration (float): The duration as a float. + + Returns: + str: The duration as a str. + """ + return f'{duration:.3f}'.rstrip('0').rstrip('.') + + +def wait( + stop: Callable[[], bool], + timeout: Optional[float] = 30, + tick: float = 0.007, + lock: Optional[Any] = None, +) -> None: + """Wait for the predicate to succeed. + + Args: + stop (Callable[[], bool]): When this check returns True, you break out of the retry loop. + timeout (float, optional): How long to wait before raising an exception, in seconds. + Defaults to ``30``. + tick (float): Check interval, in seconds. Defaults to ``0.007``. + lock (Any, optional): Context manager (this is intended for locks) to be held when + checking the predicate. Defaults to ``None``. + """ + start = time() + + if timeout is not None and timeout <= 0: + raise ValueError(f'Timeout must be positive if provided, but got: ' + + f'{_say_duration(timeout)} sec.') + + if tick <= 0: + raise ValueError(f'Tick must be positive if provided, but got: {_say_duration(tick)} sec.') + + if lock is not None: + if not hasattr(lock, '__enter__'): + raise ValueError(f'Lock must support `__enter__`, but got: {lock}.') + + if not hasattr(lock, '__exit__'): + raise ValueError(f'Lock must support `__exit__`, but got: {lock}.') + + norm_lock = lock + else: + norm_lock = nullcontext() + + while True: + with norm_lock: + if stop(): + break + + if timeout is not None: + now = time() + if timeout <= now - start: + raise RuntimeError(f'Wait timed out: timeout {_say_duration(timeout)} sec vs ' + + f'elapsed {_say_duration(now - start)} sec.') + + sleep(tick) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 3869244b5..283606429 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -12,9 +12,10 @@ from concurrent.futures._base import Future from enum import IntEnum from math import ceil +from multiprocessing import Process from tempfile import gettempdir from threading import Event, Lock -from time import sleep, time_ns +from time import sleep, time, time_ns from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union import numpy as np @@ -28,6 +29,8 @@ from streaming.base.constant import (BARRIER, BARRIER_FILELOCK, CACHE_FILELOCK, CACHE_USAGE, EPOCH_DATA, EPOCH_SHAPE, NEXT_EPOCH, RESUME, SHARD_ACCESS_TIMES, SHARD_STATES, TICK) +from streaming.base.coord.job.dir import JobDir +from streaming.base.coord.job.registry import JobRegistry from streaming.base.distributed import maybe_init_dist from streaming.base.format import get_index_basename from streaming.base.sampling import get_sampling @@ -188,8 +191,13 @@ class StreamingDataset(Array, IterableDataset): * What to iterate: + * Dataset/job registry: + + * ``config_root`` + * One or more streams (you must provide either ``streams`` or ``remote``/``local``): + * ``epoch_size`` * ``streams`` * ``remote`` * ``local`` @@ -203,11 +211,16 @@ class StreamingDataset(Array, IterableDataset): * ``validate_hash`` * ``keep_zip`` - * Absolute dataset size, if streams were weighted relatively: + * How to iterate: - * ``epoch_size`` + * Epoch pre-generation: - * How to iterate: + * ``init_pregen_epoch`` + * ``inti_pregen_sample`` + * ``pregen_next_epoch`` + * ``pregen_epoch_timeout`` + * ``pregen_epoch_tick`` + * ``num_workers`` * Shard lifecycle: @@ -238,6 +251,14 @@ class StreamingDataset(Array, IterableDataset): Args: + config_root (str, optional): Streaming configuration root directory, used for collision + detection, filelock paths, etc. If ``None``, uses a ``/streaming/`` subdir under your + system's temp root. Defaults to ``None``. + epoch_size (int | str, optional): Number of samples to draw per epoch balanced + across all streams. If ``None``, takes its value from the total number of underlying + samples. Provide this field if you are weighting streams relatively to target a larger + or smaller epoch size. Defaults to ``None``. Can also take in human-readable number + abbreviations (e.g., ``"100k"``, ``"64M"``, ``"77b"``, etc). Defaults to ``None``. streams (Sequence[Stream], optional): One or more streams to stream/cache samples from, which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. @@ -257,17 +278,28 @@ class StreamingDataset(Array, IterableDataset): keep_zip (bool): Whether to keep or delete the compressed form when decompressing downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to ``False``. - epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced - across all streams. If ``None``, takes its value from the total number of underlying - samples. Provide this field if you are weighting streams relatively to target a larger - or smaller epoch size. Defaults to ``None``. Can also take in human-readable number - abbreviations (e.g., ``"100k"``, ``"64M"``, ``"77b"``, etc). Defaults to ``None``. + init_pregen_epoch (int, optional): What epoch to pre-generate in the background at init + time, if any. This is useful if you do a lot of work between instantiating your + StreamingDataset and iterating it. Defaults to ``None``. + init_pregen_sample (int, optional): What sample offset into the epoch to pre-generate with + in the background at init time. If ``init_pregen_epoch`` is not set, must not be set + either. Defaults to ``None``. + pregen_next_epoch (bool): Whether to pre-generate the next epoch in the background at the + start of iter after generating or loading the current about-to-be-iterated epoch. + Defaults to ``True``. + pregen_epoch_timeout (float, optional): Timeout when waiting on this epoch to be + pre-generated. Defaults to ``float(np.arange(1, 7).prod())``, i.e. 12 minutes. + pregen_epoch_tick (float): Polling interval when waiting on this epoch to be pre-generated. + Defaults to ``0xCAFE / 1337 / 42``, i.e. about 925ms. + num_workers (int, optional): Number of workers per rank, same as PyTorch DataLoader + ``num_workers``. Required iff you are pre-generating an epoch at init time, otherwise + this information is determined automatically elsewhere. Defaults to ``None``. predownload (int, optional): Target number of samples to download per worker in advance of current sample. Workers will attempt to download ahead by this many samples during, but not before, training. Recommendation is to provide a value greater than per device batch size to ensure at-least per device batch size number of samples cached locally. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``. - cache_limit (Union[int, str], optional): Maximum size in bytes of this StreamingDataset's + cache_limit (int | str, optional): Maximum size in bytes of this StreamingDataset's shard cache. Before downloading a shard, the least recently used resident shard(s) may be evicted (deleted from the local cache) in order to stay under the limit. Set to ``None`` to disable shard eviction. Supports integer bytes as well as string @@ -309,8 +341,13 @@ class StreamingDataset(Array, IterableDataset): if ``False``. Defaults to ``False``. """ + pregen_todos_lock_path = 'pregen_todos.lock' + pregen_todos_path = 'pregen_todos.npy' + def __init__(self, *, + config_root: Optional[str] = None, + epoch_size: Optional[Union[int, str]] = None, streams: Optional[Sequence[Stream]] = None, remote: Optional[str] = None, local: Optional[str] = None, @@ -319,7 +356,12 @@ def __init__(self, download_timeout: float = 60, validate_hash: Optional[str] = None, keep_zip: bool = False, - epoch_size: Optional[Union[int, str]] = None, + init_pregen_epoch: Optional[int] = None, + init_pregen_sample: Optional[int] = None, + pregen_next_epoch: bool = True, + pregen_epoch_timeout: Optional[float] = float(np.arange(1, 7).prod()), + pregen_epoch_tick: float = 0xCAFE / 1337 / 42, + num_workers: Optional[int] = None, predownload: Optional[int] = None, cache_limit: Optional[Union[int, str]] = None, sampling_method: str = 'balanced', @@ -506,7 +548,55 @@ def __init__(self, # Length (__len__) is the resampled epoch size divided over the number of devices. self.length = ceil(self.epoch_size / world.num_ranks) - # Register/lookup our shared memory prefix and filelock root directory. + # Args about pre-generating epochs. + if init_pregen_epoch is not None: + if init_pregen_epoch < 0: + raise ValueError(f'Init pregen epoch must be non-negative, but got: ' + + f'{init_pregen_epoch}.') + self.init_pregen_epoch = init_pregen_epoch + + if init_pregen_sample is not None: + if not (0 <= init_pregen_sample <= self.num_samples): + raise ValueError(f'Init pregen sample must be from 0 to {self.num_samples}, but ' + + f'got: {init_pregen_sample}.') + if init_pregen_epoch is not None: + self.init_pregen_sample = init_pregen_sample or 0 + else: + if init_pregen_sample is not None: + raise ValueError(f'Init pregen epoch is not set, but init pregen sample is: ' + + f'epoch {init_pregen_epoch}, sample {init_pregen_sample}.') + self.init_pregen_sample = init_pregen_sample + + self.pregen_next_epoch = pregen_next_epoch + + if pregen_epoch_timeout is not None and pregen_epoch_timeout < 0: + raise ValueError(f'Pregen epoch timeout must be non-negative if set, but got: ' + + f'{pregen_epoch_timeout}.') + self.pregen_epoch_timeout = pregen_epoch_timeout + + if pregen_epoch_tick <= 0: + raise ValueError(f'Pregen epoch tick must be positive seconds, but got: ' + + f'{pregen_epoch_tick}.') + self.pregen_epoch_tick = pregen_epoch_tick + + self.num_workers = num_workers + + # Init registry, then register/lookup this Streaming job (new style). + self.config_root = self._get_config_root(config_root) + self._test_config_root(self.config_root) + self.registry = JobRegistry(self.config_root, 42, 0.007) + self.job = JobDir(self.registry, streams, world) + + # Maybe note some epoch to pre-generate (like epoch 0, sample offset 0)? + if self.init_pregen_epoch is not None: + self._request_pregen_epoch(self.init_pregen_epoch, self.init_pregen_sample or 0) + + # Start the epoch pre-generation loop as a daemon process. + if init_pregen_epoch is not None or pregen_next_epoch: + self.process = Process(target=self._pregen_epoch_loop, daemon=True) + self.process.run() + + # Register/lookup our shared memory prefix and filelock root directory (old style). streams_local = [os.path.abspath(os.path.join(x.local, x.split)) for x in streams] streams_remote = [ os.path.join(x.remote, x.split) if x.remote is not None else None for x in streams @@ -589,13 +679,51 @@ def __init__(self, del self._shared_barrier.lock # Remote the lock that makes it unpickleable. + self._dummy = None + def __del__(self) -> None: - """Destructor, which releases its local working directories.""" + """Destructor,kill which releases its local working directories.""" if hasattr(self, '_locals_shm'): try: self._locals_shm.buf[:4] = np.int32(0).tobytes() except: pass + self.job.manual_unregister() + + @classmethod + def _test_config_root(cls, config_root: str) -> None: + """Validate that the provided config root is usable. + + If you are unable to get root or 777 perms, you may encounter problems in registering your + Streaming jobs for collision detection, getting unique interprocess filelock paths, etc. + You can sort of get around this by changing config root to a directory you control, but + this may negatively impact collision detection. + + Args: + config_root (str): Streaming configuration root directory. + """ + os.makedirs(config_root, exist_ok=True) + filename = os.path.join(config_root, 'test.txt') + try: + with open(filename, 'wb') as out: + out.write(b'') + except: + raise ValueError('Please provide a `config_root` dir that is writeable and readable.') + os.remove(filename) + + @classmethod + def _get_config_root(cls, config_root: Optional[str] = None) -> str: + """Get the Streaming configuration root directory. + + Args: + config_root (str, optional): Config root, if explicitly provided. Defaults to ``None``. + + Returns: + str: Streaming configuration root directory. + """ + if config_root is None: + config_root = os.path.join(gettempdir(), 'streaming') + return config_root @property def size(self) -> int: @@ -942,7 +1070,198 @@ def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]: return sample_ids, shape_shm, data_shm - def _get_work(self, world: World, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]: + def _locate_epoch_work(self, epoch: int, sample: int) -> str: + """Get the filename for generated epoch work given its epoch and sample offset. + + Args: + epoch (int): Which epoch. + sample (int): What sample offset. + + Returns: + str: Filename of serialized epoch work. + """ + return self.job.get_filename(f'epoch.{epoch:06}.{sample:012}.npy') + + def _serialize_epoch_work(self, work: NDArray[np.int64]) -> bytes: + """Serialize a 5-dimensional sample ID arrangement tensor to bytes. + + Args: + work (NDArray[np.int64]): Sample IDs tensor. + + Returns: + bytes: The serialized data. + """ + # Serialize to bytes prefixed with shape (we use int64 for alignment reasons). + return b''.join([ + np.int64(work.ndim).tobytes(), + np.array(work.shape, np.int64).tobytes(), + work.tobytes(), + ]) + + def _deserialize_epoch_work(self, data: bytes) -> NDArray[np.int64]: + """Deserialize a 5-dimensional sample ID arrangement tensor from bytes. + + Args: + data (bytes): The serialized data. + + Returns: + NDArray[np.int64]: Sample IDs tensor. + """ + arr = np.ndarray(shape=-1, dtype=np.int64, buffer=data) + ndim = arr[0] + shape = tuple(arr[1:1 + ndim].tolist()) + offset = (1 + ndim) * np.int64().nbytes + return np.ndarray(shape, np.int64, arr, offset) + + def _pregen_epoch(self, epoch: int, sample: int) -> None: + """Pre-generate the sample ID arrangement for some epoch. + + This is typically run in the background in a daemon process. + + Args: + epoch (int): Which epoch. + sample (int): What sample offset. + """ + if self.num_workers is None: + raise ValueError(f'You must provide DataLoader num_workers to StreamingDataset in ' + + f'order for it to be able to pre-generate the epoch at init time.') + + # Locate epoch data, e.g. "epoch.000000007.000000001000.npy". + filename = self._locate_epoch_work(epoch, sample) + + # If there is already a file there, either someone has pre-generated it already (non-empty) + # or they are in the process of pre-generating it (empty) and we are done. If no file, + # create one to claim it ourself. + try: + with open(filename, 'xb'): + pass + except: + return + + # Create the world a worker will see. + world = World() + if 1 < self.num_workers: + world.workers_per_rank = self.num_workers + world.num_workers = world.num_ranks * world.workers_per_rank + world.workers_per_node = world.ranks_per_node * world.workers_per_rank + + # Do the epoch generation (heavy). + work = generate_work(self.batching_method, self, world, epoch, sample) + + # Serialize to bytes. + data = self._serialize_epoch_work(work) + + # Write those bytes, to be picked up by the main process/thread. + tmp_filename = filename + '.tmp' + with open(tmp_filename, 'wb') as out: + out.write(data) + os.rename(tmp_filename, filename) + + def _push_back_pregen_epoch_todo(self, todo_filename: str, epoch: int, sample: int) -> None: + now = time_ns() + todo = np.array([epoch, sample, now], np.int64) + todo = np.expand_dims(todo, 0) + if os.path.exists(todo_filename): + old = np.fromfile(todo_filename, np.int64) + old = old.reshape(-1, 3) + new = np.concatenate([old, todo], 0) + else: + new = todo + new.tofile(todo_filename) + + def _pop_front_pregen_epoch_todo(self, todo_filename: str) -> Tuple[int, int, int]: + old = np.fromfile(todo_filename, np.int64) + old = old.reshape(-1, 3) + todo = old[0] + new = old[1:] + if len(new): + new.tofile(todo_filename) + else: + os.remove(todo_filename) + return tuple(todo.tolist()) + + def _request_pregen_epoch(self, epoch: int, sample: int) -> None: + lock_filename = self.job.get_filename(self.pregen_todos_lock_path) + todo_filename = self.job.get_filename(self.pregen_todos_path) + with FileLock(lock_filename): + self._push_back_pregen_epoch_todo(todo_filename, epoch, sample) + + def _each_pregen_epoch_todo(self) -> Iterator[Tuple[int, int]]: + lock_filename = self.job.get_filename(self.pregen_todos_lock_path) + todo_filename = self.job.get_filename(self.pregen_todos_path) + dirname = os.path.dirname(lock_filename) + os.makedirs(dirname, exist_ok=True) + lock = FileLock(lock_filename) + while True: + with lock: + if os.path.exists(todo_filename): + epoch, sample, _ = self._pop_front_pregen_epoch_todo(todo_filename) + yield epoch, sample + if not hasattr(self, '_dummy'): + break + sleep(0.1337) + + def _pregen_epoch_loop(self) -> None: + for epoch, sample in self._each_pregen_epoch_todo(): + self._pregen_epoch(epoch, sample) + + def _gen_epoch(self, world: World, epoch: int, sample: int) -> NDArray[np.int64]: + """Generate (or load pre-generated) the sample ID arrangement for some epoch. + + Args: + world (World): The world dimensions to generate it for. + epoch (int): Which epoch. + sample (int): What sample offset. + + Returns: + NDArray[np.int64]: 5-dim sample IDs tensor. + """ + # Get where our pre-generated epoch data would be found, if it exists. + filename = self._locate_epoch_work(epoch, sample) + + # If the file is taken, it either is populated or will be soon. If not, we have to generate + # the epoch ourself. + if os.path.exists(filename): + # Wait for the file to become populated. + then = time() + while True: + # If it's populated, break out. + stat = os.stat(filename) + if stat.st_size: + break + + # If it's not yet populated, you then check how much time we've taken. + now = time() + elapsed = now - then + if self.pregen_epoch_timeout is not None and self.pregen_epoch_timeout < elapsed: + raise ValueError(f'Timed out while waiting on epoch pre-generation: epoch ' + + f'{epoch}, sample {sample}, timeout ' + + f'{self.pregen_epoch_timeout}, elapsed {elapsed}.') + + # If we're still waiting, sleep a bit. + sleep(self.pregen_epoch_tick) + + # Deserialize the populated file. + data = open(filename, 'rb').read() + work = self._deserialize_epoch_work(data) + else: + # Claim the epoch generation work, preventing the epoch pregen process from doing it. + try: + with open(filename, 'xb'): + pass + except: + pass + + # Generate the epoch ourself. + work = generate_work(self.batching_method, self, world, epoch, sample) + + # Maybe pre-generate the next epoch in the background. + if self.pregen_next_epoch: + self._request_pregen_epoch(epoch + 1, 0) + + return work + + def _get_epoch(self, world: World, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]: """Get this worker's partition of this epoch's sample space. Args: @@ -960,8 +1279,7 @@ def _get_work(self, world: World, epoch: int, sample_in_epoch: int) -> NDArray[n # Do expensive work that may use a lot of cores/memory just once, in the local leader. if world.is_local_leader: - epoch_sample_ids = generate_work(self.batching_method, self, world, epoch, - sample_in_epoch) + epoch_sample_ids = self._gen_epoch(world, epoch, sample_in_epoch) shape_shm, data_shm = self._share_work(epoch_sample_ids) self._shared_barrier(world.workers_per_node) else: @@ -1418,7 +1736,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: epoch, sample_in_epoch = self._resume_incr_epoch(world) # Get this worker's partition of samples to process. - sample_ids = self._get_work(world, epoch, sample_in_epoch) + sample_ids = self._get_epoch(world, epoch, sample_in_epoch) if not len(sample_ids): # Resumed at end of epoch, out of samples. return diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 7ef98dfec..fd8aebc56 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -861,6 +861,9 @@ def test_multiple_dataset_instantiation(local_remote_dir: Any, shuffle_seed: tup assert len(set(train_sample_order)) == len(set(val_sample_order)), 'Duplicate samples' +@pytest.mark.skip('We could be resuming with shard files not all be in their final phases, so ' + + 'the directory could still change on the fly even if there is no remote, so ' + + 'we cannot reuse local even in this case.') def test_same_local_no_remote(local_remote_dir: Tuple[str, str]): local_0, _ = local_remote_dir convert_to_mds(out_root=local_0, @@ -893,5 +896,5 @@ def test_same_local_diff_remote(local_remote_dir: Tuple[str, str]): # Build StreamingDataset _ = StreamingDataset(local=local_0, remote=remote_0, batch_size=4, num_canonical_nodes=1) # Build StreamingDataset - with pytest.raises(ValueError, match='Reused local directory.*vs.*Provide a different one.'): + with pytest.raises(ValueError): _ = StreamingDataset(local=local_0, remote=remote_1, batch_size=2, num_canonical_nodes=1)