Skip to content
Merged
11 changes: 7 additions & 4 deletions recml/core/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from etils import epath
import numpy as np
import tensorflow as tf
import jax


Iterator = clu_data.DatasetIterator
Expand Down Expand Up @@ -69,15 +70,17 @@ def _maybe_to_numpy(
return x
if hasattr(x, "_numpy"):
numpy = x._numpy() # pylint: disable=protected-access
else:
elif hasattr(x, "numpy"):
numpy = x.numpy()
else:
return x

if isinstance(numpy, np.ndarray):
# `numpy` shares the same underlying buffer as the `x` Tensor.
# Tensors are expected to be immutable, so we disable writes.
numpy.setflags(write=False)
return numpy

return tf.nest.map_structure(_maybe_to_numpy, batch)
return jax.tree.map(_maybe_to_numpy, batch)

@property
def element_spec(self) -> clu_data.ElementSpec:
Expand Down Expand Up @@ -109,7 +112,7 @@ def _to_element_spec(
)
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))

element_spec = tf.nest.map_structure(_to_element_spec, batch)
element_spec = jax.tree.map(_to_element_spec, batch)
self._element_spec = element_spec
return element_spec

Expand Down
2 changes: 1 addition & 1 deletion recml/core/ops/embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class SparsecoreParams:
"""Embedding parameters."""

feature_specs: Nested[FeatureSpec]
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh
mesh: jax.sharding.Mesh
data_axes: Sequence[str | None]
embedding_axes: Sequence[str | None]
sharding_strategy: str
Expand Down
8 changes: 8 additions & 0 deletions recml/core/training/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import abc
from collections.abc import Mapping, Sequence
import contextlib
import dataclasses
import enum
from typing import Any, Generic, TypeVar
Expand All @@ -24,6 +25,13 @@
from recml.core.data import iterator
import tensorflow as tf

# Patch jax.spmd_mode if it doesn't exist (removed in newer JAX versions).
if not hasattr(jax, "spmd_mode"):
@contextlib.contextmanager
def _spmd_mode(*args, **kwargs):
del args, kwargs
yield
jax.spmd_mode = _spmd_mode

# pylint: disable=logging-fstring-interpolation

Expand Down
32 changes: 14 additions & 18 deletions recml/core/training/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for partitioning."""

import abc
Expand All @@ -22,7 +23,6 @@
import jax
import numpy as np


PyTree = Any
State = Any
CreateStateFn = Callable[[PyTree], State]
Expand Down Expand Up @@ -67,7 +67,8 @@ class DataParallelPartitioner(Partitioner):
"""Data parallel partitioner."""

def __init__(self, data_axis: str = "batch"):
self.mesh = jax.make_mesh((jax.device_count(),), (data_axis,))
devices = jax.devices()
self.mesh = jax.sharding.Mesh(devices, (data_axis,))
self.data_sharding = jax.sharding.NamedSharding(
self.mesh, jax.sharding.PartitionSpec(data_axis)
)
Expand Down Expand Up @@ -107,7 +108,7 @@ def _shard(x: np.ndarray) -> jax.Array:
def partition_init(
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
) -> CreateStateFn:
with jax.sharding.use_mesh(self.mesh):
with self.mesh:
if abstract_batch is not None:
abstract_state = jax.eval_shape(init_fn, abstract_batch)
specs = nn.get_partition_spec(abstract_state)
Expand All @@ -117,7 +118,7 @@ def partition_init(
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)

def _wrapped_init(batch: PyTree) -> State:
with jax.sharding.use_mesh(self.mesh):
with self.mesh:
state = init_fn(batch)
state = _maybe_unbox_state(state)
return state
Expand All @@ -130,15 +131,15 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
jit_kws["out_shardings"] = (self.state_sharding, None)
jit_kws["donate_argnums"] = (1,)

with jax.sharding.use_mesh(self.mesh):
with self.mesh:
step_fn = jax.jit(
fn,
in_shardings=(self.data_sharding, self.state_sharding),
**jit_kws,
)

def _wrapped_step(batch: PyTree, state: State) -> Any:
with jax.sharding.use_mesh(self.mesh):
with self.mesh:
return step_fn(batch, state)

return _wrapped_step
Expand Down Expand Up @@ -190,7 +191,7 @@ def __init__(
if axis_sizes[0] == -1:
axis_sizes[0] = len(devices) // math.prod(axis_sizes[1:])

self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices)
self.mesh = jax.sharding.Mesh(devices, axis_names)
self.rules = rules
self.aot_compile = aot_compile
self.options = options
Expand All @@ -213,12 +214,6 @@ def __init__(
self.abstract_batch = None
self.abstract_state = None

@property
def mesh_context_manager(
self,
) -> Callable[[jax.sharding.Mesh], ContextManager[None]]:
return jax.sharding.use_mesh

def shard_inputs(self, inputs: PyTree) -> PyTree:
def _shard(x: np.ndarray) -> jax.Array:
return jax.make_array_from_process_local_data(self.data_sharding, x)
Expand All @@ -234,7 +229,7 @@ def partition_init(
" model parallel partitioner."
)

with self.mesh_context_manager(self.mesh):
with self.mesh:
abstract_state = jax.eval_shape(init_fn, abstract_batch)
specs = nn.get_partition_spec(abstract_state)

Expand All @@ -247,7 +242,7 @@ def partition_init(
compiled_init_fn = jax.jit(init_fn, out_shardings=state_sharding)

def _init(batch: PyTree) -> State:
with self.mesh_context_manager(self.mesh):
with self.mesh:
state = compiled_init_fn(batch)
state = _maybe_unbox_state(state)
return state
Expand All @@ -265,7 +260,8 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
else:
jit_kws["out_shardings"] = None

with self.mesh_context_manager(self.mesh):

with self.mesh:
step_fn = jax.jit(
fn,
in_shardings=(self.data_sharding, self.state_sharding),
Expand All @@ -286,7 +282,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
)

def _step(batch: PyTree, state: State) -> Any:
with self.mesh_context_manager(self.mesh):
with self.mesh:
return step_fn(batch, state)

return _step
Expand All @@ -302,4 +298,4 @@ def _maybe_unbox(x: Any) -> Any:
_maybe_unbox,
x,
is_leaf=lambda k: isinstance(k, nn.Partitioned),
)
)
9 changes: 7 additions & 2 deletions recml/core/training/partitioning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def test_data_parallelism(
self, partitioner_cls: type[partitioning.Partitioner]
):
if partitioner_cls is partitioning.ModelParallelPartitioner:
kwargs = {"axes": [("data", -1), ("model", 1)], "dp_axes": 1}
devs = np.array(jax.devices()).reshape(-1, 1)
kwargs = {"axes": [("data", -1), ("model", 1)], "dp_axes": 1, "devices": devs}
else:
kwargs = {}
partitioner = partitioner_cls(**kwargs)
Expand Down Expand Up @@ -112,8 +113,12 @@ def _eval_step(
)

def test_model_parallelism(self):
devs = np.array(jax.devices()).reshape(1, -1)

partitioner = partitioning.ModelParallelPartitioner(
axes=[("data", 1), ("model", jax.device_count())], dp_axes=1
axes=[("data", 1), ("model", jax.device_count())],
dp_axes=1,
devices=devs
)

inputs = np.zeros((128, 16), dtype=np.float32)
Expand Down
6 changes: 6 additions & 0 deletions recml/examples/dlrm_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
import dataclasses
from typing import Generic, Literal, TypeVar

import sys
import os
# Add the RecML folder to the system path
sys.path.append(os.path.join(os.getcwd(), "../../../RecML"))
os.environ["KERAS_BACKEND"] = "jax"

from etils import epy
import fiddle as fdl
import flax.linen as nn
Expand Down
10 changes: 8 additions & 2 deletions recml/examples/dlrm_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
# limitations under the License.
"""Tests for the DLRM experiment."""

import sys
import os
# Add the RecML folder to the system path
sys.path.append(os.path.join(os.getcwd(), "../../../RecML"))
os.environ["KERAS_BACKEND"] = "jax"

from absl.testing import absltest
import fiddle as fdl
from fiddle import selectors
Expand All @@ -32,8 +38,8 @@ def test_dlrm_experiment(self):

experiment = dlrm_experiment.experiment()

experiment.task.train_data.global_batch_size = 4
experiment.task.eval_data.global_batch_size = 4
experiment.task.train_data.global_batch_size = 128
experiment.task.eval_data.global_batch_size = 128
experiment.trainer.train_steps = 12
experiment.trainer.steps_per_loop = 4
experiment.trainer.steps_per_eval = 4
Expand Down
Loading
Loading