Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 52 additions & 32 deletions gigl/common/utils/compute/random.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,69 @@
"""
Matches the ``set_seed(seed, deterministic=False)`` shape used by
Hugging Face Transformers, MMEngine, and Accelerate; follows the recipe
at https://pytorch.org/docs/stable/notes/randomness.html.
"""

import os
import random
from typing import Final

import numpy as np
import torch

from gigl.common.logger import Logger

logger = Logger()

_DEFAULT_SEED: Final[int] = 42 # Answer to the Ultimate Question.
# Required on CUDA >= 10.2 when use_deterministic_algorithms(True) is set,
# otherwise cuBLAS matmuls raise RuntimeError. ":4096:8" trades ~24 MiB of
# extra cuBLAS workspace for keeping perf reasonable vs ":16:8".
_CUBLAS_WORKSPACE_CONFIG: Final[str] = ":4096:8"

def make_compute_deterministic_and_set_seed(
seed: int = 42, # Answer to the Ultimate Question of Life, The Universe, and Everything
should_consider_numpy=True,
should_consider_torch=False,
should_consider_tensorflow=False,
):
logger.info(
"""
Ensure data loading is also deterministic and you are using deterministic algorithms
for relevant frameworks, otherwise nondeterminism will persist
"""
)

# Setting PYTHONHASHSEED doesn't seem like it actually does anything
# See: https://stackoverflow.com/questions/30585108/disable-hash-randomization-from-within-python-program
# os.environ["PYTHONHASHSEED"] = "0"
def seed_everything(
seed: int = _DEFAULT_SEED,
should_enable_expensive_deterministic_compute: bool = False,
) -> None:
"""Seed Python / NumPy / PyTorch RNGs, optionally enforce deterministic torch ops.

random.seed(seed)
What gets seeded:

- ``random.seed(seed)`` — Python stdlib.
- ``np.random.seed(seed)`` — NumPy global RNG.
- ``torch.manual_seed(seed)`` — CPU **and all CUDA devices**
(``torch.manual_seed`` calls ``torch.cuda.manual_seed_all`` internally.
Also covers PyTorch Geometric.

if should_consider_numpy:
import numpy as np
When ``should_enable_expensive_deterministic_compute=True`` (opt-in; default False because it costs
throughput and should not be enabled for training or for production inference - can be used for debugging purposes.

np.random.seed(seed)
- Important: Graph Sampling currently do not follow determism outlined here.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: determism

Example:
>>> seed_everything(42)
42

if should_consider_torch:
import torch
import torch.backends.cudnn
Args:
seed: RNG seed.
deterministic: If True, also enforces bitwise-deterministic torch
ops (cudnn flags, ``use_deterministic_algorithms``,
``CUBLAS_WORKSPACE_CONFIG``). Default False — most training
pipelines want seeded RNGs without paying the throughput cost.

torch.manual_seed(seed)
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also do torch.cuda.manual_seed_all(seed) as is done here?

if should_enable_expensive_deterministic_compute:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = _CUBLAS_WORKSPACE_CONFIG
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)

if should_consider_tensorflow:
import tensorflow as tf

tf.random.set_seed(seed)
os.environ["TF_DETERMINISTIC_OPS"] = "1"
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
tf.config.threading.set_inter_op_parallelism_threads(1)
tf.config.threading.set_intra_op_parallelism_threads(1)
logger.warning(
f"seed_everything: seeded python/numpy/torch with seed={seed}; "
f"expensive deterministic algorithms ON; "
f"throughput will degrade"
)
else:
logger.info(f"seed_everything: seeded python/numpy/torch with seed={seed}")