diff --git a/docs/source/overview/sim/solvers/index.rst b/docs/source/overview/sim/solvers/index.rst index 8ffa5570..30ad1043 100644 --- a/docs/source/overview/sim/solvers/index.rst +++ b/docs/source/overview/sim/solvers/index.rst @@ -80,7 +80,9 @@ Choosing a solver - Use analytic solvers (OPW for 6-DOF arms or SRS for 7-DOF arms) when available for speed and determinism. - Use numerical solvers (PyTorch/optimization, Differential) when you need - flexibility.. + flexibility. +- Use the neural IK solver (experimental) when you have a trained checkpoint and need + fast batch inference on a supported robot. See also -------- @@ -94,3 +96,4 @@ See also pinocchio_solver.md opw_solver.md srs_solver.md + neural_ik_solver.md diff --git a/docs/source/overview/sim/solvers/neural_ik_solver.md b/docs/source/overview/sim/solvers/neural_ik_solver.md new file mode 100644 index 00000000..f2690d01 --- /dev/null +++ b/docs/source/overview/sim/solvers/neural_ik_solver.md @@ -0,0 +1,71 @@ +# NeuralIKSolver + +````{admonition} Experimental +:class: warning + +`NeuralIKSolver` is an **experimental** feature. The API, checkpoint format, +and default parameters may change without a deprecation cycle. It is currently +only validated on the **Franka Panda** robot. +```` + +`NeuralIKSolver` is a learning-based inverse kinematics (IK) solver that uses a +trained neural network policy to iteratively solve IK queries. It requires a +pre-trained checkpoint and supports batch processing. + +## Key Features + +* Iterative neural policy inference for IK solving +* Batch processing for multiple target poses simultaneously +* Multi-seed sampling: generate several random initial guesses and return the best solution +* Joint limit enforcement at every iteration +* PyTorch-based — supports both CPU and CUDA devices + +## Configuration + +The solver is configured using the `NeuralIKSolverCfg` class. Pre-trained +checkpoints are hosted on HuggingFace and can be downloaded with +`download_neural_ik_checkpoint()` (requires `HF_TOKEN` environment variable). + +```python +from embodichain.data.assets.solver_assets import download_neural_ik_checkpoint +from embodichain.lab.sim.solvers.neural_ik_solver import NeuralIKSolverCfg + +checkpoint_path = download_neural_ik_checkpoint() + +cfg = NeuralIKSolverCfg( + checkpoint_path=checkpoint_path, + num_arm_joints=7, + max_steps=30, + action_scale=0.2, + hidden_dims=[256, 256], + pos_eps=0.01, + rot_eps=0.1, + num_samples=1, +) +``` + +## Main Methods + +* `get_ik(self, target_xpos, qpos_seed=None, num_samples=None, **kwargs)` + Solve IK for the given target end-effector pose(s). + + **Parameters:** + + `target_xpos` (`torch.Tensor`): Target pose(s) as 4x4 matrix, shape `(4, 4)` or `(B, 4, 4)`. + + `qpos_seed` (`torch.Tensor`, optional): Initial joint positions, shape `(dof,)` or `(B, dof)`. + + `num_samples` (`int`, optional): Override `cfg.num_samples` for this call. + + `return_all_solutions` (`bool`): If `True`, return all sampled solutions with shape `(B, num_samples, dof)`. + + **Returns:** + + `Tuple[torch.Tensor, torch.Tensor]`: + - Success flags, shape `(B,)`. + - Joint positions, shape `(B, 1, dof)` or `(B, num_samples, dof)`. + + **Example:** + +```python + import torch + success, ik_qpos = solver.get_ik(target_xpos=target_pose, qpos_seed=qpos_seed) + print("Success:", success) + print("IK solution:", ik_qpos) +``` + diff --git a/embodichain/data/assets/solver_assets.py b/embodichain/data/assets/solver_assets.py new file mode 100644 index 00000000..a4fb2e65 --- /dev/null +++ b/embodichain/data/assets/solver_assets.py @@ -0,0 +1,89 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +from __future__ import annotations + +import os + +from huggingface_hub import hf_hub_download + +# HuggingFace endpoint. Mirrors (e.g. hf-mirror.com) often redirect to the +# real hub without forwarding the required commit-hash response headers, so we +# default to the canonical endpoint and rely on the system proxy when needed. +_HF_ENDPOINT = "https://huggingface.co" + + +def download_neural_ik_checkpoint( + repo_id: str = "dexforce/neural_ik_solver", + filename: str = "franka.pt", + token: str | None = None, + endpoint: str = _HF_ENDPOINT, +) -> str: + """Download a neural IK solver checkpoint from HuggingFace. + + The repository is gated. Either set the ``HF_TOKEN`` environment variable or + run ``huggingface-cli login`` before calling this function. + + If your network requires an HTTP proxy, set ``HTTPS_PROXY`` or + ``https_proxy`` in the environment before launching Python. + + Args: + repo_id: HuggingFace repository ID. + filename: Checkpoint filename to download. + token: HuggingFace API token. Falls back to the ``HF_TOKEN`` + environment variable or the cached token from + ``huggingface-cli login``. + endpoint: HuggingFace-compatible endpoint URL. Defaults to + ``https://huggingface.co``. Mirrors that merely redirect to the + real hub are not supported. + + Returns: + str: Local path to the downloaded checkpoint file. + + Raises: + RuntimeError: If the download fails, with authentication instructions. + """ + # Normalize proxy env vars: the ``requests`` library on Linux requires the + # lowercase form (``https_proxy``), but users typically export the uppercase + # form (``HTTPS_PROXY``). + https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy") + if https_proxy: + os.environ.setdefault("https_proxy", https_proxy) + os.environ.setdefault("HTTPS_PROXY", https_proxy) + + # Allow callers to pass the token explicitly; otherwise fall back to + # HF_TOKEN (huggingface_hub also reads this automatically, but being + # explicit makes the fallback order transparent). + if token is None: + token = os.environ.get("HF_TOKEN") or None + + try: + return hf_hub_download( + repo_id=repo_id, + filename=filename, + token=token, + endpoint=endpoint, + ) + except Exception as exc: + raise RuntimeError( + f"Failed to download '{filename}' from '{repo_id}'.\n" + "The repository is gated and requires an authenticated HuggingFace account.\n" + "To fix this:\n" + " 1. Accept the model license at https://huggingface.co/dexforce/neural_ik_solver\n" + " 2. Create an access token at https://huggingface.co/settings/tokens\n" + " 3. Export the token: export HF_TOKEN=\n" + " or run: huggingface-cli login\n" + f"Original error: {exc}" + ) from exc diff --git a/embodichain/lab/sim/objects/articulation.py b/embodichain/lab/sim/objects/articulation.py index 6d995e85..15d377b7 100644 --- a/embodichain/lab/sim/objects/articulation.py +++ b/embodichain/lab/sim/objects/articulation.py @@ -1105,6 +1105,10 @@ def set_qpos( else: qpos_set = self.body_data._qpos + if not isinstance(local_env_ids, torch.Tensor): + local_env_ids = torch.as_tensor( + local_env_ids, dtype=torch.long, device=self.device + ) indices = self.body_data.gpu_indices[local_env_ids] qpos_set[local_env_ids[:, None], local_joint_ids] = qpos self._ps.gpu_apply_joint_data( @@ -1181,6 +1185,10 @@ def set_qvel( else: qvel_set = self.body_data._qvel + if not isinstance(local_env_ids, torch.Tensor): + local_env_ids = torch.as_tensor( + local_env_ids, dtype=torch.long, device=self.device + ) indices = self.body_data.gpu_indices[local_env_ids] qvel_set[local_env_ids[:, None], local_joint_ids] = qvel self._ps.gpu_apply_joint_data( diff --git a/embodichain/lab/sim/sim_manager.py b/embodichain/lab/sim/sim_manager.py index 1998192d..757c83c3 100644 --- a/embodichain/lab/sim/sim_manager.py +++ b/embodichain/lab/sim/sim_manager.py @@ -307,7 +307,7 @@ def __init__( if sim_config.headless is False: self._window = self._world.get_windows() - # self._register_default_window_control() + self._register_default_window_control() @classmethod def get_instance(cls, instance_id: int = 0) -> SimulationManager: @@ -550,12 +550,12 @@ def open_window(self) -> None: self._window = self._world.get_windows() # TODO: will open these features after fix the related blocking issues. - # self._register_default_window_control() - # if ( - # self._window_record_hotkey_cfg is not None - # and self._window_record_input_control is None - # ): - # self.enable_window_record_hotkey(**self._window_record_hotkey_cfg) + self._register_default_window_control() + if ( + self._window_record_hotkey_cfg is not None + and self._window_record_input_control is None + ): + self.enable_window_record_hotkey(**self._window_record_hotkey_cfg) self.is_window_opened = True def close_window(self) -> None: diff --git a/embodichain/lab/sim/solvers/__init__.py b/embodichain/lab/sim/solvers/__init__.py index 901ab401..25a932ba 100644 --- a/embodichain/lab/sim/solvers/__init__.py +++ b/embodichain/lab/sim/solvers/__init__.py @@ -21,3 +21,4 @@ from .pink_solver import PinkSolverCfg, PinkSolver from .opw_solver import OPWSolverCfg, OPWSolver from .srs_solver import SRSSolverCfg, SRSSolver +from .neural_ik_solver import NeuralIKSolverCfg, NeuralIKSolver diff --git a/embodichain/lab/sim/solvers/base_solver.py b/embodichain/lab/sim/solvers/base_solver.py index ae04cb41..ed69d9c5 100644 --- a/embodichain/lab/sim/solvers/base_solver.py +++ b/embodichain/lab/sim/solvers/base_solver.py @@ -177,6 +177,14 @@ def __init__(self, cfg: SolverCfg = None, device: str = None, **kwargs): fullgraph=True, dynamic=True, ) + # Warm up on the solver device so Dynamo guards match CUDA/CPU at init + # instead of on the first get_fk call (avoids recompile_limit hits in CI). + if self.dof > 0: + with torch.no_grad(): + warmup_qpos = torch.zeros( + 1, self.dof, device=self.device, dtype=torch.float32 + ) + self.compiled_fk(warmup_qpos) self._init_qpos_limits() diff --git a/embodichain/lab/sim/solvers/neural_ik_solver.py b/embodichain/lab/sim/solvers/neural_ik_solver.py new file mode 100644 index 00000000..7f1cb1d1 --- /dev/null +++ b/embodichain/lab/sim/solvers/neural_ik_solver.py @@ -0,0 +1,302 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +from __future__ import annotations + +import torch +import torch.nn as nn + +from embodichain.utils import configclass +from embodichain.utils.math import ( + convert_quat, + quat_error_magnitude, + quat_from_matrix, +) +from embodichain.lab.sim.solvers import SolverCfg, BaseSolver +from embodichain.lab.sim.solvers.qpos_seed_sampler import QposSeedSampler + +__all__ = ["NeuralIKSolverCfg", "NeuralIKSolver"] + + +@configclass +class NeuralIKSolverCfg(SolverCfg): + """Configuration for the neural network IK solver.""" + + class_type: str = "NeuralIKSolver" + + checkpoint_path: str = "" + """Path to the trained policy checkpoint (.pt file).""" + + max_steps: int = 30 + """Number of policy inference iterations per IK solve.""" + + action_scale: float = 0.2 + """Action scaling factor (radians).""" + + obs_dim: int | None = None + """Observation dimension. If None, auto-computed as ``2 * num_arm_joints + 14``.""" + + num_arm_joints: int = 7 + """Number of arm joints (policy only controls arm, not fingers).""" + + hidden_dims: list[int] = [256, 256] + """Hidden layer dimensions for the MLP policy network.""" + + pos_eps: float = 0.01 + """Position convergence tolerance (meters) for success check.""" + + rot_eps: float = 0.1 + """Rotation convergence tolerance (radians) for success check.""" + + num_samples: int = 1 + """Number of random initial qpos seeds to sample per target pose.""" + + def init_solver( + self, device: torch.device = torch.device("cpu"), **kwargs + ) -> NeuralIKSolver: + if self.obs_dim is None: + self.obs_dim = 2 * self.num_arm_joints + 14 + solver = NeuralIKSolver(cfg=self, device=device, **kwargs) + solver.set_tcp(self._get_tcp_as_numpy()) + return solver + + +def _build_mlp(obs_dim: int, hidden_dims: list[int], action_dim: int) -> nn.Sequential: + """Build an MLP with Tanh activations between hidden layers.""" + layers = [] + in_dim = obs_dim + for h in hidden_dims: + layers.append(nn.Linear(in_dim, h)) + layers.append(nn.Tanh()) + in_dim = h + layers.append(nn.Linear(in_dim, action_dim)) + return nn.Sequential(*layers) + + +class NeuralIKSolver(BaseSolver): + """IK solver using a trained neural network policy. + + Loads a checkpoint containing actor_mean weights and obs_normalizer stats, + then runs iterative inference to solve IK queries. + """ + + def __init__(self, cfg: NeuralIKSolverCfg, device=None, **kwargs): + super().__init__(cfg=cfg, device=device, **kwargs) + + self._max_steps = cfg.max_steps + self._action_scale = cfg.action_scale + self._num_arm_joints = cfg.num_arm_joints + self._pos_eps = cfg.pos_eps + self._rot_eps = cfg.rot_eps + self._num_samples = cfg.num_samples + + ckpt = torch.load( + cfg.checkpoint_path, map_location=self.device, weights_only=False + ) + + if "agent" not in ckpt: + raise KeyError( + f"Checkpoint at '{cfg.checkpoint_path}' is missing 'agent' key. " + f"Available keys: {list(ckpt.keys())}. " + f"Expected a checkpoint from the analytic_policy_gradients training pipeline." + ) + actor_keys = [k for k in ckpt["agent"] if k.startswith("actor_mean.")] + if not actor_keys: + raise KeyError( + f"Checkpoint 'agent' has no 'actor_mean.*' keys. " + f"Available: {list(ckpt['agent'].keys())}." + ) + if "obs_normalizer" not in ckpt: + raise KeyError( + f"Checkpoint at '{cfg.checkpoint_path}' is missing 'obs_normalizer'. " + f"Available keys: {list(ckpt.keys())}." + ) + for subkey in ("mean", "var"): + if subkey not in ckpt["obs_normalizer"]: + raise KeyError( + f"Checkpoint 'obs_normalizer' is missing '{subkey}'. " + f"Available: {list(ckpt['obs_normalizer'].keys())}." + ) + + self.mlp = _build_mlp(cfg.obs_dim, cfg.hidden_dims, cfg.num_arm_joints) + + state_dict = { + k.replace("actor_mean.", ""): v + for k, v in ckpt["agent"].items() + if k.startswith("actor_mean.") + } + self.mlp.load_state_dict(state_dict) + self.mlp.to(self.device).eval() + + self._obs_mean = ckpt["obs_normalizer"]["mean"].to(self.device) + self._obs_var = ckpt["obs_normalizer"]["var"].to(self.device) + + def _normalize_obs(self, obs: torch.Tensor) -> torch.Tensor: + """Normalize observations using stored running mean/var.""" + return (obs - self._obs_mean) / (self._obs_var.sqrt() + 1e-8) + + def _build_obs( + self, + qpos: torch.Tensor, + ee_pos: torch.Tensor, + ee_quat: torch.Tensor, + target_pos: torch.Tensor, + target_quat: torch.Tensor, + last_action: torch.Tensor, + ) -> torch.Tensor: + """Build observation vector: [joint_pos(N), ee_pose(7), target_pose(7), last_action(N)].""" + return torch.cat( + [ + qpos[:, : self._num_arm_joints], + ee_pos, + ee_quat, + target_pos, + target_quat, + last_action, + ], + dim=-1, + ) + + def _run_policy( + self, + qpos: torch.Tensor, + target_xpos: torch.Tensor, + target_pos: torch.Tensor, + target_quat: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Run the iterative neural policy loop and check convergence. + + Args: + qpos: Joint positions, shape (B, dof). Modified in-place. + target_xpos: Target poses, shape (B, 4, 4). + target_pos: Target positions, shape (B, 3). + target_quat: Target quaternions (xyzw), shape (B, 4). + + Returns: + Tuple of (success [B], ik_qpos [B, dof]). + """ + B = qpos.shape[0] + last_action = torch.zeros(B, self._num_arm_joints, device=self.device) + + with torch.no_grad(): + for _ in range(self._max_steps): + ee_xpos = self.get_fk(qpos) + ee_pos = ee_xpos[:, :3, 3] + ee_quat = convert_quat(quat_from_matrix(ee_xpos[:, :3, :3]), to="xyzw") + + obs = self._build_obs( + qpos, ee_pos, ee_quat, target_pos, target_quat, last_action + ) + action = self.mlp(self._normalize_obs(obs)).clamp(-1.0, 1.0) + + qpos[:, : self._num_arm_joints] += action * self._action_scale + qpos[:, : self._num_arm_joints] = torch.clamp( + qpos[:, : self._num_arm_joints], + self.lower_qpos_limits[: self._num_arm_joints], + self.upper_qpos_limits[: self._num_arm_joints], + ) + last_action = action + + # Convergence check + ik_xpos = self.get_fk(qpos) + pos_err = (ik_xpos[:, :3, 3] - target_pos).norm(dim=-1) + ik_quat_wxyz = quat_from_matrix(ik_xpos[:, :3, :3]) + target_quat_wxyz = quat_from_matrix(target_xpos[:, :3, :3]) + rot_err = quat_error_magnitude(target_quat_wxyz, ik_quat_wxyz) + success = (pos_err < self._pos_eps) & (rot_err < self._rot_eps) + + return success, qpos + + def get_ik( + self, + target_xpos: torch.Tensor, + qpos_seed: torch.Tensor | None = None, + num_samples: int | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Solve IK using the trained neural policy. + + Args: + target_xpos: Target pose as 4x4 matrix, shape (4,4) or (B,4,4). + qpos_seed: Initial joint positions, shape (dof,) or (B,dof). + num_samples: Number of random initial seeds per target pose. + Defaults to ``cfg.num_samples`` (1). When > 1, generates + multiple random seeds within joint limits and returns the + solution closest to ``qpos_seed``. + return_all_solutions: If True, return all sampled solutions + with shape (B, num_samples, dof) instead of the closest. + + Returns: + Tuple of (success [B], target_joints [B,1,dof] or [B,num_samples,dof]). + """ + return_all_solutions = kwargs.get("return_all_solutions", False) + + n = num_samples if num_samples is not None else self._num_samples + + target_xpos = torch.as_tensor( + target_xpos, device=self.device, dtype=torch.float32 + ) + if target_xpos.dim() == 2: + target_xpos = target_xpos.unsqueeze(0) + B = target_xpos.shape[0] + + target_pos = target_xpos[:, :3, 3] + target_quat = convert_quat(quat_from_matrix(target_xpos[:, :3, :3]), to="xyzw") + + if qpos_seed is None: + qpos_seed = torch.zeros(B, self.dof, device=self.device) + else: + qpos_seed = torch.as_tensor( + qpos_seed, device=self.device, dtype=torch.float32 + ) + if qpos_seed.dim() == 1: + qpos_seed = qpos_seed.unsqueeze(0).expand(B, -1) + qpos_seed = qpos_seed.clone() + + # Single sample: run directly without QposSeedSampler overhead. + if n <= 1: + success, ik_qpos = self._run_policy( + qpos_seed, target_xpos, target_pos, target_quat + ) + return success, ik_qpos.unsqueeze(1) + + # Multiple samples: use QposSeedSampler for random seeds. + sampler = QposSeedSampler(num_samples=n, dof=self.dof, device=self.device) + all_seeds = sampler.sample( + qpos_seed, self.lower_qpos_limits, self.upper_qpos_limits, B + ) + target_xpos_repeated = sampler.repeat_target_xpos(target_xpos, n) + target_pos_rep = target_xpos_repeated[:, :3, 3] + target_quat_rep = convert_quat( + quat_from_matrix(target_xpos_repeated[:, :3, :3]), to="xyzw" + ) + + success_flat, ik_qpos_flat = self._run_policy( + all_seeds, target_xpos_repeated, target_pos_rep, target_quat_rep + ) + + all_success = success_flat.reshape(B, n) + all_results = ik_qpos_flat.reshape(B, n, self.dof) + + if return_all_solutions: + return all_success.any(dim=1), all_results + + # Pick solution closest to seed. + seed_repeat = qpos_seed.unsqueeze(1).repeat(1, n, 1) + dist = (all_results - seed_repeat).norm(dim=-1) + dist[~all_success] = float("inf") + closest_idx = torch.argmin(dist, dim=1) + closest_qpos = all_results[torch.arange(B, device=self.device), closest_idx] + return all_success.any(dim=1), closest_qpos[:, None, :] diff --git a/embodichain/lab/sim/solvers/srs_solver.py b/embodichain/lab/sim/solvers/srs_solver.py index d68f470b..967e8ec7 100644 --- a/embodichain/lab/sim/solvers/srs_solver.py +++ b/embodichain/lab/sim/solvers/srs_solver.py @@ -1175,7 +1175,7 @@ def __init__(self, cfg: SRSSolverCfg, num_envs: int, device: str, **kwargs): # Compute root base transform fk_dict = self.pk_serial_chain.forward_kinematics( - th=np.zeros(7), end_only=False + th=torch.zeros(7, dtype=torch.float32, device=self.device), end_only=False ) root_tf = fk_dict[list(fk_dict.keys())[0]] self.root_base_xpos = root_tf.get_matrix().cpu().numpy() diff --git a/examples/sim/solvers/neural_ik_solver.py b/examples/sim/solvers/neural_ik_solver.py new file mode 100644 index 00000000..39c59dab --- /dev/null +++ b/examples/sim/solvers/neural_ik_solver.py @@ -0,0 +1,269 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +import argparse +import math +import os +import time + +import numpy as np +import torch +from IPython import embed + +from embodichain.data import get_data_path +from embodichain.data.assets.solver_assets import download_neural_ik_checkpoint +from embodichain.lab.sim.cfg import MarkerCfg, RobotCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="NeuralIKSolver example") + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="Compute device for tensors and the neural IK solver (default: cpu).", + ) + parser.add_argument( + "--num-envs", + type=int, + default=1, + help="Number of parallel environments to simulate. IK is solved for all " + "environments simultaneously at each step (default: 1).", + ) + return parser.parse_args() + + +def _resolve_device(device: str) -> str: + if device == "cuda" and not torch.cuda.is_available(): + raise RuntimeError( + "CUDA was requested but is not available. Use --device cpu or install " + "a CUDA-enabled PyTorch build." + ) + return device + + +def _squeeze_ik_qpos(ik_qpos: torch.Tensor) -> torch.Tensor: + """Normalize IK output to (num_envs, dof).""" + if ik_qpos.dim() == 3: + return ik_qpos[:, 0, :] + return ik_qpos + + +def _pose_with_arena_offset( + pose: torch.Tensor | np.ndarray, arena_offset: torch.Tensor +) -> np.ndarray: + """Convert arena-local 4x4 pose to world frame by adding arena translation.""" + if isinstance(pose, torch.Tensor): + xpos = pose.detach().cpu().numpy() + else: + xpos = np.asarray(pose) + xpos = np.array(xpos, copy=True, dtype=np.float64) + offset = arena_offset.detach().cpu().numpy().reshape(3) + if xpos.ndim == 2: + xpos[:3, 3] += offset + elif xpos.ndim == 3: + xpos[:, :3, 3] += offset + return xpos + + +def main(): + args = parse_args() + np.set_printoptions(precision=5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + sim_device = _resolve_device(args.device) + num_envs = args.num_envs + + config = SimulationManagerCfg( + headless=True, + sim_device=sim_device, + num_envs=num_envs, + arena_space=2.0, + ) + sim = SimulationManager(config) + + urdf = get_data_path("Franka/Panda/PandaWithHand.urdf") + assert os.path.isfile(urdf) + + checkpoint_path = download_neural_ik_checkpoint() + + c = math.cos(-math.pi / 4) + s = math.sin(-math.pi / 4) + tcp = [ + [c, -s, 0.0, 0.0], + [s, c, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.1034], + [0.0, 0.0, 0.0, 1.0], + ] + + cfg_dict = { + "fpath": urdf, + "control_parts": { + "main_arm": [ + "Joint1", + "Joint2", + "Joint3", + "Joint4", + "Joint5", + "Joint6", + "Joint7", + ], + }, + "solver_cfg": { + "main_arm": { + "class_type": "NeuralIKSolver", + "end_link_name": "ee_link", + "root_link_name": "base_link", + "tcp": tcp, + "checkpoint_path": checkpoint_path, + "num_arm_joints": 7, + "max_steps": 30, + "action_scale": 0.2, + "hidden_dims": [256, 256], + "pos_eps": 0.1, + "rot_eps": 0.5, + }, + }, + } + + robot: Robot = sim.add_robot(cfg=RobotCfg.from_dict(cfg_dict)) + + sim.open_window() + + arm_name = "main_arm" + device = robot.device + + seed_qpos = torch.tensor( + [0.0, -np.pi / 4, 0.0, -3 * np.pi / 4, 0.0, np.pi / 2, np.pi / 4], + dtype=torch.float32, + device=device, + ) + qpos = seed_qpos.unsqueeze(0).expand(num_envs, -1).clone() + robot.set_qpos(qpos=qpos, joint_ids=robot.get_joint_ids(arm_name)) + time.sleep(3.0) + + fk_xpos = robot.compute_fk(qpos=qpos, name=arm_name, to_matrix=True) + print(f"fk_xpos shape: {tuple(fk_xpos.shape)}") + + start_pose = fk_xpos.clone() + end_pose = fk_xpos.clone() + + # Per-environment target offsets (cycle if num_envs exceeds preset count) + move_vecs = torch.tensor( + [ + [0.3, 0.4, -0.2], + [0.2, 0.0, 0.0], + [0.0, 0.2, 0.0], + [0.0, -0.2, -0.1], + [-0.2, 0.0, 0.0], + [0.0, -0.2, 0.0], + [0.0, 0.0, -0.15], + [-0.2, 0.2, 0.0], + [0.0, 0.2, -0.15], + ], + dtype=torch.float32, + device=device, + ) + for env_id in range(num_envs): + end_pose[env_id, :3, 3] += move_vecs[env_id % move_vecs.shape[0]] + + num_steps = 50 + interpolated_poses = torch.stack( + [ + torch.lerp(start_pose, end_pose, t) + for t in torch.linspace(0.0, 1.0, num_steps, device=device) + ], + dim=1, + ) + + ik_qpos = qpos.clone() + ik_qpos_results: list[torch.Tensor] = [] + ik_success_flags: list[torch.Tensor] = [] + + print( + f"\nRunning {num_steps} batch IK steps: num_envs={num_envs}, device='{sim_device}' ..." + ) + ik_compute_begin = time.time() + for step in range(num_steps): + poses = interpolated_poses[:, step, :, :] + res, ik_qpos_new = robot.compute_ik( + pose=poses, joint_seed=ik_qpos, name=arm_name + ) + ik_qpos = _squeeze_ik_qpos(ik_qpos_new) + ik_qpos_results.append(ik_qpos.clone()) + ik_success_flags.append(res) + ik_compute_end = time.time() + print( + f"IK compute time for {num_steps} steps and {num_envs} envs: " + f"{ik_compute_end - ik_compute_begin:.4f}s" + ) + + # Draw target and achieved EE axes for each environment (final step) + final_step = num_steps - 1 + final_ik_qpos = ik_qpos_results[final_step] + final_res = ik_success_flags[final_step] + ik_xpos_all = robot.compute_fk(qpos=final_ik_qpos, name=arm_name, to_matrix=True) + arena_offsets = sim.arena_offsets + + for env_id in range(num_envs): + target_axis = _pose_with_arena_offset(end_pose[env_id], arena_offsets[env_id]) + sim.draw_marker( + cfg=MarkerCfg( + name=f"fk_target_env{env_id}", + marker_type="axis", + axis_xpos=target_axis, + axis_size=0.002, + axis_len=0.005, + arena_index=-1, + ) + ) + + if final_res[env_id]: + ik_axis = _pose_with_arena_offset( + ik_xpos_all[env_id], arena_offsets[env_id] + ) + sim.draw_marker( + cfg=MarkerCfg( + name=f"ik_result_env{env_id}", + marker_type="axis", + axis_xpos=ik_axis, + axis_size=0.002, + axis_len=0.005, + arena_index=-1, + ) + ) + + # Animate: batch-apply IK qpos for successful envs, then step simulation + joint_ids = robot.get_joint_ids(arm_name) + for step in range(num_steps): + ik_qpos_step = ik_qpos_results[step] + res = ik_success_flags[step] + if res.any(): + success_ids = res.nonzero(as_tuple=True)[0] + robot.set_qpos( + qpos=ik_qpos_step[success_ids], + joint_ids=joint_ids, + env_ids=success_ids, + ) + sim.update(step=5) + + embed(header="NeuralIKSolver example. Press Ctrl+D to exit.") + + +if __name__ == "__main__": + main() diff --git a/tests/sim/solvers/test_neural_ik_solver.py b/tests/sim/solvers/test_neural_ik_solver.py new file mode 100644 index 00000000..67c8b37d --- /dev/null +++ b/tests/sim/solvers/test_neural_ik_solver.py @@ -0,0 +1,198 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +from __future__ import annotations + +import math +import os + +import numpy as np +import pytest +import torch + +from embodichain.data import get_data_path +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import RobotCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.solvers.neural_ik_solver import _build_mlp +from embodichain.utils.utility import reset_all_seeds + +_c = math.cos(-math.pi / 4) +_s = math.sin(-math.pi / 4) +TCP = [ + [_c, -_s, 0.0, 0.0], + [_s, _c, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.1034], + [0.0, 0.0, 0.0, 1.0], +] + +NUM_ARM_JOINTS = 7 +OBS_DIM = 2 * NUM_ARM_JOINTS + 14 # 28 +HIDDEN_DIMS = [256, 256] + + +def _create_fake_checkpoint(tmp_path) -> str: + """Create a minimal fake checkpoint for testing the solver interface.""" + mlp = _build_mlp(OBS_DIM, HIDDEN_DIMS, NUM_ARM_JOINTS) + ckpt = { + "agent": {f"actor_mean.{k}": v for k, v in mlp.state_dict().items()}, + "obs_normalizer": { + "mean": torch.zeros(OBS_DIM), + "var": torch.ones(OBS_DIM), + }, + } + ckpt_path = str(tmp_path / "fake_neural_ik.pt") + torch.save(ckpt, ckpt_path) + return ckpt_path + + +class TestNeuralIKSolver: + sim: SimulationManager | None = None + robot: Robot | None = None + + def _setup(self, tmp_path): + checkpoint_path = _create_fake_checkpoint(tmp_path) + config = SimulationManagerCfg(headless=True, sim_device="cpu") + self.sim = SimulationManager(config) + + urdf = get_data_path("Franka/Panda/PandaWithHand.urdf") + assert os.path.isfile(urdf) + + cfg_dict = { + "fpath": urdf, + "control_parts": { + "main_arm": [ + "Joint1", + "Joint2", + "Joint3", + "Joint4", + "Joint5", + "Joint6", + "Joint7", + ], + }, + "solver_cfg": { + "main_arm": { + "class_type": "NeuralIKSolver", + "end_link_name": "ee_link", + "root_link_name": "base_link", + "tcp": TCP, + "checkpoint_path": checkpoint_path, + "num_arm_joints": NUM_ARM_JOINTS, + "max_steps": 30, + "action_scale": 0.2, + "hidden_dims": HIDDEN_DIMS, + "pos_eps": 0.1, + "rot_eps": 0.5, + }, + }, + } + + self.robot: Robot = self.sim.add_robot(cfg=RobotCfg.from_dict(cfg_dict)) + self.sim.update(step=100) + + def teardown_method(self): + if self.sim is not None: + self.sim.destroy() + + def _make_solver_input(self): + """Create a standard qpos and its FK target for solver tests.""" + arm_name = "main_arm" + qpos = torch.tensor( + [0.0, -np.pi / 4, 0.0, -3 * np.pi / 4, 0.0, np.pi / 2, np.pi / 4], + dtype=torch.float32, + device=self.robot.device, + ).unsqueeze(0) + target_xpos = self.robot.compute_fk(qpos=qpos, name=arm_name, to_matrix=True) + solver = self.robot.get_solver(arm_name) + return solver, qpos, target_xpos + + def test_ik_interface(self, tmp_path): + """Verify compute_ik returns correct shapes and types.""" + reset_all_seeds(0) + self._setup(tmp_path) + arm_name = "main_arm" + + qpos = torch.tensor( + [0.0, -np.pi / 4, 0.0, -3 * np.pi / 4, 0.0, np.pi / 2, np.pi / 4], + dtype=torch.float32, + device=self.robot.device, + ).unsqueeze(0) + target_xpos = self.robot.compute_fk(qpos=qpos, name=arm_name, to_matrix=True) + + res, ik_qpos = self.robot.compute_ik( + pose=target_xpos, joint_seed=qpos, name=arm_name + ) + + assert res.shape == (1,) + assert res.dtype == torch.bool + dof = qpos.shape[-1] + assert ik_qpos.shape[-1] == dof + + # test for unreachable pose + invalid_pose = torch.tensor( + [ + [ + [1.0, 0.0, 0.0, 10.0], + [0.0, 1.0, 0.0, 10.0], + [0.0, 0.0, 1.0, 10.0], + [0.0, 0.0, 0.0, 1.0], + ] + ], + dtype=torch.float32, + device=self.robot.device, + ) + res, ik_qpos = self.robot.compute_ik( + pose=invalid_pose, joint_seed=qpos, name=arm_name + ) + assert res[0].item() is False + + def test_multi_sample_shape(self, tmp_path): + """Verify output shape when using multiple samples.""" + reset_all_seeds(0) + self._setup(tmp_path) + solver, qpos, target_xpos = self._make_solver_input() + + success, ik_qpos = solver.get_ik( + target_xpos=target_xpos, + qpos_seed=qpos, + num_samples=5, + ) + + dof = qpos.shape[-1] + assert success.shape == (1,) + assert ik_qpos.shape == (1, 1, dof) + + def test_multi_sample_return_all(self, tmp_path): + """Verify return_all_solutions returns all sampled solutions.""" + reset_all_seeds(0) + self._setup(tmp_path) + solver, qpos, target_xpos = self._make_solver_input() + num_samples = 5 + + success, ik_qpos = solver.get_ik( + target_xpos=target_xpos, + qpos_seed=qpos, + num_samples=num_samples, + return_all_solutions=True, + ) + + dof = qpos.shape[-1] + assert success.shape == (1,) + assert ik_qpos.shape == (1, num_samples, dof) + + +if __name__ == "__main__": + np.set_printoptions(precision=5, suppress=True) diff --git a/tests/sim/solvers/test_srs_solver.py b/tests/sim/solvers/test_srs_solver.py index cfd970e0..ada04e84 100644 --- a/tests/sim/solvers/test_srs_solver.py +++ b/tests/sim/solvers/test_srs_solver.py @@ -16,6 +16,8 @@ import os import torch + +torch._dynamo.config.cache_size_limit = 128 # recompile_limit import pytest import numpy as np