[FEATURE] Support non-blocking masked scene reset.#2468
[FEATURE] Support non-blocking masked scene reset.#2468Kashu7100 wants to merge 4 commits intoGenesis-Embodied-AI:mainfrom
Conversation
Allow passing a bool tensor directly to scene.reset(envs_idx=...) in zerocopy mode, eliminating the GPU-blocking nonzero() call that was previously required to convert bool reset buffers to int indices in RL training loops. Changes: - scene._sanitize_envs_idx: short-circuit for bool masks in zerocopy - rigid_solver.set_state: PyTorch tensor ops path for bool masks + masked FK kernel dispatch - collider.clear: dispatch to kernel_masked_collider_clear for bool masks - contact.py: add kernel_masked_collider_clear kernel - constraint/solver.py: dispatch + add constraint_solver_kernel_masked_clear - legacy_coupler.py: dispatch + add _kernel_masked_reset_mpm/sph - Add 4 tests and an example script Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Compares reset latency at 10k envs across different reset fractions. Bool mask path is ~1.7-1.95x faster by eliminating the GPU-CPU sync from nonzero(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
|
| @qd.kernel | ||
| def _kernel_masked_reset_mpm(self, envs_mask: qd.types.ndarray()): | ||
| for i_p, i_g, i_b in qd.ndrange(self.mpm_solver.n_particles, self.rigid_solver.n_geoms, envs_mask.shape[0]): | ||
| if envs_mask[i_b]: | ||
| self.mpm_rigid_normal[i_p, i_g, i_b] = 0.0 | ||
|
|
||
| @qd.kernel | ||
| def _kernel_masked_reset_sph(self, envs_mask: qd.types.ndarray()): | ||
| for i_p, i_g, i_b in qd.ndrange(self.sph_solver.n_particles, self.rigid_solver.n_geoms, envs_mask.shape[0]): | ||
| if envs_mask[i_b]: | ||
| self.sph_rigid_normal[i_p, i_g, i_b] = 0.0 |
There was a problem hiding this comment.
You should rather add a generic kernel_masked_set_zero (with already have kernel_set_zero)
| if qd.static(static_rigid_sim_config.use_hibernation): | ||
| collider_state.n_contacts_hibernated[i_b] = 0 | ||
|
|
||
| # advect hibernated contacts | ||
| for i_c in range(collider_state.n_contacts[i_b]): | ||
| i_la = collider_state.contact_data.link_a[i_c, i_b] | ||
| i_lb = collider_state.contact_data.link_b[i_c, i_b] | ||
|
|
||
| I_la = [i_la, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_la | ||
| I_lb = [i_lb, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_lb | ||
|
|
||
| if (links_state.hibernated[i_la, i_b] and links_info.is_fixed[I_lb]) or ( | ||
| links_state.hibernated[i_lb, i_b] and links_info.is_fixed[I_la] | ||
| ): | ||
| i_c_hibernated = collider_state.n_contacts_hibernated[i_b] | ||
| if i_c != i_c_hibernated: | ||
| # fmt: off | ||
| collider_state.contact_data.geom_a[i_c_hibernated, i_b] = collider_state.contact_data.geom_a[i_c, i_b] | ||
| collider_state.contact_data.geom_b[i_c_hibernated, i_b] = collider_state.contact_data.geom_b[i_c, i_b] | ||
| collider_state.contact_data.penetration[i_c_hibernated, i_b] = collider_state.contact_data.penetration[i_c, i_b] | ||
| collider_state.contact_data.normal[i_c_hibernated, i_b] = collider_state.contact_data.normal[i_c, i_b] | ||
| collider_state.contact_data.pos[i_c_hibernated, i_b] = collider_state.contact_data.pos[i_c, i_b] | ||
| collider_state.contact_data.friction[i_c_hibernated, i_b] = collider_state.contact_data.friction[i_c, i_b] | ||
| collider_state.contact_data.sol_params[i_c_hibernated, i_b] = collider_state.contact_data.sol_params[i_c, i_b] | ||
| collider_state.contact_data.force[i_c_hibernated, i_b] = collider_state.contact_data.force[i_c, i_b] | ||
| collider_state.contact_data.link_a[i_c_hibernated, i_b] = collider_state.contact_data.link_a[i_c, i_b] | ||
| collider_state.contact_data.link_b[i_c_hibernated, i_b] = collider_state.contact_data.link_b[i_c, i_b] | ||
| # fmt: on | ||
|
|
||
| collider_state.n_contacts_hibernated[i_b] = i_c_hibernated + 1 | ||
|
|
||
| # Clear contacts | ||
| for i_c in range(collider_state.n_contacts[i_b]): | ||
| should_clear = True | ||
| if qd.static(static_rigid_sim_config.use_hibernation): | ||
| should_clear = i_c >= collider_state.n_contacts_hibernated[i_b] | ||
| if should_clear: | ||
| collider_state.contact_data.link_a[i_c, i_b] = -1 | ||
| collider_state.contact_data.link_b[i_c, i_b] = -1 | ||
| collider_state.contact_data.geom_a[i_c, i_b] = -1 | ||
| collider_state.contact_data.geom_b[i_c, i_b] = -1 | ||
| collider_state.contact_data.penetration[i_c, i_b] = 0.0 | ||
| collider_state.contact_data.pos[i_c, i_b] = qd.Vector.zero(gs.qd_float, 3) | ||
| collider_state.contact_data.normal[i_c, i_b] = qd.Vector.zero(gs.qd_float, 3) | ||
| collider_state.contact_data.force[i_c, i_b] = qd.Vector.zero(gs.qd_float, 3) | ||
|
|
||
| if qd.static(static_rigid_sim_config.use_hibernation): | ||
| collider_state.n_contacts[i_b] = collider_state.n_contacts_hibernated[i_b] | ||
| else: | ||
| collider_state.n_contacts[i_b] = 0 |
There was a problem hiding this comment.
Move this in a taichi func, so that the core implementation can be shared without duplication between masked and standard implementation
| constraint_state.n_constraints[i_b] = 0 | ||
| constraint_state.n_constraints_equality[i_b] = 0 | ||
| constraint_state.n_constraints_frictionloss[i_b] = 0 | ||
| # Reset dynamic equality count to static count to avoid stale constraints after partial reset | ||
| constraint_state.qd_n_equalities[i_b] = rigid_global_info.n_equalities[None] | ||
| for i_d, i_c in qd.ndrange(n_dofs, len_constraints): | ||
| constraint_state.jac[i_c, i_d, i_b] = 0.0 | ||
| if qd.static(static_rigid_sim_config.sparse_solve): | ||
| for i_c in range(len_constraints): | ||
| constraint_state.jac_n_relevant_dofs[i_c, i_b] = 0 |
There was a problem hiding this comment.
Same. Share code using taichi func to avoid duplication.
| if gs.use_zerocopy and isinstance(envs_idx, torch.Tensor) and envs_idx.dtype == torch.bool: | ||
| # === ZEROCOPY BOOL MASK FAST PATH === |
There was a problem hiding this comment.
This is not the correct pattern. it is more tricky than this:
if isinstance(envs_idx, torch.Tensor) and (not IS_OLD_TORCH or envs_idx.dtype == torch.bool):
if envs_idx.dtype == torch.bool:
is_warmstart.masked_fill_(envs_idx, False)
qacc_ws.masked_fill_(envs_idx[None], 0.0)
else:
is_warmstart.scatter_(0, envs_idx, False)
qacc_ws.scatter_(1, envs_idx[None].expand((qacc_ws.shape[0], -1)), 0.0)
else:
is_warmstart[envs_idx] = False
qacc_ws[:, envs_idx] = 0.0| # Pass bool masks through in zerocopy mode to avoid blocking GPU sync | ||
| if gs.use_zerocopy and isinstance(envs_idx, torch.Tensor) and envs_idx.dtype == torch.bool: | ||
| return envs_idx |
There was a problem hiding this comment.
Remove this. This is not how it works. You should find anyway. This was not needed until now, I see no reason why it would need to be added.
| # RL-style termination: reset envs where box fell below threshold | ||
| box_height = pre_reset_pos[:, 2] | ||
| reset_buf = box_height < 0.05 | ||
| scene.reset(state=init_state, envs_idx=reset_buf) |
There was a problem hiding this comment.
All of this is way too complicated. Just do some random set_qpos followed by some masked reset and checks that resetted values are equal to zero, while the others still have the original value.
There was a problem hiding this comment.
You should also check that the bounding boxes are consistent.
| @pytest.mark.required | ||
| def test_bool_mask_reset_selective(tol): | ||
| """Test that bool mask reset only affects masked environments and leaves others untouched.""" | ||
| n_envs = 4 |
There was a problem hiding this comment.
Remove this parameters. Just use scene.n_envs
| """Test that bool mask reset only affects masked environments and leaves others untouched.""" | ||
| n_envs = 4 | ||
| scene = gs.Scene( | ||
| sim_options=gs.options.SimOptions(dt=0.01), |
| sim_options=gs.options.SimOptions(dt=0.01), | ||
| show_viewer=False, | ||
| ) | ||
| scene.add_entity(gs.morphs.URDF(file="urdf/plane/plane.urdf", fixed=True)) |
| box = scene.add_entity(gs.morphs.Box(size=(0.1, 0.1, 0.1), pos=(0, 0, 0.5))) | ||
| scene.build(n_envs=n_envs) | ||
|
|
||
| init_state = scene.get_state() |
There was a problem hiding this comment.
Remove this benchmark. You should rather update this block to add reset:
Genesis/tests/test_rigid_benchmarks.py
Line 510 in 7726492
There was a problem hiding this comment.
Remove this. Thank you for the effort by I will write a comprehensive documentation for the release.
No description provided.