Skip to content

[tx] Add initial implementation of RayJaxBackend#1418

Open
andrewsykim wants to merge 17 commits intoNovaSky-AI:mainfrom
andrewsykim:ray-tx
Open

[tx] Add initial implementation of RayJaxBackend#1418
andrewsykim wants to merge 17 commits intoNovaSky-AI:mainfrom
andrewsykim:ray-tx

Conversation

@andrewsykim
Copy link
Copy Markdown
Contributor

@andrewsykim andrewsykim commented Mar 31, 2026

Fixes #1393

This PR introduces the initial implementation of RayJaxBackend and RayJaxBackendImpl to enable running skyrl-tx on a Ray cluster with a single ray job submit command. When enabled, the Tinker API and Engine run on the driver/head node, while the actual JAX backend operations are distributed across Ray actors running on worker nodes. This removes the need for manual multi-node orchestration for JAX distributed training.

Tested the changes by running the following command on my 4x4 v6e TPU cluster:

ray job submit --runtime-env-json '{"py_executable": "uv run"}' -- sh -c 'cd /home/ray/SkyRL && uv run --extra tpu --extra jax --extra tinker -m skyrl.tinker.api --base-model Qwen/Qwen3-0.6B --backend ray-jax  --backend-config '\''{"ray_pg_bundles": [{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4}], "tensor_parallel_size": 4, "sample_max_num_sequences": 256, "train_micro_batch_size": 8, "fully_sharded_data_parallel_size": 4, "num_processes": 4}'\'''

$ ray job submit -- sh -c $'cd /home/ray/SkyRL && uv run --extra tpu --extra jax --extra tinker -m skyrl.tinker.api --base-model Qwen/Qwen3-8B --backend ray-jax  --backend-config \'{"ray_actor_options": {"resources": {"TPU": 4}}, "ray_pg_bundles": [{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4}], "sample_max_num_sequences": 256, "train_micro_batch_size": 32, "tensor_parallel_size": 4, "fully_sharded_data_parallel_size": 4, "num_processes": 4}\''

Create a train script (copied from skyrl-tx examples)

import tinker
import numpy as np
from tinker import types

# Connect to the local server
service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="tml-dummy")
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
tokenizer = training_client.get_tokenizer()

# Training examples
examples = [
    {"input": "banana split", "output": "anana-bay plit-say"},
    {"input": "quantum physics", "output": "uantum-qay ysics-phay"},
    {"input": "coding wizard", "output": "oding-cay izard-way"},
]

def process_example(example, tokenizer):
    prompt = f"English: {example['input']}\nPig Latin:"
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)

    tokens = prompt_tokens + completion_tokens
    weights = [0] * len(prompt_tokens) + [1] * len(completion_tokens)

    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=tokens[:-1]),
        loss_fn_inputs=dict(weights=weights[1:], target_tokens=tokens[1:])
    )

processed = [process_example(ex, tokenizer) for ex in examples]

# Training loop
for _ in range(6):
    fwdbwd = training_client.forward_backward(processed, "cross_entropy").result()
    training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result()

    logprobs = np.concatenate([o['logprobs'].tolist() for o in fwdbwd.loss_fn_outputs])
    weights = np.concatenate([e.loss_fn_inputs['weights'].tolist() for e in processed])
    print(f"Loss: {-np.dot(logprobs, weights) / weights.sum():.4f}")

Run this training script on the driver:

ray job submit --working-dir . -- sh -c 'cp rl_loop.py /home/ray/SkyRL/ && cd /home/ray/SkyRL && uv run rl_loop.py'

@andrewsykim
Copy link
Copy Markdown
Contributor Author

@pcmoritz looking for your high-level feedback on the approach before I test and polish the PR further.

Comment thread skyrl/tinker/engine.py Outdated
)

return SkyRLTrainBackend, MegatronBackendOverrides
elif backend_name == "ray-jax":
Copy link
Copy Markdown
Collaborator

@pcmoritz pcmoritz Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a user perspective, I feel like it would be most natural to not have a separate backend, since it is just the jax backend but running on top of Ray. You probably introduced a new backend to be able to nicely handle the dependencies, right?

Another way to do this would be to introduce a ray extra, so the user can run

uv run --extra ray --extra tinker -m skyrl.tinker.api --backend jax

And the ray backend will be chosen if a use_ray flag is set in the backend config.

I would have a slight preference for this, let me know what you think :)

The code organization of making a ray_jax file makes a lot of sense to me.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, I added a use_ray to backend config and removed ray-jax backend

Comment thread skyrl/backends/ray_jax.py Outdated

if process_id == 0:
self.node_ip = ray.util.get_node_ip_address()
self.port = 7777
Copy link
Copy Markdown
Collaborator

@pcmoritz pcmoritz Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the fact that with the Ray backend the coordinator_address doesn't need to be configured. It would be better to not hard-code the port though. I could see why it is tricky to auto discover the port (jax doesn't seem to support 0 as the port). If we must hard-code it, it might be worth to introduce an environment variable to be able to override it (e.g. SKYRL_JAX_COORDINATOR_PORT).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a _get_random_port function to assign a random port, we can add the override if the random port doesn't work (can't see why it wouldn't though)

Comment thread skyrl/backends/ray_jax.py Outdated
self.base_model = base_model
self.config = config.model_copy()

num_processes = self.config.num_processes or 1
Copy link
Copy Markdown
Collaborator

@pcmoritz pcmoritz Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes most sense to remove the default 1 here and make num_processes required if the Ray backend is used (since it will be multi-node / multi process in most cases).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to raise exception if num_processes is not set

Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
…er API and engine now run on driver

Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
…point script

Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
@andrewsykim andrewsykim changed the title [WIP][tx] Add initial implementation of RayJaxBackend [tx] Add initial implementation of RayJaxBackend Apr 16, 2026
@andrewsykim andrewsykim marked this pull request as ready for review April 16, 2026 18:41
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Ray-based backend for JAX to support multi-node execution, including a new RayJaxBackend implementation and configuration options for Ray actors and placement groups. The feedback focuses on improving the robustness and performance of this new backend. Specifically, it is recommended to explicitly support the ray_jax backend name in the engine logic, validate that the number of placement group bundles matches the process count to prevent index errors, and optimize data transfer by using ray.put() to broadcast batches to workers.

Comment thread skyrl/tinker/engine.py
Comment on lines 160 to +161
if backend_name == "jax":
from skyrl.backends.jax import JaxBackend, JaxBackendConfig
if use_ray:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The backend name ray_jax (or ray-jax) should be explicitly supported in the if condition to match the updated error message and the usage shown in the PR description. Currently, the logic only triggers the Ray backend if the name is exactly "jax" and the use_ray flag is set in the configuration, which contradicts the suggested command-line usage.

Suggested change
if backend_name == "jax":
from skyrl.backends.jax import JaxBackend, JaxBackendConfig
if use_ray:
if backend_name in ["jax", "ray_jax", "ray-jax"]:
if use_ray or backend_name != "jax":

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the error message

Comment thread skyrl/backends/ray_jax.py
Comment on lines +111 to +113
bundles = self.config.ray_pg_bundles
if not bundles:
bundles = [{"CPU": 1}] * num_processes
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When ray_pg_bundles is provided in the configuration, its length should be validated to ensure it matches num_processes. If the user provides fewer bundles than processes, the actor initialization loop will fail with an index error when trying to access bundles[i] during placement group bundle assignment.

Suggested change
bundles = self.config.ray_pg_bundles
if not bundles:
bundles = [{"CPU": 1}] * num_processes
bundles = self.config.ray_pg_bundles
if bundles and len(bundles) != num_processes:
raise ValueError(f"Number of bundles in ray_pg_bundles ({len(bundles)}) must match num_processes ({num_processes})")
if not bundles:
bundles = [{"CPU": 1}] * num_processes

Comment thread skyrl/backends/ray_jax.py
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 potential issues.

View 4 additional findings in Devin Review.

Open in Devin Review

Comment thread skyrl/backends/ray_jax.py
Comment on lines +191 to +192
def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None:
ray.get([w.save_sampler_checkpoint.remote(output_path, model_id, persist) for w in self.workers])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Missing probe file write in RayJaxBackend.save_sampler_checkpoint causes concurrent writes on shared filesystems

RayJaxBackend.save_sampler_checkpoint dispatches save_sampler_checkpoint to all Ray actors without first writing a probe file. The existing JaxBackend.save_sampler_checkpoint (skyrl/backends/jax.py:1146-1150) writes a .probe file so that non-coordinator workers can detect a shared filesystem and skip redundant writes (see skyrl/utils/storage.py:29). Without this probe, all Ray actors will concurrently write to the same output_path via pack_and_upload, which performs non-atomic file I/O (skyrl/utils/storage.py:35-39), leading to file corruption on shared filesystems.

Prompt for agents
In RayJaxBackend.save_sampler_checkpoint (skyrl/backends/ray_jax.py:191-192), a probe file is not written before dispatching the save to all Ray actors. The existing JaxBackend (skyrl/backends/jax.py:1146-1150) writes a probe file at output_path.with_name(output_path.name + ".probe") so workers can detect shared filesystems and skip redundant writes (see skyrl/utils/storage.py:28-31 and the pack_and_upload context manager).

The fix should mirror what JaxBackend.save_sampler_checkpoint does: before dispatching to workers, the driver should create the parent directory and write the probe file. Something like:
  output_path.parent.mkdir(parents=True, exist_ok=True)
  output_path.with_name(output_path.name + ".probe").write_text("write_probe")

Note that the RayJaxBackend driver may not have access to the same filesystem as the actors (since it's the Ray driver process), so the correct approach may need to be adapted for the Ray execution model. Consider having the rank-0 actor write the probe before the other actors proceed.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment thread skyrl/tinker/engine.py Outdated
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
@andrewsykim
Copy link
Copy Markdown
Contributor Author

I was able to test this on TPUs using the following commands:

To start the tinker API server, engine and JAX workers

ray job submit -- sh -c $'uv run --python 3.12 --extra tpu --extra jax --extra tinker --extra ray -m skyrl.tinker.api --base-model Qwen/Qwen3-0.6B --backend jax  --backend-config \'{"use_ray": true, "ray_actor_options": {"resources": {"TPU": 4}}, "ray_pg_bundles": [{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4}], "sample_max_num_sequences": 256, "train_micro_batch_size": 32, "tensor_parallel_size": 4, "fully_sharded_data_parallel_size": 4, "num_processes": 4}\''

To run RL training:

ray job submit -- sh -c 'cd SkyRL && export TINKER_API_KEY=tml-dummy && uv run --with "tinker-cookbook[math-rl] @ git+https://github.com/thinking-machines-lab/tinker-cookbook.git@nightly" python -m tinker_cookbook.recipes.math_rl.train base_url=http://localhost:8000 model_name="Qwen/Qwen3-0.6B" group_size=4 groups_per_batch=100 learning_rate=1e-4 max_tokens=512 save_every=0'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[tinker] Support using Ray to manage tinker API server, engine and Jax workers

2 participants