Description
SingleReplicaArrayHandler checkpoint restore fails when a multi-device mesh is active via jax.set_mesh(). The jnp.expand_dims and jnp.zeros calls in _globalize_single_replica_arrays are single-device operations, but JAX 0.9's implicit mesh context causes them to target the full mesh.
Environment
- JAX 0.9, multi-host GPU (4 workers × 8 H200 GPUs = 32 devices)
- Orbax 0.11.33
- S3 checkpointing
- Mesh shape: (pp=1, dp=16, tp=2)
- Active mesh set via jax.set_mesh() in training loop
Error
File "orbax/checkpoint/_src/multihost/multislice.py", line 272, in _globalize_single_replica_arrays
source_device_map[s.device] = jnp.expand_dims(s.data, axis=0)
ValueError: Received incompatible devices for jitted computation. Got argument args[0] of
broadcast_in_dim with shape float32[1] and device ids [0] on platform GPU and jit's context
mesh with device ids [0, 1, 2, ..., 31] on platform GPU
Root Cause
https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/multihost/multislice.py#L250-L280
jnp.expand_dims(s.data, axis=0) and jnp.zeros(..., device=d) are per-device ops that pick up the active multi-device mesh context in JAX 0.9, causing a device mismatch.
Suggested Fix
Scope single-device ops to a temporary per-device mesh:
python
if is_source:
for s in inp.addressable_shards:
sd_mesh = jax.sharding.Mesh(np.array([s.device]), ('_single',))
with jax.set_mesh(sd_mesh):
source_device_map[s.device] = jnp.expand_dims(s.data, axis=0)
...
else:
slice_shape = _get_slice_shape(index, global_shape)
sd_mesh = jax.sharding.Mesh(np.array([d]), ('_single',))
with jax.set_mesh(sd_mesh):
zero_data = jnp.zeros(slice_shape, dtype=inp.dtype, device=d)
device_buffers.append(zero_data)
Notes
Questions
- Is this a known issue? Is scoping to a temporary per-device mesh via jax.set_mesh() the right fix here, or is there a preferred approach for isolating single-device ops from the active mesh context?
Description
SingleReplicaArrayHandler checkpoint restore fails when a multi-device mesh is active via jax.set_mesh(). The jnp.expand_dims and jnp.zeros calls in _globalize_single_replica_arrays are single-device operations, but JAX 0.9's implicit mesh context causes them to target the full mesh.
Environment
Error
Root Cause
https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/multihost/multislice.py#L250-L280
jnp.expand_dims(s.data, axis=0) and jnp.zeros(..., device=d) are per-device ops that pick up the active multi-device mesh context in JAX 0.9, causing a device mismatch.
Suggested Fix
Scope single-device ops to a temporary per-device mesh:
Notes
Questions