diff --git a/recml/core/data/iterator.py b/recml/core/data/iterator.py index f86c922..e296f44 100644 --- a/recml/core/data/iterator.py +++ b/recml/core/data/iterator.py @@ -21,6 +21,7 @@ from etils import epath import numpy as np import tensorflow as tf +import jax Iterator = clu_data.DatasetIterator @@ -69,15 +70,17 @@ def _maybe_to_numpy( return x if hasattr(x, "_numpy"): numpy = x._numpy() # pylint: disable=protected-access - else: + elif hasattr(x, "numpy"): numpy = x.numpy() + else: + return x + if isinstance(numpy, np.ndarray): - # `numpy` shares the same underlying buffer as the `x` Tensor. # Tensors are expected to be immutable, so we disable writes. numpy.setflags(write=False) return numpy - return tf.nest.map_structure(_maybe_to_numpy, batch) + return jax.tree.map(_maybe_to_numpy, batch) @property def element_spec(self) -> clu_data.ElementSpec: @@ -109,7 +112,7 @@ def _to_element_spec( ) return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape)) - element_spec = tf.nest.map_structure(_to_element_spec, batch) + element_spec = jax.tree.map(_to_element_spec, batch) self._element_spec = element_spec return element_spec diff --git a/recml/core/ops/embedding_ops.py b/recml/core/ops/embedding_ops.py index a1de4f0..f9e17bc 100644 --- a/recml/core/ops/embedding_ops.py +++ b/recml/core/ops/embedding_ops.py @@ -38,7 +38,7 @@ class SparsecoreParams: """Embedding parameters.""" feature_specs: Nested[FeatureSpec] - mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh + mesh: jax.sharding.Mesh data_axes: Sequence[str | None] embedding_axes: Sequence[str | None] sharding_strategy: str diff --git a/recml/core/training/core.py b/recml/core/training/core.py index 51b3451..58fc83e 100644 --- a/recml/core/training/core.py +++ b/recml/core/training/core.py @@ -15,6 +15,7 @@ import abc from collections.abc import Mapping, Sequence +import contextlib import dataclasses import enum from typing import Any, Generic, TypeVar @@ -24,6 +25,13 @@ from recml.core.data import iterator import tensorflow as tf +# Patch jax.spmd_mode if it doesn't exist (removed in newer JAX versions). +if not hasattr(jax, "spmd_mode"): + @contextlib.contextmanager + def _spmd_mode(*args, **kwargs): + del args, kwargs + yield + jax.spmd_mode = _spmd_mode # pylint: disable=logging-fstring-interpolation diff --git a/recml/core/training/partitioning.py b/recml/core/training/partitioning.py index 4dc3b76..3ba5740 100644 --- a/recml/core/training/partitioning.py +++ b/recml/core/training/partitioning.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Utilities for partitioning.""" import abc @@ -22,7 +23,6 @@ import jax import numpy as np - PyTree = Any State = Any CreateStateFn = Callable[[PyTree], State] @@ -67,7 +67,8 @@ class DataParallelPartitioner(Partitioner): """Data parallel partitioner.""" def __init__(self, data_axis: str = "batch"): - self.mesh = jax.make_mesh((jax.device_count(),), (data_axis,)) + devices = jax.devices() + self.mesh = jax.sharding.Mesh(devices, (data_axis,)) self.data_sharding = jax.sharding.NamedSharding( self.mesh, jax.sharding.PartitionSpec(data_axis) ) @@ -107,7 +108,7 @@ def _shard(x: np.ndarray) -> jax.Array: def partition_init( self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None ) -> CreateStateFn: - with jax.sharding.use_mesh(self.mesh): + with self.mesh: if abstract_batch is not None: abstract_state = jax.eval_shape(init_fn, abstract_batch) specs = nn.get_partition_spec(abstract_state) @@ -117,7 +118,7 @@ def partition_init( init_fn = jax.jit(init_fn, out_shardings=self.state_sharding) def _wrapped_init(batch: PyTree) -> State: - with jax.sharding.use_mesh(self.mesh): + with self.mesh: state = init_fn(batch) state = _maybe_unbox_state(state) return state @@ -130,7 +131,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: jit_kws["out_shardings"] = (self.state_sharding, None) jit_kws["donate_argnums"] = (1,) - with jax.sharding.use_mesh(self.mesh): + with self.mesh: step_fn = jax.jit( fn, in_shardings=(self.data_sharding, self.state_sharding), @@ -138,7 +139,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: ) def _wrapped_step(batch: PyTree, state: State) -> Any: - with jax.sharding.use_mesh(self.mesh): + with self.mesh: return step_fn(batch, state) return _wrapped_step @@ -190,7 +191,7 @@ def __init__( if axis_sizes[0] == -1: axis_sizes[0] = len(devices) // math.prod(axis_sizes[1:]) - self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices) + self.mesh = jax.sharding.Mesh(devices, axis_names) self.rules = rules self.aot_compile = aot_compile self.options = options @@ -213,12 +214,6 @@ def __init__( self.abstract_batch = None self.abstract_state = None - @property - def mesh_context_manager( - self, - ) -> Callable[[jax.sharding.Mesh], ContextManager[None]]: - return jax.sharding.use_mesh - def shard_inputs(self, inputs: PyTree) -> PyTree: def _shard(x: np.ndarray) -> jax.Array: return jax.make_array_from_process_local_data(self.data_sharding, x) @@ -234,7 +229,7 @@ def partition_init( " model parallel partitioner." ) - with self.mesh_context_manager(self.mesh): + with self.mesh: abstract_state = jax.eval_shape(init_fn, abstract_batch) specs = nn.get_partition_spec(abstract_state) @@ -247,7 +242,7 @@ def partition_init( compiled_init_fn = jax.jit(init_fn, out_shardings=state_sharding) def _init(batch: PyTree) -> State: - with self.mesh_context_manager(self.mesh): + with self.mesh: state = compiled_init_fn(batch) state = _maybe_unbox_state(state) return state @@ -265,7 +260,8 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: else: jit_kws["out_shardings"] = None - with self.mesh_context_manager(self.mesh): + + with self.mesh: step_fn = jax.jit( fn, in_shardings=(self.data_sharding, self.state_sharding), @@ -286,7 +282,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: ) def _step(batch: PyTree, state: State) -> Any: - with self.mesh_context_manager(self.mesh): + with self.mesh: return step_fn(batch, state) return _step @@ -302,4 +298,4 @@ def _maybe_unbox(x: Any) -> Any: _maybe_unbox, x, is_leaf=lambda k: isinstance(k, nn.Partitioned), - ) + ) \ No newline at end of file diff --git a/recml/core/training/partitioning_test.py b/recml/core/training/partitioning_test.py index 5fa95c6..ca6901b 100644 --- a/recml/core/training/partitioning_test.py +++ b/recml/core/training/partitioning_test.py @@ -40,7 +40,8 @@ def test_data_parallelism( self, partitioner_cls: type[partitioning.Partitioner] ): if partitioner_cls is partitioning.ModelParallelPartitioner: - kwargs = {"axes": [("data", -1), ("model", 1)], "dp_axes": 1} + devs = np.array(jax.devices()).reshape(-1, 1) + kwargs = {"axes": [("data", -1), ("model", 1)], "dp_axes": 1, "devices": devs} else: kwargs = {} partitioner = partitioner_cls(**kwargs) @@ -112,8 +113,12 @@ def _eval_step( ) def test_model_parallelism(self): + devs = np.array(jax.devices()).reshape(1, -1) + partitioner = partitioning.ModelParallelPartitioner( - axes=[("data", 1), ("model", jax.device_count())], dp_axes=1 + axes=[("data", 1), ("model", jax.device_count())], + dp_axes=1, + devices=devs ) inputs = np.zeros((128, 16), dtype=np.float32) diff --git a/recml/examples/dlrm_experiment.py b/recml/examples/dlrm_experiment.py index 36da20f..eeda133 100644 --- a/recml/examples/dlrm_experiment.py +++ b/recml/examples/dlrm_experiment.py @@ -19,6 +19,12 @@ import dataclasses from typing import Generic, Literal, TypeVar +import sys +import os +# Add the RecML folder to the system path +sys.path.append(os.path.join(os.getcwd(), "../../../RecML")) +os.environ["KERAS_BACKEND"] = "jax" + from etils import epy import fiddle as fdl import flax.linen as nn diff --git a/recml/examples/dlrm_experiment_test.py b/recml/examples/dlrm_experiment_test.py index d4b44c0..07902a4 100644 --- a/recml/examples/dlrm_experiment_test.py +++ b/recml/examples/dlrm_experiment_test.py @@ -13,6 +13,12 @@ # limitations under the License. """Tests for the DLRM experiment.""" +import sys +import os +# Add the RecML folder to the system path +sys.path.append(os.path.join(os.getcwd(), "../../../RecML")) +os.environ["KERAS_BACKEND"] = "jax" + from absl.testing import absltest import fiddle as fdl from fiddle import selectors @@ -32,8 +38,8 @@ def test_dlrm_experiment(self): experiment = dlrm_experiment.experiment() - experiment.task.train_data.global_batch_size = 4 - experiment.task.eval_data.global_batch_size = 4 + experiment.task.train_data.global_batch_size = 128 + experiment.task.eval_data.global_batch_size = 128 experiment.trainer.train_steps = 12 experiment.trainer.steps_per_loop = 4 experiment.trainer.steps_per_eval = 4 diff --git a/recml/examples/train_hstu_jax.py b/recml/examples/train_hstu_jax.py new file mode 100644 index 0000000..d77876b --- /dev/null +++ b/recml/examples/train_hstu_jax.py @@ -0,0 +1,263 @@ +"""HSTU Experiment Configuration using Fiddle and RecML with JaxTrainer""" + +import dataclasses +from typing import Mapping, Tuple +import sys +import os + +os.environ["KERAS_BACKEND"] = "jax" + +import fiddle as fdl +import jax +import jax.numpy as jnp +import keras +import optax +import tensorflow as tf +import clu.metrics as clu_metrics +from absl import app +from absl import flags +from absl import logging + +# Add the RecML folder to the system path +sys.path.append(os.path.join(os.getcwd(), "../../../RecML")) + +# RecML Imports +from recml.core.training import core +from recml.core.training import jax_trainer +from recml.core.training import partitioning +from recml.layers.keras import hstu +import recml + +# Define command-line flags +FLAGS = flags.FLAGS + +flags.DEFINE_string("train_path", None, "Path (or pattern) to training data") +flags.DEFINE_string("eval_path", None, "Path (or glob pattern) to evaluation data") + +flags.DEFINE_string("model_dir", "/tmp/hstu_model_jax", "Where to save the model") +flags.DEFINE_integer("vocab_size", 5_000_000, "Vocabulary size") +flags.DEFINE_integer("train_steps", 2000, "Total training steps") + +# Mark flags as required +flags.mark_flag_as_required("train_path") +flags.mark_flag_as_required("eval_path") + +@dataclasses.dataclass +class HSTUModelConfig: + """Configuration for the HSTU model architecture""" + vocab_size: int = 5_000_000 + max_sequence_length: int = 50 + model_dim: int = 64 + num_heads: int = 4 + num_layers: int = 4 + dropout: float = 0.5 + learning_rate: float = 1e-3 + +class TFRecordDataFactory(recml.Factory[tf.data.Dataset]): + """Reusable Data Factory for TFRecord datasets""" + + path: str + batch_size: int + max_sequence_length: int + feature_key: str = "sequence" + target_key: str = "target" + is_training: bool = True + + def make(self) -> tf.data.Dataset: + """Builds the tf.data.Dataset""" + if not self.path: + logging.warning("No path provided for dataset factory") + return tf.data.Dataset.empty() + + dataset = tf.data.Dataset.list_files(self.path) + dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=tf.data.AUTOTUNE) + + def _parse_fn(serialized_example): + features = { + self.feature_key: tf.io.VarLenFeature(tf.int64), + self.target_key: tf.io.FixedLenFeature([1], tf.int64), + } + parsed = tf.io.parse_single_example(serialized_example, features) + + seq = tf.sparse.to_dense(parsed[self.feature_key]) + padding_needed = self.max_sequence_length - tf.shape(seq)[0] + seq = tf.pad(seq, [[0, padding_needed]]) + seq = tf.ensure_shape(seq, [self.max_sequence_length]) + seq = tf.cast(seq, tf.int32) + + target = tf.cast(parsed[self.target_key], tf.int32) + return seq, target + + dataset = dataset.map(_parse_fn, num_parallel_calls=tf.data.AUTOTUNE) + + if self.is_training: + dataset = dataset.repeat() + + return dataset.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE) + +class HSTUTask(jax_trainer.JaxTask): + """JaxTask for HSTU model""" + + def __init__( + self, + model_config: HSTUModelConfig, + train_data_factory: recml.Factory[tf.data.Dataset], + eval_data_factory: recml.Factory[tf.data.Dataset], + ): + self.config = model_config + self.train_data_factory = train_data_factory + self.eval_data_factory = eval_data_factory + + def create_datasets(self) -> Tuple[tf.data.Dataset, tf.data.Dataset]: + return self.train_data_factory.make(), self.eval_data_factory.make() + + def _create_model(self) -> keras.Model: + inputs = keras.Input( + shape=(self.config.max_sequence_length,), dtype="int32", name="input_ids" + ) + padding_mask = keras.ops.cast(keras.ops.not_equal(inputs, 0), "int32") + + hstu_layer = hstu.HSTU( + vocab_size=self.config.vocab_size, + max_positions=self.config.max_sequence_length, + model_dim=self.config.model_dim, + num_heads=self.config.num_heads, + num_layers=self.config.num_layers, + dropout=self.config.dropout, + ) + + logits = hstu_layer(inputs, padding_mask=padding_mask) + + def get_last_token_logits(args): + seq_logits, mask = args + lengths = keras.ops.sum(keras.ops.cast(mask, "int32"), axis=1) + last_indices = lengths - 1 + indices = keras.ops.expand_dims(keras.ops.expand_dims(last_indices, -1), -1) + return keras.ops.squeeze(keras.ops.take_along_axis(seq_logits, indices, axis=1), axis=1) + + output_logits = keras.layers.Lambda(get_last_token_logits)([logits, padding_mask]) + output_logits = keras.layers.Activation("linear", dtype="float32")(output_logits) + + model = keras.Model(inputs=inputs, outputs=output_logits) + return model + + def create_state(self, batch, rng) -> jax_trainer.KerasState: + inputs, _ = batch + model = self._create_model() + # Build the model to initialize variables + model.build(inputs.shape) + + optimizer = optax.adam(learning_rate=self.config.learning_rate) + return jax_trainer.KerasState.create(model=model, tx=optimizer) + + def train_step( + self, batch, state: jax_trainer.KerasState, rng: jax.Array + ) -> Tuple[jax_trainer.KerasState, Mapping[str, clu_metrics.Metric]]: + inputs, targets = batch + + def loss_fn(tvars): + logits, _ = state.model.stateless_call(tvars, state.ntvars, inputs) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits, jnp.squeeze(targets) + ) + return jnp.mean(loss), logits + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(state.tvars) + state = state.update(grads=grads) + + metrics = self._compute_metrics(loss, logits, targets) + return state, metrics + + def eval_step( + self, batch, state: jax_trainer.KerasState + ) -> Mapping[str, clu_metrics.Metric]: + inputs, targets = batch + logits, _ = state.model.stateless_call(state.tvars, state.ntvars, inputs) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits, jnp.squeeze(targets) + ) + loss = jnp.mean(loss) + return self._compute_metrics(loss, logits, targets) + + def _compute_metrics(self, loss, logits, targets): + targets = jnp.squeeze(targets) + metrics = {"loss": clu_metrics.Average.from_model_output(loss)} + + return metrics + +def experiment() -> fdl.Config[recml.Experiment]: + """Defines the experiment structure using Fiddle configs""" + + max_seq_len = 50 + batch_size = 64 + + model_cfg = fdl.Config( + HSTUModelConfig, + vocab_size=5_000_000, + max_sequence_length=max_seq_len, + model_dim=64, + num_layers=4, + dropout=0.5 + ) + + train_data = fdl.Config( + TFRecordDataFactory, + path="", # Placeholder + batch_size=batch_size, + max_sequence_length=max_seq_len, + is_training=True + ) + + eval_data = fdl.Config( + TFRecordDataFactory, + path="", # Placeholder + batch_size=batch_size, + max_sequence_length=max_seq_len, + is_training=False + ) + + task = fdl.Config( + HSTUTask, + model_config=model_cfg, + train_data_factory=train_data, + eval_data_factory=eval_data + ) + + trainer = fdl.Config( + jax_trainer.JaxTrainer, + partitioner=fdl.Config(partitioning.DataParallelPartitioner), + model_dir="/tmp/default_dir", # Placeholder + train_steps=2000, + steps_per_eval=10, + steps_per_loop=10, + ) + + return fdl.Config(recml.Experiment, task=task, trainer=trainer) + +def main(_): + # Ensure JAX uses the correct backend + logging.info(f"JAX Backend: {jax.default_backend()}") + + config = experiment() + + logging.info(f"Setting Train Path to: {FLAGS.train_path}") + config.task.train_data_factory.path = FLAGS.train_path + + logging.info(f"Setting Eval Path to: {FLAGS.eval_path}") + config.task.eval_data_factory.path = FLAGS.eval_path + + config.task.model_config.vocab_size = FLAGS.vocab_size + + logging.info(f"Setting Model Dir to: {FLAGS.model_dir}") + config.trainer.model_dir = FLAGS.model_dir + config.trainer.train_steps = FLAGS.train_steps + + expt = fdl.build(config) + + logging.info("Starting experiment execution...") + core.run_experiment(expt, core.Experiment.Mode.TRAIN_AND_EVAL) + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/recml/examples/train_hstu_keras.py b/recml/examples/train_hstu_keras.py new file mode 100644 index 0000000..4685908 --- /dev/null +++ b/recml/examples/train_hstu_keras.py @@ -0,0 +1,222 @@ +"""HSTU Experiment Configuration using Fiddle and RecML with KerasTrainer""" + +import dataclasses +from typing import Optional +import sys +import os + +import fiddle as fdl +import keras +import tensorflow as tf +from absl import app +from absl import flags +from absl import logging + +# Add the RecML folder to the system path +sys.path.append(os.path.join(os.getcwd(), "../../../RecML")) + +from recml.core.training import core +from recml.core.training import keras_trainer +from recml.layers.keras import hstu +import recml +import jax +print(jax.devices()) + +# Define command-line flags +FLAGS = flags.FLAGS + +flags.DEFINE_string("train_path", None, "Path (or pattern) to training data") +flags.DEFINE_string("eval_path", None, "Path (or glob pattern) to evaluation data") + +flags.DEFINE_string("model_dir", "/tmp/hstu_model", "Where to save the model") +flags.DEFINE_integer("vocab_size", 5_000_000, "Vocabulary size") +flags.DEFINE_integer("train_steps", 2000, "Total training steps") + +# Mark flags as required +flags.mark_flag_as_required("train_path") +flags.mark_flag_as_required("eval_path") + +@dataclasses.dataclass +class HSTUModelConfig: + """Configuration for the HSTU model architecture""" + vocab_size: int = 5_000_000 + max_sequence_length: int = 50 + model_dim: int = 64 + num_heads: int = 4 + num_layers: int = 4 + dropout: float = 0.5 + learning_rate: float = 1e-3 + +class TFRecordDataFactory(recml.Factory[tf.data.Dataset]): + """Reusable Data Factory for TFRecord datasets""" + + path: str + batch_size: int + max_sequence_length: int + feature_key: str = "sequence" + target_key: str = "target" + is_training: bool = True + + def make(self) -> tf.data.Dataset: + """Builds the tf.data.Dataset""" + if not self.path: + logging.warning("No path provided for dataset factory") + return tf.data.Dataset.empty() + + dataset = tf.data.Dataset.list_files(self.path) + dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=tf.data.AUTOTUNE) + + def _parse_fn(serialized_example): + features = { + self.feature_key: tf.io.VarLenFeature(tf.int64), + self.target_key: tf.io.FixedLenFeature([1], tf.int64), + } + parsed = tf.io.parse_single_example(serialized_example, features) + + seq = tf.sparse.to_dense(parsed[self.feature_key]) + padding_needed = self.max_sequence_length - tf.shape(seq)[0] + seq = tf.pad(seq, [[0, padding_needed]]) + seq = tf.ensure_shape(seq, [self.max_sequence_length]) + seq = tf.cast(seq, tf.int32) + + target = tf.cast(parsed[self.target_key], tf.int32) + return seq, target + + dataset = dataset.map(_parse_fn, num_parallel_calls=tf.data.AUTOTUNE) + + if self.is_training: + dataset = dataset.repeat() + + return dataset.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE) + +class HSTUTask(keras_trainer.KerasTask): + """KerasTask that receives its dependencies via injection""" + + def __init__( + self, + model_config: HSTUModelConfig, + train_data_factory: recml.Factory[tf.data.Dataset], + eval_data_factory: recml.Factory[tf.data.Dataset], + ): + self.config = model_config + self.train_data_factory = train_data_factory + self.eval_data_factory = eval_data_factory + + def create_dataset(self, training: bool) -> tf.data.Dataset: + if training: + return self.train_data_factory.make() + return self.eval_data_factory.make() + + def create_model(self) -> keras.Model: + inputs = keras.Input( + shape=(self.config.max_sequence_length,), dtype="int32", name="input_ids" + ) + padding_mask = keras.ops.cast(keras.ops.not_equal(inputs, 0), "int32") + + hstu_layer = hstu.HSTU( + vocab_size=self.config.vocab_size, + max_positions=self.config.max_sequence_length, + model_dim=self.config.model_dim, + num_heads=self.config.num_heads, + num_layers=self.config.num_layers, + dropout=self.config.dropout, + ) + + logits = hstu_layer(inputs, padding_mask=padding_mask) + + def get_last_token_logits(args): + seq_logits, mask = args + lengths = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1) + last_indices = lengths - 1 + return tf.gather(seq_logits, last_indices, batch_dims=1) + + output_logits = keras.layers.Lambda(get_last_token_logits)([logits, padding_mask]) + output_logits = keras.layers.Activation("linear", dtype="float32")(output_logits) + + model = keras.Model(inputs=inputs, outputs=output_logits) + + model.compile( + optimizer=keras.optimizers.Adam(learning_rate=self.config.learning_rate), + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[ + keras.metrics.SparseTopKCategoricalAccuracy(k=10, name="HR_10"), + keras.metrics.SparseTopKCategoricalAccuracy(k=50, name="HR_50"), + keras.metrics.SparseTopKCategoricalAccuracy(k=200, name="HR_200"), + ], + ) + return model + +def experiment() -> fdl.Config[recml.Experiment]: + """Defines the experiment structure using Fiddle configs""" + + max_seq_len = 50 + batch_size = 128 + + model_cfg = fdl.Config( + HSTUModelConfig, + vocab_size=5_000_000, + max_sequence_length=max_seq_len, + model_dim=64, + num_layers=4, + dropout=0.5 + ) + + train_data = fdl.Config( + TFRecordDataFactory, + path="", # Placeholder + batch_size=batch_size, + max_sequence_length=max_seq_len, + is_training=True + ) + + eval_data = fdl.Config( + TFRecordDataFactory, + path="", # Placeholder + batch_size=batch_size, + max_sequence_length=max_seq_len, + is_training=False + ) + + task = fdl.Config( + HSTUTask, + model_config=model_cfg, + train_data_factory=train_data, + eval_data_factory=eval_data + ) + + trainer = fdl.Config( + keras_trainer.KerasTrainer, + model_dir="/tmp/default_dir", # Placeholder + train_steps=2000, + steps_per_eval=10, + steps_per_loop=10, + ) + + return fdl.Config(recml.Experiment, task=task, trainer=trainer) + +def main(_): + keras.mixed_precision.set_global_policy("mixed_bfloat16") + logging.info("Mixed precision policy set to mixed_bfloat16") + + config = experiment() + + logging.info(f"Setting Train Path to: {FLAGS.train_path}") + config.task.train_data_factory.path = FLAGS.train_path + + logging.info(f"Setting Eval Path to: {FLAGS.eval_path}") + config.task.eval_data_factory.path = FLAGS.eval_path + + config.task.model_config.vocab_size = FLAGS.vocab_size + + logging.info(f"Setting Model Dir to: {FLAGS.model_dir}") + config.trainer.model_dir = FLAGS.model_dir + config.trainer.train_steps = FLAGS.train_steps + + expt = fdl.build(config) + + logging.info("Starting experiment execution...") + core.run_experiment(expt, core.Experiment.Mode.TRAIN_AND_EVAL) + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/recml/layers/linen/sparsecore.py b/recml/layers/linen/sparsecore.py index a908ab8..6496b07 100644 --- a/recml/layers/linen/sparsecore.py +++ b/recml/layers/linen/sparsecore.py @@ -28,10 +28,9 @@ from recml.core.ops import embedding_ops import tensorflow as tf - with epy.lazy_imports(): # pylint: disable=g-import-not-at-top - from jax_tpu_embedding.sparsecore.lib.flax import embed + from jax_tpu_embedding.sparsecore.lib.flax.linen import embed from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec from jax_tpu_embedding.sparsecore.lib.nn import table_stacking @@ -175,7 +174,6 @@ def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array: dataclasses.field(default=lambda n, bs: bs) ) - # Optional device information. local_device_count: int = dataclasses.field( default_factory=jax.local_device_count ) @@ -367,18 +365,9 @@ class SparsecoreEmbed(nn.Module): """ sparsecore_config: SparsecoreConfig - mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None = None - - def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh: - if self.mesh is not None: - return self.mesh - abstract_mesh = jax.sharding.get_abstract_mesh() - if not abstract_mesh.shape_tuple: - raise ValueError( - 'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make' - ' sure to set the mesh when calling the sparsecore module.' - ) - return abstract_mesh + mesh: jax.sharding.Mesh = dataclasses.field( + default_factory=lambda: jax.sharding.Mesh(jax.devices(), ('batch',)) + ) def get_sharding_axis( self, mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh @@ -388,25 +377,25 @@ def get_sharding_axis( return self.sparsecore_config.sharding_axis def setup(self): - mesh = self.get_mesh() - sharding_axis_name = self.get_sharding_axis(mesh) + sharding_axis_name = self.get_sharding_axis(self.mesh) + initializer = functools.partial( embedding.init_embedding_variables, table_specs=embedding.get_table_specs( self.sparsecore_config.feature_specs ), global_sharding=jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec(sharding_axis_name, None) + self.mesh, jax.sharding.PartitionSpec(sharding_axis_name, None) ), num_sparsecore_per_device=self.sparsecore_config.num_sc_per_device, # We need to by-pass the mesh check to allow using an abstract mesh. - bypass_mesh_check=isinstance(mesh, jax.sharding.AbstractMesh), + bypass_mesh_check=isinstance(self.mesh, jax.sharding.AbstractMesh), ) self.embedding_table = self.param( name=EMBEDDING_PARAM_NAME, init_fn=embed.with_sparsecore_layout( - initializer, (sharding_axis_name,), mesh # type: ignore + initializer, (sharding_axis_name,), self.mesh # type: ignore ), ) @@ -423,12 +412,13 @@ def __call__( Returns: The activations structure with the same structure as specs. """ - mesh = self.get_mesh() - sharding_axis_name = self.get_sharding_axis(mesh) + # mesh = self.get_mesh() + sharding_axis_name = self.get_sharding_axis(self.mesh) + activations = embedding_ops.sparsecore_lookup( embedding_ops.SparsecoreParams( feature_specs=self.sparsecore_config.feature_specs, - mesh=mesh, + mesh=self.mesh, data_axes=(sharding_axis_name,), embedding_axes=(sharding_axis_name, None), sharding_strategy=self.sparsecore_config.sharding_strategy, diff --git a/requirements.txt b/requirements.txt index 580d6c9..998ee15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ absl-py==2.2.2 +aiofiles==25.1.0 +array-record==0.8.3 astroid==3.3.9 astunparse==1.6.3 attrs==25.3.0 @@ -16,7 +18,7 @@ etils==1.12.2 fiddle==0.3.0 filelock==3.18.0 flatbuffers==25.2.10 -flax==0.10.5 +flax==0.12.2 fsspec==2025.3.2 gast==0.6.0 google-pasta==0.2.0 @@ -31,18 +33,22 @@ immutabledict==4.2.1 importlib-resources==6.5.2 iniconfig==2.1.0 isort==6.0.1 -jax==0.6.0 -jaxlib==0.6.0 +jax==0.8.2 +jax-tpu-embedding==0.1.0.dev20251208 +jaxlib==0.8.2 jaxtyping==0.3.1 -jinja2==3.1.6 +Jinja2==3.1.6 kagglehub==0.3.11 keras==3.9.2 keras-hub==0.20.0 libclang==18.1.1 libcst==1.7.0 -markdown==3.8 +libtpu==0.0.32 +# libtpu-nightly is usually installed directly via URL, but pinning it helps tracking +# libtpu-nightly==0.1.dev20240617+default +Markdown==3.8 markdown-it-py==3.0.0 -markupsafe==3.0.2 +MarkupSafe==3.0.2 mccabe==0.7.0 mdurl==0.1.2 ml-collections==1.1.0 @@ -54,23 +60,37 @@ nest-asyncio==1.6.0 networkx==3.4.2 nodeenv==1.9.1 numpy==2.1.3 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 opt-einsum==3.4.0 optax==0.2.4 optree==0.15.0 -orbax-checkpoint==0.11.12 +orbax-checkpoint==0.11.31 packaging==24.2 platformdirs==4.3.7 pluggy==1.5.0 +portpicker==1.6.0 pre-commit==4.2.0 promise==2.3 -protobuf==5.29.4 +# protobuf==6.33.4 psutil==7.0.0 pyarrow==19.0.1 -pygments==2.19.1 +Pygments==2.19.1 pylint==3.3.6 pytest==8.3.5 pytest-env==1.1.5 -pyyaml==6.0.2 +PyYAML==6.0.2 regex==2024.11.6 requests==2.32.3 rich==14.0.0 @@ -84,9 +104,10 @@ tensorboard==2.19.0 tensorboard-data-server==0.7.2 tensorflow==2.19.0 tensorflow-datasets==4.9.8 +tensorflow-io-gcs-filesystem==0.37.1 tensorflow-metadata==1.17.1 tensorflow-text==2.19.0 -tensorstore==0.1.73 +tensorstore==0.1.80 termcolor==3.0.1 toml==0.10.2 tomlkit==0.13.2 @@ -94,11 +115,12 @@ toolz==1.0.0 torch==2.6.0 tqdm==4.67.1 treescope==0.1.9 -typing-extensions==4.13.2 +triton==3.2.0 +typing_extensions==4.13.2 urllib3==2.4.0 virtualenv==20.30.0 wadler-lindig==0.1.5 -werkzeug==3.1.3 +Werkzeug==3.1.3 wheel==0.45.1 wrapt==1.17.2 -zipp==3.21.0 +zipp==3.21.0 \ No newline at end of file diff --git a/training.md b/training.md new file mode 100644 index 0000000..cb82670 --- /dev/null +++ b/training.md @@ -0,0 +1,71 @@ +# Model Training Guide + +This guide explains how to set up the environment and train the HSTU/DLRM models on Cloud TPU v6. + +## Option 1: Virtual Environment (Recommended for Dev) + +If you are developing on a TPU VM directly, use a virtual environment to avoid conflicts with the system-level Python packages. + +#### 1. Prerequisites +Ensure you have **Python 3.11+** installed. +```bash +python3 --version +``` + +### 2. Create and Activate Virtual Environment +Run the following from the root of the repository: +```bash +# Create the venv +python3 -m venv venv + +# Activate it +source venv/bin/activate +``` + +### 3. Install Dependencies +```bash +pip install -r requirements.txt +``` +We need to force a specific version of Protobuf to ensure compatibility with our TPU stack. Run this exactly as shown: +```bash +pip install "protobuf>=6.31.1" --no-deps +``` +The `--no-deps` flag is required to prevent pip from downgrading it due to strict dependency pinning in other libraries. + +### 4. Run the Training for DLRM +```bash +python dlrm_experiment_test.py +``` + +## Option 2: Docker (Recommended for Production) + +If you prefer not to manage a virtual environment or want to deploy this as a container, you can build a Docker image. + +## 1. Build the Image +Create a file named `Dockerfile` in the root of the repository: + +```dockerfile +# Use an official Python 3.11 runtime as a parent image +FROM python:3.11-slim + +# Set the working directory +WORKDIR /app + +# Copy the current directory contents into the container +COPY . /app + +# Install system tools if needed (e.g., git) +RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --upgrade pip +RUN pip install -r requirements.txt + +# Force install the specific protobuf version +RUN pip install "protobuf>=6.31.1" --no-deps + +# Default command to run the training script +CMD ["python", "recml/examples/dlrm_experiment_test.py"] +``` + +You can use this dockerfile to run the DLRM model experiment from this repo in your own environment. \ No newline at end of file