[tx] Add initial implementation of RayJaxBackend#1418
[tx] Add initial implementation of RayJaxBackend#1418andrewsykim wants to merge 17 commits intoNovaSky-AI:mainfrom
Conversation
f09742e to
63c6ae2
Compare
|
@pcmoritz looking for your high-level feedback on the approach before I test and polish the PR further. |
| ) | ||
|
|
||
| return SkyRLTrainBackend, MegatronBackendOverrides | ||
| elif backend_name == "ray-jax": |
There was a problem hiding this comment.
From a user perspective, I feel like it would be most natural to not have a separate backend, since it is just the jax backend but running on top of Ray. You probably introduced a new backend to be able to nicely handle the dependencies, right?
Another way to do this would be to introduce a ray extra, so the user can run
uv run --extra ray --extra tinker -m skyrl.tinker.api --backend jax
And the ray backend will be chosen if a use_ray flag is set in the backend config.
I would have a slight preference for this, let me know what you think :)
The code organization of making a ray_jax file makes a lot of sense to me.
There was a problem hiding this comment.
makes sense, I added a use_ray to backend config and removed ray-jax backend
|
|
||
| if process_id == 0: | ||
| self.node_ip = ray.util.get_node_ip_address() | ||
| self.port = 7777 |
There was a problem hiding this comment.
I like the fact that with the Ray backend the coordinator_address doesn't need to be configured. It would be better to not hard-code the port though. I could see why it is tricky to auto discover the port (jax doesn't seem to support 0 as the port). If we must hard-code it, it might be worth to introduce an environment variable to be able to override it (e.g. SKYRL_JAX_COORDINATOR_PORT).
There was a problem hiding this comment.
Added a _get_random_port function to assign a random port, we can add the override if the random port doesn't work (can't see why it wouldn't though)
| self.base_model = base_model | ||
| self.config = config.model_copy() | ||
|
|
||
| num_processes = self.config.num_processes or 1 |
There was a problem hiding this comment.
I think it makes most sense to remove the default 1 here and make num_processes required if the Ray backend is used (since it will be multi-node / multi process in most cases).
There was a problem hiding this comment.
updated to raise exception if num_processes is not set
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
…er API and engine now run on driver Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
…point script Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a Ray-based backend for JAX to support multi-node execution, including a new RayJaxBackend implementation and configuration options for Ray actors and placement groups. The feedback focuses on improving the robustness and performance of this new backend. Specifically, it is recommended to explicitly support the ray_jax backend name in the engine logic, validate that the number of placement group bundles matches the process count to prevent index errors, and optimize data transfer by using ray.put() to broadcast batches to workers.
| if backend_name == "jax": | ||
| from skyrl.backends.jax import JaxBackend, JaxBackendConfig | ||
| if use_ray: |
There was a problem hiding this comment.
The backend name ray_jax (or ray-jax) should be explicitly supported in the if condition to match the updated error message and the usage shown in the PR description. Currently, the logic only triggers the Ray backend if the name is exactly "jax" and the use_ray flag is set in the configuration, which contradicts the suggested command-line usage.
| if backend_name == "jax": | |
| from skyrl.backends.jax import JaxBackend, JaxBackendConfig | |
| if use_ray: | |
| if backend_name in ["jax", "ray_jax", "ray-jax"]: | |
| if use_ray or backend_name != "jax": |
There was a problem hiding this comment.
updated the error message
| bundles = self.config.ray_pg_bundles | ||
| if not bundles: | ||
| bundles = [{"CPU": 1}] * num_processes |
There was a problem hiding this comment.
When ray_pg_bundles is provided in the configuration, its length should be validated to ensure it matches num_processes. If the user provides fewer bundles than processes, the actor initialization loop will fail with an index error when trying to access bundles[i] during placement group bundle assignment.
| bundles = self.config.ray_pg_bundles | |
| if not bundles: | |
| bundles = [{"CPU": 1}] * num_processes | |
| bundles = self.config.ray_pg_bundles | |
| if bundles and len(bundles) != num_processes: | |
| raise ValueError(f"Number of bundles in ray_pg_bundles ({len(bundles)}) must match num_processes ({num_processes})") | |
| if not bundles: | |
| bundles = [{"CPU": 1}] * num_processes |
| def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None: | ||
| ray.get([w.save_sampler_checkpoint.remote(output_path, model_id, persist) for w in self.workers]) |
There was a problem hiding this comment.
🔴 Missing probe file write in RayJaxBackend.save_sampler_checkpoint causes concurrent writes on shared filesystems
RayJaxBackend.save_sampler_checkpoint dispatches save_sampler_checkpoint to all Ray actors without first writing a probe file. The existing JaxBackend.save_sampler_checkpoint (skyrl/backends/jax.py:1146-1150) writes a .probe file so that non-coordinator workers can detect a shared filesystem and skip redundant writes (see skyrl/utils/storage.py:29). Without this probe, all Ray actors will concurrently write to the same output_path via pack_and_upload, which performs non-atomic file I/O (skyrl/utils/storage.py:35-39), leading to file corruption on shared filesystems.
Prompt for agents
In RayJaxBackend.save_sampler_checkpoint (skyrl/backends/ray_jax.py:191-192), a probe file is not written before dispatching the save to all Ray actors. The existing JaxBackend (skyrl/backends/jax.py:1146-1150) writes a probe file at output_path.with_name(output_path.name + ".probe") so workers can detect shared filesystems and skip redundant writes (see skyrl/utils/storage.py:28-31 and the pack_and_upload context manager).
The fix should mirror what JaxBackend.save_sampler_checkpoint does: before dispatching to workers, the driver should create the parent directory and write the probe file. Something like:
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.with_name(output_path.name + ".probe").write_text("write_probe")
Note that the RayJaxBackend driver may not have access to the same filesystem as the actors (since it's the Ray driver process), so the correct approach may need to be adapted for the Ray execution model. Consider having the rank-0 actor write the probe before the other actors proceed.
Was this helpful? React with 👍 or 👎 to provide feedback.
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
|
I was able to test this on TPUs using the following commands: To start the tinker API server, engine and JAX workers To run RL training: |
Fixes #1393
This PR introduces the initial implementation of
RayJaxBackendandRayJaxBackendImplto enable running skyrl-tx on a Ray cluster with a singleray job submitcommand. When enabled, the Tinker API and Engine run on the driver/head node, while the actual JAX backend operations are distributed across Ray actors running on worker nodes. This removes the need for manual multi-node orchestration for JAX distributed training.Tested the changes by running the following command on my 4x4 v6e TPU cluster:
Create a train script (copied from skyrl-tx examples)
Run this training script on the driver: