Skip to content

_globalize_single_replica_arrays fails with "incompatible devices" under active mesh context (JAX 0.9/Orbax 0.11.33) #3164

@evelyn22chen

Description

@evelyn22chen

Description

SingleReplicaArrayHandler checkpoint restore fails when a multi-device mesh is active via jax.set_mesh(). The jnp.expand_dims and jnp.zeros calls in _globalize_single_replica_arrays are single-device operations, but JAX 0.9's implicit mesh context causes them to target the full mesh.

Environment

  • JAX 0.9, multi-host GPU (4 workers × 8 H200 GPUs = 32 devices)
  • Orbax 0.11.33
  • S3 checkpointing
  • Mesh shape: (pp=1, dp=16, tp=2)
  • Active mesh set via jax.set_mesh() in training loop

Error

File "orbax/checkpoint/_src/multihost/multislice.py", line 272, in _globalize_single_replica_arrays
    source_device_map[s.device] = jnp.expand_dims(s.data, axis=0)
ValueError: Received incompatible devices for jitted computation. Got argument args[0] of
broadcast_in_dim with shape float32[1] and device ids [0] on platform GPU and jit's context
mesh with device ids [0, 1, 2, ..., 31] on platform GPU

Root Cause

https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/multihost/multislice.py#L250-L280

jnp.expand_dims(s.data, axis=0) and jnp.zeros(..., device=d) are per-device ops that pick up the active multi-device mesh context in JAX 0.9, causing a device mismatch.

Suggested Fix

Scope single-device ops to a temporary per-device mesh:

python
if is_source:
    for s in inp.addressable_shards:
        sd_mesh = jax.sharding.Mesh(np.array([s.device]), ('_single',))
        with jax.set_mesh(sd_mesh):
            source_device_map[s.device] = jnp.expand_dims(s.data, axis=0)

...

else:
    slice_shape = _get_slice_shape(index, global_shape)
    sd_mesh = jax.sharding.Mesh(np.array([d]), ('_single',))
    with jax.set_mesh(sd_mesh):
        zero_data = jnp.zeros(slice_shape, dtype=inp.dtype, device=d)
    device_buffers.append(zero_data)

Notes

Questions

  • Is this a known issue? Is scoping to a temporary per-device mesh via jax.set_mesh() the right fix here, or is there a preferred approach for isolating single-device ops from the active mesh context?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions