Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"mujoco==3.2.6",
"pin==3.7.0",
"greenlet",
"duckdb",
]
readme = "README.md"
maintainers = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }]
Expand Down
2 changes: 2 additions & 0 deletions python/rcs/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import typer
from rcs.envs.storage_wrapper import StorageWrapper
from rcs.sim_state_replay import replay as replay_command

app = typer.Typer()
app.command()(replay_command)


@app.command()
Expand Down
1 change: 1 addition & 0 deletions python/rcs/_core/sim.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class Sim:
def set_config(self, cfg: SimConfig) -> bool: ...
def step(self, k: int) -> None: ...
def step_until_convergence(self) -> None: ...
def sync_gui(self) -> None: ...

class SimCameraConfig(rcs._core.common.BaseCameraConfig):
type: CameraType
Expand Down
17 changes: 17 additions & 0 deletions python/rcs/envs/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ def reset(
return super().reset(seed=seed, options=options)


class SimStateObservationWrapper(ActObsInfoWrapper):
STATE_KEY = "sim_state"
STATE_SPEC_KEY = "sim_state_spec"

def __init__(self, env):
super().__init__(env)
assert self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.SIMULATION, "Base environment must be simulation."
self.sim = cast(sim.Sim, self.get_wrapper_attr("sim"))

def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
observation = dict(observation)
sim_state = self.sim.get_state()
observation[self.STATE_KEY] = sim_state
observation[self.STATE_SPEC_KEY] = self.sim.get_state_spec()
return observation, info


class GripperWrapperSim(ActObsInfoWrapper):
def __init__(self, env):
super().__init__(env)
Expand Down
29 changes: 29 additions & 0 deletions python/rcs/sim/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import mujoco as mj
import mujoco.viewer
import numpy as np
from rcs._core.sim import GuiClient as _GuiClient
from rcs._core.sim import Sim as _Sim
from rcs.sim import SimConfig, egl_bootstrap
Expand Down Expand Up @@ -43,6 +44,8 @@ def gui_loop(gui_uuid: str, close_event):


class Sim(_Sim):
STATE_SPEC = mj.mjtState.mjSTATE_INTEGRATION

def __init__(self, mjmdl: str | PathLike | ModelComposer, cfg: SimConfig | None = None):
if isinstance(mjmdl, ModelComposer):
self.model = mjmdl.get_model()
Expand All @@ -66,6 +69,32 @@ def __init__(self, mjmdl: str | PathLike | ModelComposer, cfg: SimConfig | None
if cfg is not None:
self.set_config(cfg)

def get_state_spec(self) -> int:
return int(self.STATE_SPEC)

def get_state_size(self, spec: int | None = None) -> int:
state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec)
return mj.mj_stateSize(self.model, state_spec)

def get_state(self, spec: int | None = None) -> np.ndarray:
state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec)
state = np.empty(self.get_state_size(int(state_spec)), dtype=np.float64)
mj.mj_getState(self.model, self.data, state, state_spec)
return state

def set_state(self, state: np.ndarray, spec: int | None = None):
state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec)
state_array = np.asarray(state, dtype=np.float64)
expected_size = self.get_state_size(int(state_spec))
if state_array.shape != (expected_size,):
msg = (
f"Expected MuJoCo state with shape ({expected_size},), "
f"got {state_array.shape} for spec {int(state_spec)}."
)
raise ValueError(msg)
mj.mj_setState(self.model, self.data, state_array, state_spec)
mj.mj_forward(self.model, self.data)

def close_gui(self):
if self._stop_event is not None:
self._stop_event.set()
Expand Down
242 changes: 242 additions & 0 deletions python/rcs/sim_state_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
from __future__ import annotations

import time
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated, Any

import gymnasium as gym
import numpy as np
import pyarrow.compute as pc
import pyarrow.dataset as ds
import typer
from PIL import Image
from rcs.envs.base import ControlMode
from rcs.envs.sim import SimStateObservationWrapper

import rcs # noqa: F401

app = typer.Typer(help="Replay recorded MuJoCo trajectories from a parquet dataset.")

DATASET_ARGUMENT = typer.Argument(..., exists=True, file_okay=False, dir_okay=True)

ENV_ID_OPTION = typer.Option(help="Gymnasium env id used for replay.")
TRAJECTORY_UUID_OPTION = typer.Option(help="UUID of the recorded trajectory to replay.")
CAMERA_OPTION = typer.Option("--camera", help="Camera names to enable on the replay env.")
RESOLUTION_OPTION = typer.Option(help="Replay camera resolution as WIDTH HEIGHT.")
FRAME_RATE_OPTION = typer.Option(help="Replay camera frame rate.")
RENDER_MODE_OPTION = typer.Option(help="Gym render mode for the replay env.")
CONTROL_MODE_OPTION = typer.Option(help="Control mode name for env creation.")
SLEEP_OPTION = typer.Option(help="Optional delay between restored states.")
OUTPUT_DIR_OPTION = typer.Option(help="Optional directory for re-rendered RGB frames.")
PREFER_DUCKDB_OPTION = typer.Option(help="Use duckdb for parquet loading when it is available.")


@dataclass(frozen=True)
class RecordedSimStep:
step: int
uuid: str
timestamp: float | None
observation: dict[str, Any]

@property
def sim_state(self) -> np.ndarray:
return np.asarray(self.observation[SimStateObservationWrapper.STATE_KEY], dtype=np.float64)

@property
def sim_state_spec(self) -> int:
return int(self.observation.get(SimStateObservationWrapper.STATE_SPEC_KEY, 0))


class DuckDBUnavailableError(RuntimeError):
pass


def _get_duckdb_module():
try:
import duckdb
except ModuleNotFoundError as exc:
msg = (
"duckdb is required for the preferred parquet read path but is not installed. "
"Install the 'duckdb' Python package or rely on the pyarrow fallback in library calls."
)
raise DuckDBUnavailableError(msg) from exc
return duckdb


def _load_distinct_uuids_with_duckdb(dataset_path: Path) -> list[str]:
duckdb = _get_duckdb_module()
connection = duckdb.connect()
try:
rows = connection.execute(
"SELECT DISTINCT uuid FROM read_parquet(?) ORDER BY uuid",
[str(dataset_path)],
).fetchall()
finally:
connection.close()
return [row[0] for row in rows]


def _load_distinct_uuids_with_pyarrow(dataset_path: Path) -> list[str]:
dataset = ds.dataset(str(dataset_path), format="parquet")
uuids = dataset.to_table(columns=["uuid"])["uuid"]
return sorted(str(uuid) for uuid in pc.unique(uuids).to_pylist()) # type: ignore


def list_trajectory_ids(dataset_path: Path, prefer_duckdb: bool = True) -> list[str]:
if prefer_duckdb:
try:
return _load_distinct_uuids_with_duckdb(dataset_path)
except DuckDBUnavailableError:
pass
return _load_distinct_uuids_with_pyarrow(dataset_path)


def _load_trajectory_with_duckdb(dataset_path: Path, trajectory_uuid: str) -> list[RecordedSimStep]:
duckdb = _get_duckdb_module()
connection = duckdb.connect()
try:
table = connection.execute(
"SELECT uuid, step, timestamp, obs FROM read_parquet(?) WHERE uuid = ? ORDER BY step",
[str(dataset_path), trajectory_uuid],
).to_arrow_table()
finally:
connection.close()
return [
RecordedSimStep(
step=int(row["step"]),
uuid=str(row["uuid"]),
timestamp=float(row["timestamp"]) if row["timestamp"] is not None else None,
observation=row["obs"],
)
for row in table.to_pylist()
]


def _load_trajectory_with_pyarrow(dataset_path: Path, trajectory_uuid: str) -> list[RecordedSimStep]:
dataset = ds.dataset(str(dataset_path), format="parquet")
table = dataset.to_table(filter=pc.field("uuid") == trajectory_uuid, columns=["uuid", "step", "timestamp", "obs"])
rows = table.sort_by([("step", "ascending")]).to_pylist()
return [
RecordedSimStep(
step=int(row["step"]),
uuid=str(row["uuid"]),
timestamp=float(row["timestamp"]) if row["timestamp"] is not None else None,
observation=row["obs"],
)
for row in rows
]


def load_trajectory(dataset_path: Path, trajectory_uuid: str, prefer_duckdb: bool = True) -> list[RecordedSimStep]:
if prefer_duckdb:
try:
return _load_trajectory_with_duckdb(dataset_path, trajectory_uuid)
except DuckDBUnavailableError:
pass
return _load_trajectory_with_pyarrow(dataset_path, trajectory_uuid)


def resolve_trajectory_uuid(dataset_path: Path, trajectory_uuid: str | None, prefer_duckdb: bool = True) -> str:
if trajectory_uuid is not None:
return trajectory_uuid
available_uuids = list_trajectory_ids(dataset_path, prefer_duckdb=prefer_duckdb)
if len(available_uuids) == 1:
return available_uuids[0]
msg = (
f"Dataset {dataset_path} contains {len(available_uuids)} trajectories. "
f"Pass --trajectory-uuid and choose one of: {available_uuids}"
)
raise ValueError(msg)


def restore_sim_step(env: gym.Env, recorded_step: RecordedSimStep):
sim = env.get_wrapper_attr("sim")
sim.set_state(recorded_step.sim_state, spec=recorded_step.sim_state_spec)


def collect_rgb_frames(env: gym.Env) -> dict[str, np.ndarray]:
try:
camera_set = env.get_wrapper_attr("camera_set")
except AttributeError:
return {}

frameset = camera_set.get_latest_frames()
if frameset is None:
return {}

rgb_frames: dict[str, np.ndarray] = {}
for camera_name, frame in frameset.frames.items():
lower_name = camera_name.lower()
if "digit" in lower_name or "tactile" in lower_name:
continue
rgb_frames[camera_name] = np.asarray(frame.camera.color.data)
return rgb_frames


def save_rgb_frames(output_dir: Path, recorded_step: RecordedSimStep, rgb_frames: dict[str, np.ndarray]):
output_dir.mkdir(parents=True, exist_ok=True)
for camera_name, rgb_frame in rgb_frames.items():
Image.fromarray(rgb_frame).save(output_dir / f"step-{recorded_step.step:06d}-{camera_name}.png")


def replay_trajectory(
env: gym.Env,
recorded_steps: list[RecordedSimStep],
*,
sleep_s: float = 0.0,
output_dir: Path | None = None,
):
if not recorded_steps:
msg = "No recorded sim states found in the requested trajectory."
raise ValueError(msg)

env.reset()
sim = env.get_wrapper_attr("sim")
for recorded_step in recorded_steps:
restore_sim_step(env, recorded_step)
if output_dir is not None:
save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env))
sim.sync_gui()
if sleep_s > 0:
time.sleep(sleep_s)


@app.command()
def replay(
dataset: Annotated[Path, DATASET_ARGUMENT],
env_id: Annotated[str, ENV_ID_OPTION] = "rcs/FR3SimplePickUpSim-v0",
trajectory_uuid: Annotated[str | None, TRAJECTORY_UUID_OPTION] = None,
camera: Annotated[list[str] | None, CAMERA_OPTION] = None,
resolution: Annotated[tuple[int, int], RESOLUTION_OPTION] = (256, 256),
frame_rate: Annotated[int, FRAME_RATE_OPTION] = 0,
render_mode: Annotated[str, RENDER_MODE_OPTION] = "human",
control_mode: Annotated[str, CONTROL_MODE_OPTION] = ControlMode.CARTESIAN_TRPY.name,
sleep_s: Annotated[float, SLEEP_OPTION] = 0.0,
output_dir: Annotated[Path | None, OUTPUT_DIR_OPTION] = None,
prefer_duckdb: Annotated[bool, PREFER_DUCKDB_OPTION] = True,
):
if camera is None:
camera = []
resolved_uuid = resolve_trajectory_uuid(dataset, trajectory_uuid, prefer_duckdb=prefer_duckdb)
env = gym.make(
env_id,
render_mode=render_mode,
control_mode=ControlMode[control_mode],
resolution=resolution,
frame_rate=frame_rate,
cam_list=camera,
)
try:
recorded_steps = load_trajectory(dataset, resolved_uuid, prefer_duckdb=prefer_duckdb)
replay_trajectory(env, recorded_steps, sleep_s=sleep_s, output_dir=output_dir)
finally:
env.close()

typer.echo(f"Replayed {len(recorded_steps)} steps from trajectory {resolved_uuid}.")
if output_dir is not None:
typer.echo(f"Saved re-rendered RGB frames to {output_dir}.")


if __name__ == "__main__":
app()
Loading
Loading