diff --git a/examples/teleop/franka.py b/examples/teleop/franka.py index 334fed08..0dca09e6 100644 --- a/examples/teleop/franka.py +++ b/examples/teleop/franka.py @@ -7,8 +7,10 @@ from rcs.envs.configs import EmptyWorldFR3Duo from rcs.envs.storage_wrapper import StorageWrapper from rcs.envs.tasks import PickTaskConfig +from rcs.operator.compose import ComposeOperator, ComposeOperatorConfig from rcs.operator.gello import GelloConfig, GelloOperator from rcs.operator.interface import TeleopLoop +from rcs.operator.keyboard import KeyboardOperatorConfig from rcs.operator.quest import QuestConfig, QuestOperator from rcs_fr3.configs import DefaultFR3MultiHardwareEnv from rcs_fr3.creators import HardwareCameraCreatorConfig @@ -117,7 +119,7 @@ def get_env(): # env_rel = StorageWrapper( # env_rel, DATASET_PATH, INSTRUCTION, batch_size=32, max_rows_per_group=100, max_rows_per_file=1000 # ) - operator = GelloOperator(config) if isinstance(config, GelloConfig) else QuestOperator(config) + operator = build_operator(config) else: # FR3 @@ -137,10 +139,21 @@ def get_env(): sim = env_rel.get_wrapper_attr("sim") MujocoPublisher(sim.model, sim.data, MQ3_ADDR, visible_geoms_groups=list(range(1, 3))) - operator = GelloOperator(config, sim) if isinstance(config, GelloConfig) else QuestOperator(config, sim) + operator = build_operator(config, sim) return env_rel, operator +def build_operator(config: QuestConfig | GelloConfig, sim=None): + if isinstance(config, GelloConfig): + compose_config = ComposeOperatorConfig( + base=config, + override=KeyboardOperatorConfig(control_mode=GelloOperator.control_mode), + simulation=config.simulation, + ) + return ComposeOperator(compose_config, sim) + return QuestOperator(config, sim) if sim is not None else QuestOperator(config) + + def main(): env_rel, operator = get_env() env_rel.reset() diff --git a/examples/teleop/so101.py b/examples/teleop/so101.py new file mode 100644 index 00000000..fe8e4945 --- /dev/null +++ b/examples/teleop/so101.py @@ -0,0 +1,61 @@ +import logging +from pathlib import Path + +from rcs.envs.base import ControlMode, CoverWrapper, MultiRobotWrapper +from rcs.operator.interface import TeleopLoop +from rcs.operator.so101 import SO101Operator, SO101OperatorConfig +from rcs_so101 import RCSSO101ConfigEnvCreator, SO101Config + +import rcs + +logger = logging.getLogger(__name__) + +FOLLOWER_PORT = "/dev/ttyACM0" +LEADER_PORT = "/dev/ttyACM1" +FOLLOWER_CALIBRATION_DIR = Path(".cache/so101/follower") +LEADER_CALIBRATION_DIR = Path(".cache/so101/leader") +ROBOT_NAME = "so101" + + +def get_env(): + robot_type = rcs.common.RobotType("SO101") + robot_meta = rcs.ROBOTS[robot_type] + robot_cfg = SO101Config( + id="follower", + port=FOLLOWER_PORT, + calibration_dir=str(FOLLOWER_CALIBRATION_DIR), + robot_type=robot_type, + attachment_site=robot_meta.attachment_site, + kinematic_model_path=robot_meta.mjcf_model_path, + dof=robot_meta.dof, + joint_limits=robot_meta.joint_limits, + q_home=robot_meta.q_home, + tcp_offset=rcs.common.Pose(), + ) + env = RCSSO101ConfigEnvCreator()( + robot_cfg=robot_cfg, + control_mode=ControlMode.JOINTS, + max_relative_movement=None, + relative_to=SO101Operator.control_mode[1], + ) + return CoverWrapper(MultiRobotWrapper({ROBOT_NAME: env})) + + +def main(): + env = get_env() + operator = SO101Operator( + SO101OperatorConfig( + controller_name=ROBOT_NAME, + id="leader", + port=LEADER_PORT, + calibration_dir=str(LEADER_CALIBRATION_DIR), + use_degrees=True, + ) + ) + tele = TeleopLoop(env, operator) + with env, tele: # type: ignore + tele.environment_step_loop() + + +if __name__ == "__main__": + main() diff --git a/extensions/rcs_so101/src/rcs_so101/__init__.py b/extensions/rcs_so101/src/rcs_so101/__init__.py index ac7ea3fb..c2c8d004 100644 --- a/extensions/rcs_so101/src/rcs_so101/__init__.py +++ b/extensions/rcs_so101/src/rcs_so101/__init__.py @@ -2,7 +2,7 @@ from rcs_so101._core.so101_ik import SO101IK from . import configs, creators, hw -from .creators import RCSSO101ConfigEnvCreator +from .creators import RCSSO101ConfigEnvCreator, make_so101_leader from .hw import SO101, SO101Config, SO101Gripper __all__ = [ @@ -14,5 +14,6 @@ "SO101", "SO101Config", "SO101Gripper", + "make_so101_leader", "__version__", ] diff --git a/extensions/rcs_so101/src/rcs_so101/creators.py b/extensions/rcs_so101/src/rcs_so101/creators.py index f92969c7..054cbab0 100644 --- a/extensions/rcs_so101/src/rcs_so101/creators.py +++ b/extensions/rcs_so101/src/rcs_so101/creators.py @@ -1,6 +1,8 @@ import logging import typing from dataclasses import dataclass, field +from pathlib import Path +from typing import Any import gymnasium as gym from rcs._core.common import BaseCameraConfig @@ -117,17 +119,24 @@ def config(self) -> SO101HardwareEnvCreatorConfig: msg = "Implement config() in a subclass or pass `cfg=` explicitly." raise NotImplementedError(msg) - # For now, the leader-follower teleop script uses the leader object directly - # and doesn't depend on an RCS-provided class. - # @staticmethod - # def teleoperator( - # id: str, - # port: str, - # calibration_dir: PathLike | str | None = None, - # ) -> SO101Leader: - # if isinstance(calibration_dir, str): - # calibration_dir = Path(calibration_dir) - # cfg = SO101LeaderConfig(id=id, calibration_dir=calibration_dir, port=port) - # teleop = make_teleoperator_from_config(cfg) - # teleop.connect() - # return teleop + +def make_so101_leader( + id: str = "leader", + port: str = "/dev/ttyACM1", + calibration_dir: str | Path | None = None, + use_degrees: bool = True, +) -> Any: + try: + from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig + from lerobot.teleoperators.so_leader.so_leader import SO101Leader + except ImportError as exc: + msg = "lerobot SO101 leader dependencies are not available." + raise ImportError(msg) from exc + + if isinstance(calibration_dir, str): + calibration_dir = Path(calibration_dir) + + cfg = SO101LeaderConfig(id=id, calibration_dir=calibration_dir, port=port, use_degrees=use_degrees) + teleop = SO101Leader(cfg) + teleop.connect() + return teleop diff --git a/python/rcs/operator/__init__.py b/python/rcs/operator/__init__.py index e69de29b..6849bced 100644 --- a/python/rcs/operator/__init__.py +++ b/python/rcs/operator/__init__.py @@ -0,0 +1,31 @@ +from rcs.operator.compose import ComposeOperator, ComposeOperatorConfig +from rcs.operator.gello import GelloConfig, GelloOperator +from rcs.operator.interface import ( + BaseOperator, + BaseOperatorConfig, + TeleopCommands, + TeleopLoop, +) +from rcs.operator.keyboard import KeyboardOperator, KeyboardOperatorConfig +from rcs.operator.pedals import FootPedalOperator, FootPedalOperatorConfig +from rcs.operator.quest import QuestConfig, QuestOperator +from rcs.operator.so101 import SO101Operator, SO101OperatorConfig + +__all__ = [ + "BaseOperator", + "BaseOperatorConfig", + "ComposeOperator", + "ComposeOperatorConfig", + "GelloConfig", + "GelloOperator", + "FootPedalOperator", + "FootPedalOperatorConfig", + "KeyboardOperator", + "KeyboardOperatorConfig", + "QuestConfig", + "QuestOperator", + "SO101Operator", + "SO101OperatorConfig", + "TeleopCommands", + "TeleopLoop", +] diff --git a/python/rcs/operator/compose.py b/python/rcs/operator/compose.py new file mode 100644 index 00000000..bd78c2fc --- /dev/null +++ b/python/rcs/operator/compose.py @@ -0,0 +1,87 @@ +import logging +import threading +from dataclasses import dataclass, field + +from rcs.operator.interface import BaseOperator, BaseOperatorConfig, TeleopCommands +from rcs.sim.sim import Sim +from rcs.utils import SimpleFrameRate + +logger = logging.getLogger(__name__) + + +class ComposeOperator(BaseOperator): + """Compose two operators so the override action wins on overlapping controllers.""" + + def __init__(self, config: "ComposeOperatorConfig", sim: Sim | None = None): + super().__init__(config, sim) + self.config: ComposeOperatorConfig + self._exit_requested = False + self._child_lock = threading.Lock() + + self._base_operator = self.config.base.operator_class(self.config.base, sim) + self._override_operator = self.config.override.operator_class(self.config.override, sim) + + if self._base_operator.control_mode != self._override_operator.control_mode: + msg = ( + "ComposeOperator requires both child operators to use the same control_mode. " + f"Got base={self._base_operator.control_mode} and " + f"override={self._override_operator.control_mode}." + ) + raise ValueError(msg) + + self.control_mode = self._base_operator.control_mode + self.controller_names = list( + dict.fromkeys(self._base_operator.controller_names + self._override_operator.controller_names) + ) + + def consume_commands(self) -> TeleopCommands: + with self._child_lock: + base_commands = self._base_operator.consume_commands() + override_commands = self._override_operator.consume_commands() + return TeleopCommands.merged(base_commands, override_commands) + + def reset_operator_state(self): + with self._child_lock: + self._base_operator.reset_operator_state() + self._override_operator.reset_operator_state() + + def consume_action(self): + with self._child_lock: + actions = self._base_operator.consume_action() + override_actions = self._override_operator.consume_action() + return actions | override_actions + + def run(self): + self._base_operator.start() + self._override_operator.start() + + rate_limiter = SimpleFrameRate(self.config.read_frequency, "compose operator") + + try: + while not self._exit_requested: + if not self._base_operator.is_alive(): + logger.warning("ComposeOperator base child stopped.") + break + if not self._override_operator.is_alive(): + logger.warning("ComposeOperator override child stopped.") + break + rate_limiter() + finally: + self.close() + + def close(self): + self._exit_requested = True + self._base_operator.close() + self._override_operator.close() + + current_thread = threading.current_thread() + for operator in (self._base_operator, self._override_operator): + if operator.is_alive() and current_thread != operator: + operator.join(timeout=1.0) + + +@dataclass(kw_only=True) +class ComposeOperatorConfig(BaseOperatorConfig): + operator_class: type[BaseOperator] = field(default=ComposeOperator) + base: BaseOperatorConfig + override: BaseOperatorConfig diff --git a/python/rcs/operator/gello.py b/python/rcs/operator/gello.py index 418c736b..868a4173 100644 --- a/python/rcs/operator/gello.py +++ b/python/rcs/operator/gello.py @@ -1,5 +1,4 @@ import contextlib -import copy import logging import threading import time @@ -19,13 +18,6 @@ except ImportError: HAS_DYNAMIXEL_SDK = False -try: - from pynput import keyboard - - HAS_PYNPUT = True -except ImportError: - HAS_PYNPUT = False - from rcs.envs.base import ControlMode, RelativeTo from rcs.operator.interface import BaseOperator, BaseOperatorConfig, TeleopCommands from rcs.sim.sim import Sim @@ -346,40 +338,15 @@ def __init__(self, config: "GelloConfig", sim: Sim | None = None): super().__init__(config, sim) self.config: GelloConfig self._resource_lock = threading.Lock() - self._cmd_lock = threading.Lock() - self._exit_requested = False - self._commands = TeleopCommands() - self.controller_names = list(self.config.arms.keys()) self._last_joints: Dict[str, np.ndarray | None] = {name: None for name in self.controller_names} self._last_gripper = {name: 1.0 for name in self.controller_names} self._hws: Dict[str, GelloHardware] = {} - if HAS_PYNPUT: - self._listener = keyboard.Listener(on_press=self._on_press) - self._listener.start() - else: - logger.warning("pynput not found. Keyboard triggers disabled.") - - def _on_press(self, key): - try: - if hasattr(key, "char"): - if key.char == "s": - with self._cmd_lock: - self._commands.sync_position = True - elif key.char == "r": - with self._cmd_lock: - self._commands.failure = True - except AttributeError: - pass - def consume_commands(self) -> TeleopCommands: - with self._cmd_lock: - cmds = copy.copy(self._commands) - self._commands = TeleopCommands() - return cmds + return TeleopCommands() def reset_operator_state(self): # GELLO is absolute, no internal state to reset typically @@ -425,8 +392,6 @@ def run(self): def close(self): self._exit_requested = True - if HAS_PYNPUT and hasattr(self, "_listener"): - self._listener.stop() for hw in self._hws.values(): hw.close() if self.is_alive() and threading.current_thread() != self: diff --git a/python/rcs/operator/interface.py b/python/rcs/operator/interface.py index e7e6b3c0..915ffdb1 100644 --- a/python/rcs/operator/interface.py +++ b/python/rcs/operator/interface.py @@ -24,6 +24,20 @@ class TeleopCommands: sync_position: bool = False reset_origin_to_current: dict[str, bool] = field(default_factory=dict) + @classmethod + def merged(cls, *commands: "TeleopCommands") -> "TeleopCommands": + merged = cls() + for cmd in commands: + merged.record = merged.record or cmd.record + merged.success = merged.success or cmd.success + merged.failure = merged.failure or cmd.failure + merged.sync_position = merged.sync_position or cmd.sync_position + for controller, should_reset in cmd.reset_origin_to_current.items(): + merged.reset_origin_to_current[controller] = ( + merged.reset_origin_to_current.get(controller, False) or should_reset + ) + return merged + class BaseOperator(ABC, threading.Thread): control_mode: tuple[ControlMode, RelativeTo] diff --git a/python/rcs/operator/keyboard.py b/python/rcs/operator/keyboard.py new file mode 100644 index 00000000..d43bc1dd --- /dev/null +++ b/python/rcs/operator/keyboard.py @@ -0,0 +1,94 @@ +import copy +import logging +import threading +from dataclasses import dataclass, field + +try: + from pynput import keyboard + + HAS_PYNPUT = True +except ImportError: + HAS_PYNPUT = False + +from rcs.envs.base import ArmWithGripper, ControlMode, RelativeTo +from rcs.operator.interface import BaseOperator, BaseOperatorConfig, TeleopCommands +from rcs.sim.sim import Sim +from rcs.utils import SimpleFrameRate + +logger = logging.getLogger(__name__) + + +class KeyboardOperator(BaseOperator): + """Keyboard-only operator that emits teleop commands and no motion actions.""" + + control_mode = (ControlMode.JOINTS, RelativeTo.NONE) + + def __init__(self, config: "KeyboardOperatorConfig", sim: Sim | None = None): + super().__init__(config, sim) + self.config: KeyboardOperatorConfig + self._cmd_lock = threading.Lock() + self._exit_requested = False + self._commands = TeleopCommands() + self.control_mode = self.config.control_mode + self.controller_names = [] + + if HAS_PYNPUT: + self._listener = keyboard.Listener(on_press=self._on_press) + self._listener.start() + else: + logger.warning("pynput not found. Keyboard commands disabled.") + + def _on_press(self, key): + try: + if key == keyboard.Key.space: + with self._cmd_lock: + self._commands.record = True + return + + char = key.char + except AttributeError: + return + + with self._cmd_lock: + if char == self.config.sync_key: + self._commands.sync_position = True + elif char == self.config.record_key: + self._commands.record = True + elif char == self.config.success_key: + self._commands.success = True + elif char == self.config.failure_key: + self._commands.failure = True + + def consume_commands(self) -> TeleopCommands: + with self._cmd_lock: + cmds = copy.copy(self._commands) + self._commands = TeleopCommands() + return cmds + + def reset_operator_state(self): + pass + + def consume_action(self) -> dict[str, ArmWithGripper]: + return {} + + def run(self): + rate_limiter = SimpleFrameRate(self.config.read_frequency, "keyboard operator") + while not self._exit_requested: + rate_limiter() + + def close(self): + self._exit_requested = True + if HAS_PYNPUT and hasattr(self, "_listener"): + self._listener.stop() + if self.is_alive() and threading.current_thread() != self: + self.join(timeout=1.0) + + +@dataclass(kw_only=True) +class KeyboardOperatorConfig(BaseOperatorConfig): + operator_class: type[BaseOperator] = field(default=KeyboardOperator) + control_mode: tuple[ControlMode, RelativeTo] = (ControlMode.JOINTS, RelativeTo.NONE) + sync_key: str = "s" + record_key: str = " " + success_key: str = "x" + failure_key: str = "r" diff --git a/python/rcs/operator/pedals.py b/python/rcs/operator/pedals.py index 2d529a92..f4eb676b 100644 --- a/python/rcs/operator/pedals.py +++ b/python/rcs/operator/pedals.py @@ -1,116 +1,143 @@ +import copy +import logging import threading -import time +from dataclasses import dataclass, field -import evdev -from evdev import ecodes +try: + import evdev + from evdev import ecodes + HAS_EVDEV = True +except ImportError: + HAS_EVDEV = False -class FootPedal: - def __init__(self, device_name_substring="Foot Switch"): - """Initializes the foot pedal and starts the background reading thread.""" - self.device_path = self._find_device(device_name_substring) +from rcs.envs.base import ArmWithGripper, ControlMode, RelativeTo +from rcs.operator.interface import BaseOperator, BaseOperatorConfig, TeleopCommands +from rcs.sim.sim import Sim +from rcs.utils import SimpleFrameRate - if not self.device_path: - msg = f"Could not find a device matching '{device_name_substring}'" - raise FileNotFoundError(msg) +logger = logging.getLogger(__name__) - self.device = evdev.InputDevice(self.device_path) - self.device.grab() # Prevent events from leaking into the OS/terminal +_SUPPORTED_COMMANDS = frozenset({"record", "success", "failure", "sync_position"}) - # Dictionary to hold the current state of each key. - # True = Pressed/Held, False = Released - self._key_states = {} - self._lock = threading.Lock() - # Start the background thread - self._running = True - self._thread = threading.Thread(target=self._read_events, daemon=True) - self._thread.start() - print(f"Connected to {self.device.name} at {self.device_path}") +class FootPedalOperator(BaseOperator): + """Command-only operator for foot pedals exposed as a Linux evdev input device.""" - def _find_device(self, substring): - """Finds the device path for the foot pedal.""" - for path in evdev.list_devices(): - dev = evdev.InputDevice(path) - if substring.lower() in dev.name.lower(): - return path - return None + control_mode = (ControlMode.JOINTS, RelativeTo.NONE) - def _read_events(self): - """Background loop that updates the state dictionary.""" - try: - for event in self.device.read_loop(): - if not self._running: - break + def __init__(self, config: "FootPedalOperatorConfig", sim: Sim | None = None): + super().__init__(config, sim) + self.config: FootPedalOperatorConfig + self._cmd_lock = threading.Lock() + self._exit_requested = False + self._commands = TeleopCommands() + self.control_mode = self.config.control_mode + self.controller_names = [] + self._device: evdev.InputDevice | None = None + self._device_thread: threading.Thread | None = None - if event.type == ecodes.EV_KEY: - key_event = evdev.categorize(event) + invalid_commands = set(self.config.command_bindings.values()) - _SUPPORTED_COMMANDS + if invalid_commands: + msg = ( + "Unsupported foot pedal command bindings: " + f"{sorted(invalid_commands)}. Supported commands are {sorted(_SUPPORTED_COMMANDS)}." + ) + raise ValueError(msg) - if isinstance(key_event, evdev.KeyEvent): - with self._lock: - # keystate: 1 is DOWN, 2 is HOLD, 0 is UP - is_pressed = key_event.keystate in [1, 2] + if not HAS_EVDEV: + msg = "evdev is not installed. Install it to use FootPedalOperator." + raise ImportError(msg) - # Store state using the string name of the key (e.g., 'KEY_A') - # If a key resolves to a list (rare, but happens in evdev), take the first one - key_name = key_event.keycode - if isinstance(key_name, list): - key_name = key_name[0] + device_path = self._find_device_path(self.config.device_name_substring) + if device_path is None: + msg = f"Could not find a foot pedal input device matching '{self.config.device_name_substring}'." + raise FileNotFoundError(msg) - self._key_states[key_name] = is_pressed + self._device = evdev.InputDevice(device_path) + if self.config.grab_device: + self._device.grab() - except OSError: - pass # Device disconnected or closed - - def get_states(self): - """ - Returns a snapshot of the latest key states. - Example return: {'KEY_A': True, 'KEY_B': False, 'KEY_C': False} - """ - with self._lock: - # Return a copy to ensure thread safety - return self._key_states.copy() - - def get_key_state(self, key_name): - """Returns the state of a specific key, defaulting to False if never pressed.""" - with self._lock: - return self._key_states.get(key_name, False) + self._device_thread = threading.Thread(target=self._evdev_read_loop, daemon=True) + self._device_thread.start() + logger.info(f"Connected foot pedal device {self._device.name} at {device_path}") - def close(self): - """Cleans up the device and stops the thread.""" - self._running = False + def _find_device_path(self, substring: str) -> str | None: + for path in evdev.list_devices(): + device = evdev.InputDevice(path) + if substring.lower() in device.name.lower(): + return path + return None + + def _trigger_command(self, command_name: str): + with self._cmd_lock: + setattr(self._commands, command_name, True) + + def _evdev_read_loop(self): + assert self._device is not None try: - self.device.ungrab() - self.device.close() - except OSError: - pass + for event in self._device.read_loop(): + if self._exit_requested: + break + if event.type != ecodes.EV_KEY or event.value not in (1, 2): + continue + key_name = getattr(event, "keycode", None) + if isinstance(key_name, list): + key_name = key_name[0] -# ========================================== -# Example Usage -# ========================================== -if __name__ == "__main__": - try: - # Initialize the pedal - pedal = FootPedal("Foot Switch") + command_name = self.config.command_bindings.get(str(key_name)) + if command_name is not None: + self._trigger_command(command_name) + except OSError: + if not self._exit_requested: + logger.warning("Foot pedal device disconnected.", exc_info=True) - # Simulate a typical robotics control loop running at 10Hz - print("Starting control loop... Press Ctrl+C to exit.") - while True: - # Grab the latest states instantly without blocking - states = pedal.get_states() + def consume_commands(self) -> TeleopCommands: + with self._cmd_lock: + cmds = copy.copy(self._commands) + self._commands = TeleopCommands() + return cmds - if states: - # Print only the keys that are currently pressed - pressed_keys = [key for key, is_pressed in states.items() if is_pressed] - print(f"Currently pressed: {pressed_keys}") + def reset_operator_state(self): + pass - # Your teleoperation logic goes here... + def consume_action(self) -> dict[str, ArmWithGripper]: + return {} - time.sleep(0.1) # 10Hz loop + def run(self): + rate_limiter = SimpleFrameRate(self.config.read_frequency, "foot pedal operator") + while not self._exit_requested: + rate_limiter() - except KeyboardInterrupt: - print("\nShutting down...") - finally: - if "pedal" in locals(): - pedal.close() + def close(self): + self._exit_requested = True + + if self._device is not None: + try: + if self.config.grab_device: + self._device.ungrab() + self._device.close() + except OSError: + pass + + if self._device_thread is not None and self._device_thread.is_alive(): + self._device_thread.join(timeout=1.0) + + if self.is_alive() and threading.current_thread() != self: + self.join(timeout=1.0) + + +@dataclass(kw_only=True) +class FootPedalOperatorConfig(BaseOperatorConfig): + operator_class: type[BaseOperator] = field(default=FootPedalOperator) + control_mode: tuple[ControlMode, RelativeTo] = (ControlMode.JOINTS, RelativeTo.NONE) + device_name_substring: str = "Foot Switch" + grab_device: bool = True + command_bindings: dict[str, str] = field( + default_factory=lambda: { + "KEY_B": "sync_position", + "KEY_C": "record", + "KEY_A": "failure", + } + ) diff --git a/python/rcs/operator/so101.py b/python/rcs/operator/so101.py new file mode 100644 index 00000000..62dda472 --- /dev/null +++ b/python/rcs/operator/so101.py @@ -0,0 +1,142 @@ +import logging +import threading +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable + +import numpy as np +from rcs.envs.base import ControlMode, RelativeTo +from rcs.operator.interface import BaseOperator, BaseOperatorConfig, TeleopCommands +from rcs.sim.sim import Sim +from rcs.utils import SimpleFrameRate + +logger = logging.getLogger(__name__) + +LeaderFactory = Callable[["SO101OperatorConfig"], Any] + +JOINT_KEYS = ( + "shoulder_pan.pos", + "shoulder_lift.pos", + "elbow_flex.pos", + "wrist_flex.pos", + "wrist_roll.pos", +) +GRIPPER_KEY = "gripper.pos" + + +def _load_so101_leader_classes() -> tuple[type[Any], type[Any]]: + try: + from lerobot.teleoperators.so_leader.config_so_leader import ( + SO101LeaderConfig as LeRobotSO101LeaderConfig, + ) + from lerobot.teleoperators.so_leader.so_leader import ( + SO101Leader as LeRobotSO101Leader, + ) + except ImportError as exc: + msg = ( + "lerobot SO101 leader dependencies are not available. " + "Install lerobot and its teleoperator dependencies to use SO101Operator." + ) + raise ImportError(msg) from exc + + return LeRobotSO101LeaderConfig, LeRobotSO101Leader + + +def default_so101_leader_factory(config: "SO101OperatorConfig") -> Any: + leader_config_cls, leader_cls = _load_so101_leader_classes() + + leader_kwargs: dict[str, Any] = { + "id": config.id, + "port": config.port, + "use_degrees": config.use_degrees, + } + if config.calibration_dir is not None: + leader_kwargs["calibration_dir"] = Path(config.calibration_dir) + + leader = leader_cls(leader_config_cls(**leader_kwargs)) + leader.connect() + return leader + + +class SO101Operator(BaseOperator): + control_mode = (ControlMode.JOINTS, RelativeTo.NONE) + + def __init__(self, config: "SO101OperatorConfig", sim: Sim | None = None): + super().__init__(config, sim) + self.config: SO101OperatorConfig + self._resource_lock = threading.Lock() + self._exit_requested = False + self.controller_names = [self.config.controller_name] + self._last_joints: np.ndarray | None = None + self._last_gripper = 1.0 + self._leader: Any | None = None + self._leader_factory = self.config.leader_factory or default_so101_leader_factory + + @staticmethod + def _leader_action_to_target(action: dict[str, float], use_degrees: bool) -> tuple[np.ndarray, float]: + joints = np.array([action[key] for key in JOINT_KEYS], dtype=np.float64) + if use_degrees: + joints = np.deg2rad(joints) + + gripper = float(np.clip(action.get(GRIPPER_KEY, 100.0) / 100.0, 0.0, 1.0)) + return joints, gripper + + def consume_commands(self) -> TeleopCommands: + return TeleopCommands() + + def reset_operator_state(self): + pass + + def consume_action(self) -> dict[str, Any]: + with self._resource_lock: + if self._last_joints is None: + return {} + + return { + self.config.controller_name: { + "joints": self._last_joints.copy(), + "gripper": np.array([self._last_gripper], dtype=np.float32), + } + } + + def run(self): + try: + self._leader = self._leader_factory(self.config) + except Exception as exc: + logger.error(f"Failed to initialize SO101 leader: {exc}") + return + + rate_limiter = SimpleFrameRate(self.config.read_frequency, "so101 readout") + + while not self._exit_requested: + try: + leader_action = self._leader.get_action() + joints, gripper = self._leader_action_to_target(leader_action, self.config.use_degrees) + with self._resource_lock: + self._last_joints = joints + self._last_gripper = gripper + except Exception as exc: + logger.warning(f"Error reading SO101 leader state: {exc}") + + rate_limiter() + + def close(self): + self._exit_requested = True + if self._leader is not None: + try: + self._leader.disconnect() + except Exception: + logger.debug("Failed to disconnect SO101 leader cleanly.", exc_info=True) + if self.is_alive() and threading.current_thread() != self: + self.join(timeout=1.0) + + +@dataclass(kw_only=True) +class SO101OperatorConfig(BaseOperatorConfig): + operator_class: type[BaseOperator] = field(default=SO101Operator) + controller_name: str = "so101" + id: str = "leader" + port: str = "/dev/ttyACM1" + calibration_dir: str | None = None + use_degrees: bool = True + leader_factory: LeaderFactory | None = None diff --git a/python/tests/test_so101_operator.py b/python/tests/test_so101_operator.py new file mode 100644 index 00000000..c6532bf7 --- /dev/null +++ b/python/tests/test_so101_operator.py @@ -0,0 +1,90 @@ +import time +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path + +import numpy as np + +operator_path = Path(__file__).resolve().parents[1] / "rcs" / "operator" / "so101.py" +spec = spec_from_file_location("rcs_local_operator_so101", operator_path) +assert spec is not None and spec.loader is not None +so101_module = module_from_spec(spec) +spec.loader.exec_module(so101_module) +SO101Operator = so101_module.SO101Operator +SO101OperatorConfig = so101_module.SO101OperatorConfig + + +class FakeLeader: + def __init__(self, action): + self._action = action + self.disconnected = False + + def get_action(self): + return self._action + + def disconnect(self): + self.disconnected = True + + +def test_leader_action_to_target_converts_degrees_to_radians(): + joints, gripper = SO101Operator._leader_action_to_target( + { + "shoulder_pan.pos": 90.0, + "shoulder_lift.pos": 0.0, + "elbow_flex.pos": -45.0, + "wrist_flex.pos": 30.0, + "wrist_roll.pos": 180.0, + "gripper.pos": 25.0, + }, + use_degrees=True, + ) + + np.testing.assert_allclose(joints, np.deg2rad([90.0, 0.0, -45.0, 30.0, 180.0])) + assert gripper == 0.25 + + +def test_leader_action_to_target_keeps_non_degree_joint_values(): + joints, gripper = SO101Operator._leader_action_to_target( + { + "shoulder_pan.pos": -100.0, + "shoulder_lift.pos": -50.0, + "elbow_flex.pos": 0.0, + "wrist_flex.pos": 50.0, + "wrist_roll.pos": 100.0, + "gripper.pos": 100.0, + }, + use_degrees=False, + ) + + np.testing.assert_allclose(joints, np.array([-100.0, -50.0, 0.0, 50.0, 100.0])) + assert gripper == 1.0 + + +def test_operator_run_updates_latest_action_and_closes_cleanly(): + fake_leader = FakeLeader( + { + "shoulder_pan.pos": 10.0, + "shoulder_lift.pos": 20.0, + "elbow_flex.pos": 30.0, + "wrist_flex.pos": 40.0, + "wrist_roll.pos": 50.0, + "gripper.pos": 60.0, + } + ) + cfg = SO101OperatorConfig( + read_frequency=200, + leader_factory=lambda _: fake_leader, + ) + operator = SO101Operator(cfg) + + operator.start() + time.sleep(0.05) + action = operator.consume_action() + operator.close() + + assert cfg.controller_name in action + np.testing.assert_allclose( + action[cfg.controller_name]["joints"], + np.deg2rad([10.0, 20.0, 30.0, 40.0, 50.0]), + ) + np.testing.assert_allclose(action[cfg.controller_name]["gripper"], np.array([0.6], dtype=np.float32)) + assert fake_leader.disconnected