Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/source/overview/sim/solvers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ Choosing a solver
- Use analytic solvers (OPW for 6-DOF arms or SRS for 7-DOF arms) when available for speed and
determinism.
- Use numerical solvers (PyTorch/optimization, Differential) when you need
flexibility..
flexibility.
- Use the neural IK solver (experimental) when you have a trained checkpoint and need
fast batch inference on a supported robot.

See also
--------
Expand All @@ -94,3 +96,4 @@ See also
pinocchio_solver.md
opw_solver.md
srs_solver.md
neural_ik_solver.md
71 changes: 71 additions & 0 deletions docs/source/overview/sim/solvers/neural_ik_solver.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# NeuralIKSolver

````{admonition} Experimental
:class: warning

`NeuralIKSolver` is an **experimental** feature. The API, checkpoint format,
and default parameters may change without a deprecation cycle. It is currently
only validated on the **Franka Panda** robot.
````

`NeuralIKSolver` is a learning-based inverse kinematics (IK) solver that uses a
trained neural network policy to iteratively solve IK queries. It requires a
pre-trained checkpoint and supports batch processing.

## Key Features

* Iterative neural policy inference for IK solving
* Batch processing for multiple target poses simultaneously
* Multi-seed sampling: generate several random initial guesses and return the best solution
* Joint limit enforcement at every iteration
* PyTorch-based — supports both CPU and CUDA devices

## Configuration

The solver is configured using the `NeuralIKSolverCfg` class. Pre-trained
checkpoints are hosted on HuggingFace and can be downloaded with
`download_neural_ik_checkpoint()` (requires `HF_TOKEN` environment variable).

```python
from embodichain.data.assets.solver_assets import download_neural_ik_checkpoint
from embodichain.lab.sim.solvers.neural_ik_solver import NeuralIKSolverCfg

checkpoint_path = download_neural_ik_checkpoint()

cfg = NeuralIKSolverCfg(
checkpoint_path=checkpoint_path,
num_arm_joints=7,
max_steps=30,
action_scale=0.2,
hidden_dims=[256, 256],
pos_eps=0.01,
rot_eps=0.1,
num_samples=1,
)
```

## Main Methods

* `get_ik(self, target_xpos, qpos_seed=None, num_samples=None, **kwargs)`
Solve IK for the given target end-effector pose(s).

**Parameters:**
+ `target_xpos` (`torch.Tensor`): Target pose(s) as 4x4 matrix, shape `(4, 4)` or `(B, 4, 4)`.
+ `qpos_seed` (`torch.Tensor`, optional): Initial joint positions, shape `(dof,)` or `(B, dof)`.
+ `num_samples` (`int`, optional): Override `cfg.num_samples` for this call.
+ `return_all_solutions` (`bool`): If `True`, return all sampled solutions with shape `(B, num_samples, dof)`.

**Returns:**
+ `Tuple[torch.Tensor, torch.Tensor]`:
- Success flags, shape `(B,)`.
- Joint positions, shape `(B, 1, dof)` or `(B, num_samples, dof)`.

**Example:**

```python
import torch
success, ik_qpos = solver.get_ik(target_xpos=target_pose, qpos_seed=qpos_seed)
print("Success:", success)
print("IK solution:", ik_qpos)
```

89 changes: 89 additions & 0 deletions embodichain/data/assets/solver_assets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
from __future__ import annotations

import os

from huggingface_hub import hf_hub_download

# HuggingFace endpoint. Mirrors (e.g. hf-mirror.com) often redirect to the
# real hub without forwarding the required commit-hash response headers, so we
# default to the canonical endpoint and rely on the system proxy when needed.
_HF_ENDPOINT = "https://huggingface.co"


def download_neural_ik_checkpoint(
repo_id: str = "dexforce/neural_ik_solver",
filename: str = "franka.pt",
token: str | None = None,
endpoint: str = _HF_ENDPOINT,
) -> str:
"""Download a neural IK solver checkpoint from HuggingFace.

The repository is gated. Either set the ``HF_TOKEN`` environment variable or
run ``huggingface-cli login`` before calling this function.

If your network requires an HTTP proxy, set ``HTTPS_PROXY`` or
``https_proxy`` in the environment before launching Python.

Args:
repo_id: HuggingFace repository ID.
filename: Checkpoint filename to download.
token: HuggingFace API token. Falls back to the ``HF_TOKEN``
environment variable or the cached token from
``huggingface-cli login``.
endpoint: HuggingFace-compatible endpoint URL. Defaults to
``https://huggingface.co``. Mirrors that merely redirect to the
real hub are not supported.

Returns:
str: Local path to the downloaded checkpoint file.

Raises:
RuntimeError: If the download fails, with authentication instructions.
"""
# Normalize proxy env vars: the ``requests`` library on Linux requires the
# lowercase form (``https_proxy``), but users typically export the uppercase
# form (``HTTPS_PROXY``).
https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
if https_proxy:
os.environ.setdefault("https_proxy", https_proxy)
os.environ.setdefault("HTTPS_PROXY", https_proxy)

# Allow callers to pass the token explicitly; otherwise fall back to
# HF_TOKEN (huggingface_hub also reads this automatically, but being
# explicit makes the fallback order transparent).
if token is None:
token = os.environ.get("HF_TOKEN") or None

try:
return hf_hub_download(
repo_id=repo_id,
filename=filename,
token=token,
endpoint=endpoint,
)
except Exception as exc:
raise RuntimeError(
f"Failed to download '{filename}' from '{repo_id}'.\n"
"The repository is gated and requires an authenticated HuggingFace account.\n"
"To fix this:\n"
" 1. Accept the model license at https://huggingface.co/dexforce/neural_ik_solver\n"
" 2. Create an access token at https://huggingface.co/settings/tokens\n"
" 3. Export the token: export HF_TOKEN=<your_token>\n"
" or run: huggingface-cli login\n"
f"Original error: {exc}"
) from exc
8 changes: 8 additions & 0 deletions embodichain/lab/sim/objects/articulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,10 @@ def set_qpos(
else:
qpos_set = self.body_data._qpos

if not isinstance(local_env_ids, torch.Tensor):
local_env_ids = torch.as_tensor(
local_env_ids, dtype=torch.long, device=self.device
)
indices = self.body_data.gpu_indices[local_env_ids]
qpos_set[local_env_ids[:, None], local_joint_ids] = qpos
self._ps.gpu_apply_joint_data(
Expand Down Expand Up @@ -1181,6 +1185,10 @@ def set_qvel(
else:
qvel_set = self.body_data._qvel

if not isinstance(local_env_ids, torch.Tensor):
local_env_ids = torch.as_tensor(
local_env_ids, dtype=torch.long, device=self.device
)
indices = self.body_data.gpu_indices[local_env_ids]
qvel_set[local_env_ids[:, None], local_joint_ids] = qvel
self._ps.gpu_apply_joint_data(
Expand Down
14 changes: 7 additions & 7 deletions embodichain/lab/sim/sim_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def __init__(

if sim_config.headless is False:
self._window = self._world.get_windows()
# self._register_default_window_control()
self._register_default_window_control()

@classmethod
def get_instance(cls, instance_id: int = 0) -> SimulationManager:
Expand Down Expand Up @@ -550,12 +550,12 @@ def open_window(self) -> None:
self._window = self._world.get_windows()

# TODO: will open these features after fix the related blocking issues.
# self._register_default_window_control()
# if (
# self._window_record_hotkey_cfg is not None
# and self._window_record_input_control is None
# ):
# self.enable_window_record_hotkey(**self._window_record_hotkey_cfg)
self._register_default_window_control()
if (
self._window_record_hotkey_cfg is not None
and self._window_record_input_control is None
):
self.enable_window_record_hotkey(**self._window_record_hotkey_cfg)
self.is_window_opened = True

def close_window(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions embodichain/lab/sim/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .pink_solver import PinkSolverCfg, PinkSolver
from .opw_solver import OPWSolverCfg, OPWSolver
from .srs_solver import SRSSolverCfg, SRSSolver
from .neural_ik_solver import NeuralIKSolverCfg, NeuralIKSolver
8 changes: 8 additions & 0 deletions embodichain/lab/sim/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ def __init__(self, cfg: SolverCfg = None, device: str = None, **kwargs):
fullgraph=True,
dynamic=True,
)
# Warm up on the solver device so Dynamo guards match CUDA/CPU at init
# instead of on the first get_fk call (avoids recompile_limit hits in CI).
if self.dof > 0:
with torch.no_grad():
warmup_qpos = torch.zeros(
1, self.dof, device=self.device, dtype=torch.float32
)
self.compiled_fk(warmup_qpos)

self._init_qpos_limits()

Expand Down
Loading
Loading