Skip to content

[FEATURE] Support non-blocking masked scene reset.#2468

Open
Kashu7100 wants to merge 4 commits intoGenesis-Embodied-AI:mainfrom
Kashu7100:feat-scene_reset
Open

[FEATURE] Support non-blocking masked scene reset.#2468
Kashu7100 wants to merge 4 commits intoGenesis-Embodied-AI:mainfrom
Kashu7100:feat-scene_reset

Conversation

@Kashu7100
Copy link
Collaborator

@Kashu7100 Kashu7100 commented Feb 26, 2026

No description provided.

Kashu7100 and others added 4 commits February 25, 2026 19:44
Allow passing a bool tensor directly to scene.reset(envs_idx=...) in
zerocopy mode, eliminating the GPU-blocking nonzero() call that was
previously required to convert bool reset buffers to int indices in RL
training loops.

Changes:
- scene._sanitize_envs_idx: short-circuit for bool masks in zerocopy
- rigid_solver.set_state: PyTorch tensor ops path for bool masks +
  masked FK kernel dispatch
- collider.clear: dispatch to kernel_masked_collider_clear for bool masks
- contact.py: add kernel_masked_collider_clear kernel
- constraint/solver.py: dispatch + add constraint_solver_kernel_masked_clear
- legacy_coupler.py: dispatch + add _kernel_masked_reset_mpm/sph
- Add 4 tests and an example script

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Compares reset latency at 10k envs across different reset fractions.
Bool mask path is ~1.7-1.95x faster by eliminating the GPU-CPU sync
from nonzero().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@github-actions
Copy link

⚠️ Abnormal Benchmark Result Detected ➡️ Report

@Kashu7100 Kashu7100 marked this pull request as ready for review February 26, 2026 16:57
Comment on lines +139 to +149
@qd.kernel
def _kernel_masked_reset_mpm(self, envs_mask: qd.types.ndarray()):
for i_p, i_g, i_b in qd.ndrange(self.mpm_solver.n_particles, self.rigid_solver.n_geoms, envs_mask.shape[0]):
if envs_mask[i_b]:
self.mpm_rigid_normal[i_p, i_g, i_b] = 0.0

@qd.kernel
def _kernel_masked_reset_sph(self, envs_mask: qd.types.ndarray()):
for i_p, i_g, i_b in qd.ndrange(self.sph_solver.n_particles, self.rigid_solver.n_geoms, envs_mask.shape[0]):
if envs_mask[i_b]:
self.sph_rigid_normal[i_p, i_g, i_b] = 0.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should rather add a generic kernel_masked_set_zero (with already have kernel_set_zero)

Comment on lines +134 to +183
if qd.static(static_rigid_sim_config.use_hibernation):
collider_state.n_contacts_hibernated[i_b] = 0

# advect hibernated contacts
for i_c in range(collider_state.n_contacts[i_b]):
i_la = collider_state.contact_data.link_a[i_c, i_b]
i_lb = collider_state.contact_data.link_b[i_c, i_b]

I_la = [i_la, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_la
I_lb = [i_lb, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_lb

if (links_state.hibernated[i_la, i_b] and links_info.is_fixed[I_lb]) or (
links_state.hibernated[i_lb, i_b] and links_info.is_fixed[I_la]
):
i_c_hibernated = collider_state.n_contacts_hibernated[i_b]
if i_c != i_c_hibernated:
# fmt: off
collider_state.contact_data.geom_a[i_c_hibernated, i_b] = collider_state.contact_data.geom_a[i_c, i_b]
collider_state.contact_data.geom_b[i_c_hibernated, i_b] = collider_state.contact_data.geom_b[i_c, i_b]
collider_state.contact_data.penetration[i_c_hibernated, i_b] = collider_state.contact_data.penetration[i_c, i_b]
collider_state.contact_data.normal[i_c_hibernated, i_b] = collider_state.contact_data.normal[i_c, i_b]
collider_state.contact_data.pos[i_c_hibernated, i_b] = collider_state.contact_data.pos[i_c, i_b]
collider_state.contact_data.friction[i_c_hibernated, i_b] = collider_state.contact_data.friction[i_c, i_b]
collider_state.contact_data.sol_params[i_c_hibernated, i_b] = collider_state.contact_data.sol_params[i_c, i_b]
collider_state.contact_data.force[i_c_hibernated, i_b] = collider_state.contact_data.force[i_c, i_b]
collider_state.contact_data.link_a[i_c_hibernated, i_b] = collider_state.contact_data.link_a[i_c, i_b]
collider_state.contact_data.link_b[i_c_hibernated, i_b] = collider_state.contact_data.link_b[i_c, i_b]
# fmt: on

collider_state.n_contacts_hibernated[i_b] = i_c_hibernated + 1

# Clear contacts
for i_c in range(collider_state.n_contacts[i_b]):
should_clear = True
if qd.static(static_rigid_sim_config.use_hibernation):
should_clear = i_c >= collider_state.n_contacts_hibernated[i_b]
if should_clear:
collider_state.contact_data.link_a[i_c, i_b] = -1
collider_state.contact_data.link_b[i_c, i_b] = -1
collider_state.contact_data.geom_a[i_c, i_b] = -1
collider_state.contact_data.geom_b[i_c, i_b] = -1
collider_state.contact_data.penetration[i_c, i_b] = 0.0
collider_state.contact_data.pos[i_c, i_b] = qd.Vector.zero(gs.qd_float, 3)
collider_state.contact_data.normal[i_c, i_b] = qd.Vector.zero(gs.qd_float, 3)
collider_state.contact_data.force[i_c, i_b] = qd.Vector.zero(gs.qd_float, 3)

if qd.static(static_rigid_sim_config.use_hibernation):
collider_state.n_contacts[i_b] = collider_state.n_contacts_hibernated[i_b]
else:
collider_state.n_contacts[i_b] = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this in a taichi func, so that the core implementation can be shared without duplication between masked and standard implementation

Comment on lines +514 to +523
constraint_state.n_constraints[i_b] = 0
constraint_state.n_constraints_equality[i_b] = 0
constraint_state.n_constraints_frictionloss[i_b] = 0
# Reset dynamic equality count to static count to avoid stale constraints after partial reset
constraint_state.qd_n_equalities[i_b] = rigid_global_info.n_equalities[None]
for i_d, i_c in qd.ndrange(n_dofs, len_constraints):
constraint_state.jac[i_c, i_d, i_b] = 0.0
if qd.static(static_rigid_sim_config.sparse_solve):
for i_c in range(len_constraints):
constraint_state.jac_n_relevant_dofs[i_c, i_b] = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. Share code using taichi func to avoid duplication.

Comment on lines +1728 to +1729
if gs.use_zerocopy and isinstance(envs_idx, torch.Tensor) and envs_idx.dtype == torch.bool:
# === ZEROCOPY BOOL MASK FAST PATH ===
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the correct pattern. it is more tricky than this:

            if isinstance(envs_idx, torch.Tensor) and (not IS_OLD_TORCH or envs_idx.dtype == torch.bool):
                if envs_idx.dtype == torch.bool:
                    is_warmstart.masked_fill_(envs_idx, False)
                    qacc_ws.masked_fill_(envs_idx[None], 0.0)
                else:
                    is_warmstart.scatter_(0, envs_idx, False)
                    qacc_ws.scatter_(1, envs_idx[None].expand((qacc_ws.shape[0], -1)), 0.0)
            else:
                is_warmstart[envs_idx] = False
                qacc_ws[:, envs_idx] = 0.0

Comment on lines +1422 to +1424
# Pass bool masks through in zerocopy mode to avoid blocking GPU sync
if gs.use_zerocopy and isinstance(envs_idx, torch.Tensor) and envs_idx.dtype == torch.bool:
return envs_idx
Copy link
Collaborator

@duburcqa duburcqa Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this. This is not how it works. You should find anyway. This was not needed until now, I see no reason why it would need to be added.

# RL-style termination: reset envs where box fell below threshold
box_height = pre_reset_pos[:, 2]
reset_buf = box_height < 0.05
scene.reset(state=init_state, envs_idx=reset_buf)
Copy link
Collaborator

@duburcqa duburcqa Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this is way too complicated. Just do some random set_qpos followed by some masked reset and checks that resetted values are equal to zero, while the others still have the original value.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should also check that the bounding boxes are consistent.

@pytest.mark.required
def test_bool_mask_reset_selective(tol):
"""Test that bool mask reset only affects masked environments and leaves others untouched."""
n_envs = 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this parameters. Just use scene.n_envs

"""Test that bool mask reset only affects masked environments and leaves others untouched."""
n_envs = 4
scene = gs.Scene(
sim_options=gs.options.SimOptions(dt=0.01),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this option.

sim_options=gs.options.SimOptions(dt=0.01),
show_viewer=False,
)
scene.add_entity(gs.morphs.URDF(file="urdf/plane/plane.urdf", fixed=True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this.

box = scene.add_entity(gs.morphs.Box(size=(0.1, 0.1, 0.1), pos=(0, 0, 0.5)))
scene.build(n_envs=n_envs)

init_state = scene.get_state()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this benchmark. You should rather update this block to add reset:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this. Thank you for the effort by I will write a comprehensive documentation for the release.

@duburcqa duburcqa changed the title [FEATURE] support bool mask as envs_idx for non-blocking scene reset [FEATURE] Support non-blocking masked scene reset. Feb 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants