From 6480253652c41a0f38fb331996e7c8dc33b1f744 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Mon, 20 Apr 2026 17:11:34 +0200 Subject: [PATCH 1/4] feat(sim): record and replay MuJoCo state --- pyproject.toml | 1 + python/rcs/__main__.py | 2 + python/rcs/envs/sim.py | 19 ++ python/rcs/sim/sim.py | 29 +++ python/rcs/sim_state_replay.py | 237 +++++++++++++++++++ python/tests/test_sim_state_record_replay.py | 176 ++++++++++++++ scripts/replay_sim_trajectory.py | 9 + 7 files changed, 473 insertions(+) create mode 100644 python/rcs/sim_state_replay.py create mode 100644 python/tests/test_sim_state_record_replay.py create mode 100644 scripts/replay_sim_trajectory.py diff --git a/pyproject.toml b/pyproject.toml index b4f00608..72da8c65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "mujoco==3.2.6", "pin==3.7.0", "greenlet", + "duckdb", ] readme = "README.md" maintainers = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] diff --git a/python/rcs/__main__.py b/python/rcs/__main__.py index 1cc552e4..a481aab1 100644 --- a/python/rcs/__main__.py +++ b/python/rcs/__main__.py @@ -3,8 +3,10 @@ import typer from rcs.envs.storage_wrapper import StorageWrapper +from rcs.sim_state_replay import replay as replay_command app = typer.Typer() +app.command()(replay_command) @app.command() diff --git a/python/rcs/envs/sim.py b/python/rcs/envs/sim.py index 8b3ac8cf..53fb60c3 100644 --- a/python/rcs/envs/sim.py +++ b/python/rcs/envs/sim.py @@ -43,6 +43,25 @@ def reset( return super().reset(seed=seed, options=options) +class SimStateObservationWrapper(ActObsInfoWrapper): + STATE_KEY = "sim_state" + STATE_SPEC_KEY = "sim_state_spec" + STATE_SIZE_KEY = "sim_state_size" + + def __init__(self, env): + super().__init__(env) + assert self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.SIMULATION, "Base environment must be simulation." + self.sim = cast(sim.Sim, self.get_wrapper_attr("sim")) + + def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + observation = dict(observation) + sim_state = self.sim.get_state() + observation[self.STATE_KEY] = sim_state + observation[self.STATE_SPEC_KEY] = self.sim.get_state_spec() + observation[self.STATE_SIZE_KEY] = sim_state.shape[0] + return observation, info + + class GripperWrapperSim(ActObsInfoWrapper): def __init__(self, env): super().__init__(env) diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index 542e7370..0d46773c 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -10,6 +10,7 @@ import mujoco as mj import mujoco.viewer +import numpy as np from rcs._core.sim import GuiClient as _GuiClient from rcs._core.sim import Sim as _Sim from rcs.sim import SimConfig, egl_bootstrap @@ -42,6 +43,8 @@ def gui_loop(gui_uuid: str, close_event): class Sim(_Sim): + STATE_SPEC = mj.mjtState.mjSTATE_INTEGRATION + def __init__(self, mjmdl: str | PathLike, cfg: SimConfig | None = None): mjmdl = Path(mjmdl) if mjmdl.suffix == ".xml": @@ -61,6 +64,32 @@ def __init__(self, mjmdl: str | PathLike, cfg: SimConfig | None = None): if cfg is not None: self.set_config(cfg) + def get_state_spec(self) -> int: + return int(self.STATE_SPEC) + + def get_state_size(self, spec: int | None = None) -> int: + state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) + return mj.mj_stateSize(self.model, state_spec) + + def get_state(self, spec: int | None = None) -> np.ndarray: + state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) + state = np.empty(self.get_state_size(int(state_spec)), dtype=np.float64) + mj.mj_getState(self.model, self.data, state, state_spec) + return state + + def set_state(self, state: np.ndarray, spec: int | None = None): + state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) + state_array = np.asarray(state, dtype=np.float64) + expected_size = self.get_state_size(int(state_spec)) + if state_array.shape != (expected_size,): + msg = ( + f"Expected MuJoCo state with shape ({expected_size},), " + f"got {state_array.shape} for spec {int(state_spec)}." + ) + raise ValueError(msg) + mj.mj_setState(self.model, self.data, state_array, state_spec) + mj.mj_forward(self.model, self.data) + def close_gui(self): if self._stop_event is not None: self._stop_event.set() diff --git a/python/rcs/sim_state_replay.py b/python/rcs/sim_state_replay.py new file mode 100644 index 00000000..f07eef91 --- /dev/null +++ b/python/rcs/sim_state_replay.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Annotated, Any + +import gymnasium as gym +import numpy as np +import pyarrow.compute as pc +import pyarrow.dataset as ds +import typer +from PIL import Image +from rcs.envs.base import ControlMode +from rcs.envs.sim import SimStateObservationWrapper + +import rcs # noqa: F401 + +app = typer.Typer(help="Replay recorded MuJoCo trajectories from a parquet dataset.") + +DATASET_ARGUMENT = typer.Argument(..., exists=True, file_okay=False, dir_okay=True) +ENV_ID_OPTION = typer.Option("rcs/FR3SimplePickUpSim-v0", help="Gymnasium env id used for replay.") +TRAJECTORY_UUID_OPTION = typer.Option(None, help="UUID of the recorded trajectory to replay.") +CAMERA_OPTION = typer.Option([], "--camera", help="Camera names to enable on the replay env.") +RESOLUTION_OPTION = typer.Option((256, 256), help="Replay camera resolution as WIDTH HEIGHT.") +FRAME_RATE_OPTION = typer.Option(0, help="Replay camera frame rate.") +RENDER_MODE_OPTION = typer.Option("human", help="Gym render mode for the replay env.") +CONTROL_MODE_OPTION = typer.Option(ControlMode.CARTESIAN_TRPY.name, help="Control mode name for env creation.") +SLEEP_OPTION = typer.Option(0.0, help="Optional delay between restored states.") +OUTPUT_DIR_OPTION = typer.Option(None, help="Optional directory for re-rendered RGB frames.") +PREFER_DUCKDB_OPTION = typer.Option(True, help="Use duckdb for parquet loading when it is available.") + + +@dataclass(frozen=True) +class RecordedSimStep: + step: int + uuid: str + timestamp: float | None + observation: dict[str, Any] + + @property + def sim_state(self) -> np.ndarray: + return np.asarray(self.observation[SimStateObservationWrapper.STATE_KEY], dtype=np.float64) + + @property + def sim_state_spec(self) -> int: + return int(self.observation.get(SimStateObservationWrapper.STATE_SPEC_KEY, 0)) + + +class DuckDBUnavailableError(RuntimeError): + pass + + +def _get_duckdb_module(): + try: + import duckdb + except ModuleNotFoundError as exc: + msg = ( + "duckdb is required for the preferred parquet read path but is not installed. " + "Install the 'duckdb' Python package or rely on the pyarrow fallback in library calls." + ) + raise DuckDBUnavailableError(msg) from exc + return duckdb + + +def _load_distinct_uuids_with_duckdb(dataset_path: Path) -> list[str]: + duckdb = _get_duckdb_module() + connection = duckdb.connect() + try: + rows = connection.execute( + "SELECT DISTINCT uuid FROM read_parquet(?) ORDER BY uuid", + [str(dataset_path)], + ).fetchall() + finally: + connection.close() + return [row[0] for row in rows] + + +def _load_distinct_uuids_with_pyarrow(dataset_path: Path) -> list[str]: + dataset = ds.dataset(str(dataset_path), format="parquet") + uuids = dataset.to_table(columns=["uuid"])["uuid"] + return sorted(str(uuid) for uuid in pc.unique(uuids).to_pylist()) + + +def list_trajectory_ids(dataset_path: Path, prefer_duckdb: bool = True) -> list[str]: + if prefer_duckdb: + try: + return _load_distinct_uuids_with_duckdb(dataset_path) + except DuckDBUnavailableError: + pass + return _load_distinct_uuids_with_pyarrow(dataset_path) + + +def _load_trajectory_with_duckdb(dataset_path: Path, trajectory_uuid: str) -> list[RecordedSimStep]: + duckdb = _get_duckdb_module() + connection = duckdb.connect() + try: + table = connection.execute( + "SELECT uuid, step, timestamp, obs FROM read_parquet(?) WHERE uuid = ? ORDER BY step", + [str(dataset_path), trajectory_uuid], + ).to_arrow_table() + finally: + connection.close() + return [ + RecordedSimStep( + step=int(row["step"]), + uuid=str(row["uuid"]), + timestamp=float(row["timestamp"]) if row["timestamp"] is not None else None, + observation=row["obs"], + ) + for row in table.to_pylist() + ] + + +def _load_trajectory_with_pyarrow(dataset_path: Path, trajectory_uuid: str) -> list[RecordedSimStep]: + dataset = ds.dataset(str(dataset_path), format="parquet") + table = dataset.to_table(filter=pc.field("uuid") == trajectory_uuid, columns=["uuid", "step", "timestamp", "obs"]) + rows = table.sort_by([("step", "ascending")]).to_pylist() + return [ + RecordedSimStep( + step=int(row["step"]), + uuid=str(row["uuid"]), + timestamp=float(row["timestamp"]) if row["timestamp"] is not None else None, + observation=row["obs"], + ) + for row in rows + ] + + +def load_trajectory(dataset_path: Path, trajectory_uuid: str, prefer_duckdb: bool = True) -> list[RecordedSimStep]: + if prefer_duckdb: + try: + return _load_trajectory_with_duckdb(dataset_path, trajectory_uuid) + except DuckDBUnavailableError: + pass + return _load_trajectory_with_pyarrow(dataset_path, trajectory_uuid) + + +def resolve_trajectory_uuid(dataset_path: Path, trajectory_uuid: str | None, prefer_duckdb: bool = True) -> str: + if trajectory_uuid is not None: + return trajectory_uuid + available_uuids = list_trajectory_ids(dataset_path, prefer_duckdb=prefer_duckdb) + if len(available_uuids) == 1: + return available_uuids[0] + msg = ( + f"Dataset {dataset_path} contains {len(available_uuids)} trajectories. " + f"Pass --trajectory-uuid and choose one of: {available_uuids}" + ) + raise ValueError(msg) + + +def restore_sim_step(env: gym.Env, recorded_step: RecordedSimStep): + sim = env.get_wrapper_attr("sim") + sim.set_state(recorded_step.sim_state, spec=recorded_step.sim_state_spec) + + +def collect_rgb_frames(env: gym.Env) -> dict[str, np.ndarray]: + try: + camera_set = env.get_wrapper_attr("camera_set") + except AttributeError: + return {} + + frameset = camera_set.get_latest_frames() + if frameset is None: + return {} + + rgb_frames: dict[str, np.ndarray] = {} + for camera_name, frame in frameset.frames.items(): + lower_name = camera_name.lower() + if "digit" in lower_name or "tactile" in lower_name: + continue + rgb_frames[camera_name] = np.asarray(frame.camera.color.data) + return rgb_frames + + +def save_rgb_frames(output_dir: Path, recorded_step: RecordedSimStep, rgb_frames: dict[str, np.ndarray]): + output_dir.mkdir(parents=True, exist_ok=True) + for camera_name, rgb_frame in rgb_frames.items(): + Image.fromarray(rgb_frame).save(output_dir / f"step-{recorded_step.step:06d}-{camera_name}.png") + + +def replay_trajectory( + env: gym.Env, + recorded_steps: list[RecordedSimStep], + *, + sleep_s: float = 0.0, + output_dir: Path | None = None, +): + if not recorded_steps: + msg = "No recorded sim states found in the requested trajectory." + raise ValueError(msg) + + env.reset() + for recorded_step in recorded_steps: + restore_sim_step(env, recorded_step) + if output_dir is not None: + save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env)) + if sleep_s > 0: + time.sleep(sleep_s) + + +@app.command() +def replay( + dataset: Annotated[Path, DATASET_ARGUMENT], + env_id: Annotated[str, ENV_ID_OPTION], + trajectory_uuid: Annotated[str | None, TRAJECTORY_UUID_OPTION], + camera: Annotated[list[str], CAMERA_OPTION], + resolution: Annotated[tuple[int, int], RESOLUTION_OPTION], + frame_rate: Annotated[int, FRAME_RATE_OPTION], + render_mode: Annotated[str, RENDER_MODE_OPTION], + control_mode: Annotated[str, CONTROL_MODE_OPTION], + sleep_s: Annotated[float, SLEEP_OPTION], + output_dir: Annotated[Path | None, OUTPUT_DIR_OPTION], + prefer_duckdb: Annotated[bool, PREFER_DUCKDB_OPTION], +): + resolved_uuid = resolve_trajectory_uuid(dataset, trajectory_uuid, prefer_duckdb=prefer_duckdb) + env = gym.make( + env_id, + render_mode=render_mode, + control_mode=ControlMode[control_mode], + resolution=resolution, + frame_rate=frame_rate, + cam_list=camera, + ) + try: + recorded_steps = load_trajectory(dataset, resolved_uuid, prefer_duckdb=prefer_duckdb) + replay_trajectory(env, recorded_steps, sleep_s=sleep_s, output_dir=output_dir) + finally: + env.close() + + typer.echo(f"Replayed {len(recorded_steps)} steps from trajectory {resolved_uuid}.") + if output_dir is not None: + typer.echo(f"Saved re-rendered RGB frames to {output_dir}.") + + +if __name__ == "__main__": + app() diff --git a/python/tests/test_sim_state_record_replay.py b/python/tests/test_sim_state_record_replay.py new file mode 100644 index 00000000..1e974812 --- /dev/null +++ b/python/tests/test_sim_state_record_replay.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import importlib.util +import sys +from dataclasses import dataclass +from pathlib import Path + +import gymnasium as gym +import mujoco as mj +import numpy as np +import pyarrow.dataset as ds +from rcs._core.common import RobotPlatform +from rcs.camera.interface import CameraFrame, DataFrame, Frame, FrameSet +from rcs.envs.storage_wrapper import StorageWrapper + +import rcs + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def _load_local_module(module_name: str, relative_path: str): + module_path = REPO_ROOT / relative_path + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + msg = f"Could not create an import spec for {module_name} from {module_path}." + raise ImportError(msg) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + parent_name, _, child_name = module_name.rpartition(".") + if parent_name: + parent_module = sys.modules[parent_name] + setattr(parent_module, child_name, module) + spec.loader.exec_module(module) + return module + + +local_sim_module = _load_local_module("rcs.sim.sim", "python/rcs/sim/sim.py") +rcs.sim.__dict__["Sim"] = local_sim_module.Sim +_load_local_module("rcs.envs.sim", "python/rcs/envs/sim.py") +_load_local_module("rcs.sim_state_replay", "python/rcs/sim_state_replay.py") + +from rcs.envs.sim import SimStateObservationWrapper # noqa: E402 +from rcs.sim.sim import Sim # noqa: E402 +from rcs.sim_state_replay import ( # noqa: E402 + load_trajectory, + replay_trajectory, + restore_sim_step, +) + +XML = """ + + + + + + + + + +""" + + +@dataclass +class DummyCameraSet: + sim: Sim + + def get_latest_frames(self) -> FrameSet: + color_value = int(np.clip(round((self.sim.data.qpos[0] + 1.0) * 80.0), 0, 255)) + rgb = np.full((8, 8, 3), color_value, dtype=np.uint8) + return FrameSet( + frames={ + "main": Frame( + camera=CameraFrame( + color=DataFrame(data=rgb), + depth=None, + ), + ) + }, + avg_timestamp=None, + ) + + +class DummySimEnv(gym.Env): + PLATFORM = RobotPlatform.SIMULATION + + def __init__(self, sim: Sim, camera_set: DummyCameraSet | None = None): + super().__init__() + self.sim = sim + self.camera_set = camera_set + self.action_space = gym.spaces.Dict( + { + "delta": gym.spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float64), + } + ) + self.observation_space = gym.spaces.Dict( + { + "qpos": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.sim.model.nq,), dtype=np.float64), + "qvel": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.sim.model.nv,), dtype=np.float64), + } + ) + + def _obs(self) -> dict[str, np.ndarray]: + return { + "qpos": self.sim.data.qpos.copy(), + "qvel": self.sim.data.qvel.copy(), + } + + def get_wrapper_attr(self, name: str): + return getattr(self, name) + + def reset(self, *, seed: int | None = None, options: dict | None = None): + super().reset(seed=seed) + mj.mj_resetData(self.sim.model, self.sim.data) + mj.mj_forward(self.sim.model, self.sim.data) + return self._obs(), {} + + def step(self, action: dict[str, np.ndarray]): + self.sim.data.qpos[0] += float(action["delta"][0]) + self.sim.data.qvel[:] = 0.0 + mj.mj_forward(self.sim.model, self.sim.data) + return self._obs(), 0.0, False, False, {} + + def close(self): + return None + + +def test_record_and_replay_sim_state(tmp_path: Path): + model_path = tmp_path / "dummy.xml" + model_path.write_text(XML) + + dataset_path = tmp_path / "dataset" + record_env: gym.Env = DummySimEnv(Sim(model_path)) + record_env = SimStateObservationWrapper(record_env) + record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay", batch_size=1, always_record=True) + + obs, _ = record_env.reset() + assert SimStateObservationWrapper.STATE_KEY in obs + + record_env.step({"delta": np.array([0.125], dtype=np.float64)}) + record_env.close() + + table = ds.dataset(str(dataset_path), format="parquet").to_table().sort_by([("step", "ascending")]) + rows = table.to_pylist() + assert len(rows) == 1 + + recorded_obs = rows[0]["obs"] + assert SimStateObservationWrapper.STATE_KEY in recorded_obs + assert SimStateObservationWrapper.STATE_SPEC_KEY in recorded_obs + assert SimStateObservationWrapper.STATE_SIZE_KEY in recorded_obs + assert ( + len(recorded_obs[SimStateObservationWrapper.STATE_KEY]) + == recorded_obs[SimStateObservationWrapper.STATE_SIZE_KEY] + ) + + recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True) + assert len(recorded_steps) == 1 + assert np.allclose(recorded_steps[0].sim_state, np.asarray(recorded_obs[SimStateObservationWrapper.STATE_KEY])) + + replay_sim = Sim(model_path) + replay_env: gym.Env = DummySimEnv(replay_sim, camera_set=DummyCameraSet(replay_sim)) + replay_env = SimStateObservationWrapper(replay_env) + render_dir = tmp_path / "rendered" + + replay_env.reset() + restore_sim_step(replay_env, recorded_steps[0]) + assert np.allclose( + replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 + ) + assert np.allclose( + replay_env.get_wrapper_attr("sim").data.qvel, np.asarray(recorded_obs["qvel"]), atol=1e-9, rtol=0 + ) + + replay_trajectory(replay_env, recorded_steps, output_dir=render_dir) + + rendered_files = sorted(path.name for path in render_dir.glob("*.png")) + assert rendered_files == ["step-000000-main.png"] diff --git a/scripts/replay_sim_trajectory.py b/scripts/replay_sim_trajectory.py new file mode 100644 index 00000000..04933693 --- /dev/null +++ b/scripts/replay_sim_trajectory.py @@ -0,0 +1,9 @@ +from pathlib import Path +import sys + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "python")) + +from rcs.sim_state_replay import app + +if __name__ == "__main__": + app() From 9a08a3c99e377332efbfe72ff9bc322c5e311c56 Mon Sep 17 00:00:00 2001 From: Pierre Krack Date: Wed, 22 Apr 2026 13:12:48 +0200 Subject: [PATCH 2/4] refactor: review of sim state recording and replay --- python/rcs/envs/sim.py | 2 - python/rcs/sim_state_replay.py | 46 +++++++++++--------- python/tests/test_sim_state_record_replay.py | 44 ++----------------- scripts/replay_sim_trajectory.py | 9 ---- 4 files changed, 28 insertions(+), 73 deletions(-) delete mode 100644 scripts/replay_sim_trajectory.py diff --git a/python/rcs/envs/sim.py b/python/rcs/envs/sim.py index 53fb60c3..9882d22f 100644 --- a/python/rcs/envs/sim.py +++ b/python/rcs/envs/sim.py @@ -46,7 +46,6 @@ def reset( class SimStateObservationWrapper(ActObsInfoWrapper): STATE_KEY = "sim_state" STATE_SPEC_KEY = "sim_state_spec" - STATE_SIZE_KEY = "sim_state_size" def __init__(self, env): super().__init__(env) @@ -58,7 +57,6 @@ def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tupl sim_state = self.sim.get_state() observation[self.STATE_KEY] = sim_state observation[self.STATE_SPEC_KEY] = self.sim.get_state_spec() - observation[self.STATE_SIZE_KEY] = sim_state.shape[0] return observation, info diff --git a/python/rcs/sim_state_replay.py b/python/rcs/sim_state_replay.py index f07eef91..d760a1a4 100644 --- a/python/rcs/sim_state_replay.py +++ b/python/rcs/sim_state_replay.py @@ -19,16 +19,17 @@ app = typer.Typer(help="Replay recorded MuJoCo trajectories from a parquet dataset.") DATASET_ARGUMENT = typer.Argument(..., exists=True, file_okay=False, dir_okay=True) -ENV_ID_OPTION = typer.Option("rcs/FR3SimplePickUpSim-v0", help="Gymnasium env id used for replay.") -TRAJECTORY_UUID_OPTION = typer.Option(None, help="UUID of the recorded trajectory to replay.") -CAMERA_OPTION = typer.Option([], "--camera", help="Camera names to enable on the replay env.") -RESOLUTION_OPTION = typer.Option((256, 256), help="Replay camera resolution as WIDTH HEIGHT.") -FRAME_RATE_OPTION = typer.Option(0, help="Replay camera frame rate.") -RENDER_MODE_OPTION = typer.Option("human", help="Gym render mode for the replay env.") -CONTROL_MODE_OPTION = typer.Option(ControlMode.CARTESIAN_TRPY.name, help="Control mode name for env creation.") -SLEEP_OPTION = typer.Option(0.0, help="Optional delay between restored states.") -OUTPUT_DIR_OPTION = typer.Option(None, help="Optional directory for re-rendered RGB frames.") -PREFER_DUCKDB_OPTION = typer.Option(True, help="Use duckdb for parquet loading when it is available.") + +ENV_ID_OPTION = typer.Option(help="Gymnasium env id used for replay.") +TRAJECTORY_UUID_OPTION = typer.Option(help="UUID of the recorded trajectory to replay.") +CAMERA_OPTION = typer.Option("--camera", help="Camera names to enable on the replay env.") +RESOLUTION_OPTION = typer.Option(help="Replay camera resolution as WIDTH HEIGHT.") +FRAME_RATE_OPTION = typer.Option(help="Replay camera frame rate.") +RENDER_MODE_OPTION = typer.Option(help="Gym render mode for the replay env.") +CONTROL_MODE_OPTION = typer.Option(help="Control mode name for env creation.") +SLEEP_OPTION = typer.Option(help="Optional delay between restored states.") +OUTPUT_DIR_OPTION = typer.Option(help="Optional directory for re-rendered RGB frames.") +PREFER_DUCKDB_OPTION = typer.Option(help="Use duckdb for parquet loading when it is available.") @dataclass(frozen=True) @@ -79,7 +80,7 @@ def _load_distinct_uuids_with_duckdb(dataset_path: Path) -> list[str]: def _load_distinct_uuids_with_pyarrow(dataset_path: Path) -> list[str]: dataset = ds.dataset(str(dataset_path), format="parquet") uuids = dataset.to_table(columns=["uuid"])["uuid"] - return sorted(str(uuid) for uuid in pc.unique(uuids).to_pylist()) + return sorted(str(uuid) for uuid in pc.unique(uuids).to_pylist()) # type: ignore def list_trajectory_ids(dataset_path: Path, prefer_duckdb: bool = True) -> list[str]: @@ -193,6 +194,7 @@ def replay_trajectory( env.reset() for recorded_step in recorded_steps: restore_sim_step(env, recorded_step) + env.get_wrapper_attr("sim").step(1) if output_dir is not None: save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env)) if sleep_s > 0: @@ -202,17 +204,19 @@ def replay_trajectory( @app.command() def replay( dataset: Annotated[Path, DATASET_ARGUMENT], - env_id: Annotated[str, ENV_ID_OPTION], - trajectory_uuid: Annotated[str | None, TRAJECTORY_UUID_OPTION], - camera: Annotated[list[str], CAMERA_OPTION], - resolution: Annotated[tuple[int, int], RESOLUTION_OPTION], - frame_rate: Annotated[int, FRAME_RATE_OPTION], - render_mode: Annotated[str, RENDER_MODE_OPTION], - control_mode: Annotated[str, CONTROL_MODE_OPTION], - sleep_s: Annotated[float, SLEEP_OPTION], - output_dir: Annotated[Path | None, OUTPUT_DIR_OPTION], - prefer_duckdb: Annotated[bool, PREFER_DUCKDB_OPTION], + env_id: Annotated[str, ENV_ID_OPTION] = "rcs/FR3SimplePickUpSim-v0", + trajectory_uuid: Annotated[str | None, TRAJECTORY_UUID_OPTION] = None, + camera: Annotated[list[str] | None, CAMERA_OPTION] = None, + resolution: Annotated[tuple[int, int], RESOLUTION_OPTION] = (256, 256), + frame_rate: Annotated[int, FRAME_RATE_OPTION] = 0, + render_mode: Annotated[str, RENDER_MODE_OPTION] = "human", + control_mode: Annotated[str, CONTROL_MODE_OPTION] = ControlMode.CARTESIAN_TRPY.name, + sleep_s: Annotated[float, SLEEP_OPTION] = 0.0, + output_dir: Annotated[Path | None, OUTPUT_DIR_OPTION] = None, + prefer_duckdb: Annotated[bool, PREFER_DUCKDB_OPTION] = True, ): + if camera is None: + camera = [] resolved_uuid = resolve_trajectory_uuid(dataset, trajectory_uuid, prefer_duckdb=prefer_duckdb) env = gym.make( env_id, diff --git a/python/tests/test_sim_state_record_replay.py b/python/tests/test_sim_state_record_replay.py index 1e974812..94d925b6 100644 --- a/python/tests/test_sim_state_record_replay.py +++ b/python/tests/test_sim_state_record_replay.py @@ -1,7 +1,5 @@ from __future__ import annotations -import importlib.util -import sys from dataclasses import dataclass from pathlib import Path @@ -11,41 +9,10 @@ import pyarrow.dataset as ds from rcs._core.common import RobotPlatform from rcs.camera.interface import CameraFrame, DataFrame, Frame, FrameSet +from rcs.envs.sim import SimStateObservationWrapper from rcs.envs.storage_wrapper import StorageWrapper - -import rcs - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def _load_local_module(module_name: str, relative_path: str): - module_path = REPO_ROOT / relative_path - spec = importlib.util.spec_from_file_location(module_name, module_path) - if spec is None or spec.loader is None: - msg = f"Could not create an import spec for {module_name} from {module_path}." - raise ImportError(msg) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - parent_name, _, child_name = module_name.rpartition(".") - if parent_name: - parent_module = sys.modules[parent_name] - setattr(parent_module, child_name, module) - spec.loader.exec_module(module) - return module - - -local_sim_module = _load_local_module("rcs.sim.sim", "python/rcs/sim/sim.py") -rcs.sim.__dict__["Sim"] = local_sim_module.Sim -_load_local_module("rcs.envs.sim", "python/rcs/envs/sim.py") -_load_local_module("rcs.sim_state_replay", "python/rcs/sim_state_replay.py") - -from rcs.envs.sim import SimStateObservationWrapper # noqa: E402 -from rcs.sim.sim import Sim # noqa: E402 -from rcs.sim_state_replay import ( # noqa: E402 - load_trajectory, - replay_trajectory, - restore_sim_step, -) +from rcs.sim.sim import Sim +from rcs.sim_state_replay import load_trajectory, replay_trajectory, restore_sim_step XML = """ @@ -146,11 +113,6 @@ def test_record_and_replay_sim_state(tmp_path: Path): recorded_obs = rows[0]["obs"] assert SimStateObservationWrapper.STATE_KEY in recorded_obs assert SimStateObservationWrapper.STATE_SPEC_KEY in recorded_obs - assert SimStateObservationWrapper.STATE_SIZE_KEY in recorded_obs - assert ( - len(recorded_obs[SimStateObservationWrapper.STATE_KEY]) - == recorded_obs[SimStateObservationWrapper.STATE_SIZE_KEY] - ) recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True) assert len(recorded_steps) == 1 diff --git a/scripts/replay_sim_trajectory.py b/scripts/replay_sim_trajectory.py deleted file mode 100644 index 04933693..00000000 --- a/scripts/replay_sim_trajectory.py +++ /dev/null @@ -1,9 +0,0 @@ -from pathlib import Path -import sys - -sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "python")) - -from rcs.sim_state_replay import app - -if __name__ == "__main__": - app() From b1c5fccfac08fb0b005187470de8cb2ee3b9f2d0 Mon Sep 17 00:00:00 2001 From: Pierre Krack Date: Thu, 23 Apr 2026 18:55:49 +0200 Subject: [PATCH 3/4] sim recording test, info cannot be empty with storage wrapper --- python/tests/test_sim_state_record_replay.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tests/test_sim_state_record_replay.py b/python/tests/test_sim_state_record_replay.py index 94d925b6..67255f8f 100644 --- a/python/tests/test_sim_state_record_replay.py +++ b/python/tests/test_sim_state_record_replay.py @@ -79,13 +79,13 @@ def reset(self, *, seed: int | None = None, options: dict | None = None): super().reset(seed=seed) mj.mj_resetData(self.sim.model, self.sim.data) mj.mj_forward(self.sim.model, self.sim.data) - return self._obs(), {} + return self._obs(), {"dummy": True} def step(self, action: dict[str, np.ndarray]): self.sim.data.qpos[0] += float(action["delta"][0]) self.sim.data.qvel[:] = 0.0 mj.mj_forward(self.sim.model, self.sim.data) - return self._obs(), 0.0, False, False, {} + return self._obs(), 0.0, False, False, {"dummy": True} def close(self): return None @@ -98,12 +98,12 @@ def test_record_and_replay_sim_state(tmp_path: Path): dataset_path = tmp_path / "dataset" record_env: gym.Env = DummySimEnv(Sim(model_path)) record_env = SimStateObservationWrapper(record_env) - record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay", batch_size=1, always_record=True) - + record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay") obs, _ = record_env.reset() + record_env.start_record() assert SimStateObservationWrapper.STATE_KEY in obs - record_env.step({"delta": np.array([0.125], dtype=np.float64)}) + record_env.stop_record() record_env.close() table = ds.dataset(str(dataset_path), format="parquet").to_table().sort_by([("step", "ascending")]) From e83535ba9d123190e68528b1e49a7626a6c16210 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Thu, 23 Apr 2026 23:12:02 +0200 Subject: [PATCH 4/4] fix: sync gui without advancing replay state --- python/rcs/_core/sim.pyi | 1 + python/rcs/sim_state_replay.py | 3 +- python/tests/test_sim_state_record_replay.py | 59 +++++++++++++++++++- src/pybind/rcs.cpp | 1 + src/sim/gui.h | 2 + src/sim/gui_server.cpp | 14 +++-- src/sim/sim.cpp | 6 ++ src/sim/sim.h | 1 + 8 files changed, 81 insertions(+), 6 deletions(-) diff --git a/python/rcs/_core/sim.pyi b/python/rcs/_core/sim.pyi index 8d8ff075..fa5a8899 100644 --- a/python/rcs/_core/sim.pyi +++ b/python/rcs/_core/sim.pyi @@ -99,6 +99,7 @@ class Sim: def set_config(self, cfg: SimConfig) -> bool: ... def step(self, k: int) -> None: ... def step_until_convergence(self) -> None: ... + def sync_gui(self) -> None: ... class SimCameraConfig(rcs._core.common.BaseCameraConfig): type: CameraType diff --git a/python/rcs/sim_state_replay.py b/python/rcs/sim_state_replay.py index d760a1a4..7f5e6a32 100644 --- a/python/rcs/sim_state_replay.py +++ b/python/rcs/sim_state_replay.py @@ -192,11 +192,12 @@ def replay_trajectory( raise ValueError(msg) env.reset() + sim = env.get_wrapper_attr("sim") for recorded_step in recorded_steps: restore_sim_step(env, recorded_step) - env.get_wrapper_attr("sim").step(1) if output_dir is not None: save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env)) + sim.sync_gui() if sleep_s > 0: time.sleep(sleep_s) diff --git a/python/tests/test_sim_state_record_replay.py b/python/tests/test_sim_state_record_replay.py index 67255f8f..78abf1f3 100644 --- a/python/tests/test_sim_state_record_replay.py +++ b/python/tests/test_sim_state_record_replay.py @@ -12,7 +12,12 @@ from rcs.envs.sim import SimStateObservationWrapper from rcs.envs.storage_wrapper import StorageWrapper from rcs.sim.sim import Sim -from rcs.sim_state_replay import load_trajectory, replay_trajectory, restore_sim_step +from rcs.sim_state_replay import ( + RecordedSimStep, + load_trajectory, + replay_trajectory, + restore_sim_step, +) XML = """ @@ -91,6 +96,58 @@ def close(self): return None +class SpySim: + def __init__(self): + self.states: list[tuple[np.ndarray, int | None]] = [] + self.sync_calls = 0 + + def set_state(self, state: np.ndarray, spec: int | None = None): + self.states.append((np.asarray(state, dtype=np.float64), spec)) + + def sync_gui(self): + self.sync_calls += 1 + + +class SpyReplayEnv(gym.Env): + PLATFORM = RobotPlatform.SIMULATION + + def __init__(self, sim: SpySim): + super().__init__() + self.sim = sim + self.reset_calls = 0 + + def get_wrapper_attr(self, name: str): + return getattr(self, name) + + def reset(self, *, seed: int | None = None, options: dict | None = None): + self.reset_calls += 1 + return {}, {} + + +def test_replay_trajectory_syncs_gui_without_stepping(): + spy_sim = SpySim() + env = SpyReplayEnv(spy_sim) + recorded_steps = [ + RecordedSimStep( + step=3, + uuid="traj-1", + timestamp=1.23, + observation={ + SimStateObservationWrapper.STATE_KEY: np.array([1.0, 2.0, 3.0], dtype=np.float64), + SimStateObservationWrapper.STATE_SPEC_KEY: 7, + }, + ) + ] + + replay_trajectory(env, recorded_steps) + + assert env.reset_calls == 1 + assert len(spy_sim.states) == 1 + np.testing.assert_allclose(spy_sim.states[0][0], np.array([1.0, 2.0, 3.0], dtype=np.float64)) + assert spy_sim.states[0][1] == 7 + assert spy_sim.sync_calls == 1 + + def test_record_and_replay_sim_state(tmp_path: Path): model_path = tmp_path / "dummy.xml" model_path.write_text(XML) diff --git a/src/pybind/rcs.cpp b/src/pybind/rcs.cpp index f5d37520..ab7772c0 100644 --- a/src/pybind/rcs.cpp +++ b/src/pybind/rcs.cpp @@ -742,6 +742,7 @@ PYBIND11_MODULE(_core, m) { .def("get_config", &rcs::sim::Sim::get_config) .def("step", &rcs::sim::Sim::step, py::arg("k")) .def("reset", &rcs::sim::Sim::reset) + .def("sync_gui", &rcs::sim::Sim::sync_gui) .def("_start_gui_server", &rcs::sim::Sim::start_gui_server, py::arg("id")) .def("_stop_gui_server", &rcs::sim::Sim::stop_gui_server); diff --git a/src/sim/gui.h b/src/sim/gui.h index 36360492..465b6a12 100644 --- a/src/sim/gui.h +++ b/src/sim/gui.h @@ -53,6 +53,8 @@ class GuiServer { public: GuiServer(mjModel* m, mjData* d, const std::string& id); ~GuiServer(); + void publish_state(); + void publish_state_if_requested(); void update_mjdata_callback(); private: diff --git a/src/sim/gui_server.cpp b/src/sim/gui_server.cpp index 3e838e9a..36a916d5 100644 --- a/src/sim/gui_server.cpp +++ b/src/sim/gui_server.cpp @@ -46,12 +46,16 @@ GuiServer::~GuiServer() { } }; -void GuiServer::update_mjdata_callback() { +void GuiServer::publish_state() { + this->shm.state_lock.lock(); + mj_getState(this->m, this->d, this->shm.state.ptr, MJ_PHYSICS_SPEC); + this->shm.state_lock.unlock(); +} + +void GuiServer::publish_state_if_requested() { this->shm.info_lock.lock_upgradable(); if (*this->shm.info_byte) { - this->shm.state_lock.lock(); - mj_getState(this->m, this->d, this->shm.state.ptr, MJ_PHYSICS_SPEC); - this->shm.state_lock.unlock(); + this->publish_state(); this->shm.info_lock.unlock_upgradable_and_lock(); *this->shm.info_byte = false; this->shm.info_lock.unlock_and_lock_upgradable(); @@ -59,5 +63,7 @@ void GuiServer::update_mjdata_callback() { this->shm.info_lock.unlock_upgradable(); } +void GuiServer::update_mjdata_callback() { this->publish_state_if_requested(); } + } // namespace sim } // namespace rcs diff --git a/src/sim/sim.cpp b/src/sim/sim.cpp index 164ad2c8..b988d4b3 100644 --- a/src/sim/sim.cpp +++ b/src/sim/sim.cpp @@ -182,5 +182,11 @@ void Sim::start_gui_server(const std::string& id) { // TODO: when stop_gui_server is called, the callback still exists but now // points to an non existing gui void Sim::stop_gui_server() { this->gui.reset(); } + +void Sim::sync_gui() { + if (this->gui.has_value()) { + this->gui->publish_state(); + } +} } // namespace sim } // namespace rcs diff --git a/src/sim/sim.h b/src/sim/sim.h index 9aa5f0af..c75121d3 100644 --- a/src/sim/sim.h +++ b/src/sim/sim.h @@ -101,6 +101,7 @@ class Sim { const std::string& id, int frame_rate, size_t width, size_t height); void start_gui_server(const std::string& id); void stop_gui_server(); + void sync_gui(); }; } // namespace sim } // namespace rcs