diff --git a/.gitignore b/.gitignore index 4045db012e..640b5b9224 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,9 @@ data/* FastSAM-x.pt yolo11n.pt +# MuJoCo dumps a GL-warning log in the CWD at runtime. +MUJOCO_LOG.TXT + /thread_monitor_report.csv # symlink one of .envrc.* if you'd like to use diff --git a/data/.lfs/command_center.html.tar.gz b/data/.lfs/command_center.html.tar.gz index 9f7bfe1979..eb86adabea 100644 --- a/data/.lfs/command_center.html.tar.gz +++ b/data/.lfs/command_center.html.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7663ac06572e3b9490859b400e9ddbf45ac3ef52a58fcdb8c2c41936dc9d43b5 -size 137675 +oid sha256:1d8ce5f2a30928e254607e7562eec00cac56e8a1b17c898eca98572b34e2a674 +size 138506 diff --git a/data/.lfs/mujoco_sim.tar.gz b/data/.lfs/mujoco_sim.tar.gz index 57833fbbc6..0abd6ac057 100644 --- a/data/.lfs/mujoco_sim.tar.gz +++ b/data/.lfs/mujoco_sim.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d178439569ed81dfad05455419dc51da2c52021313b6d7b9259d9e30946db7c6 -size 60186340 +oid sha256:a0ba1b3363c772f5470717b1ce991ee247d7984d3bb2adf15f4c3ef62d90843f +size 60290020 diff --git a/dimos/control/components.py b/dimos/control/components.py index 69dc195f4a..d7c3f0742f 100644 --- a/dimos/control/components.py +++ b/dimos/control/components.py @@ -50,6 +50,9 @@ class HardwareComponent: address: Connection address - IP for TCP, port for CAN auto_enable: Whether to auto-enable servos gripper_joints: Joints that use adapter gripper methods (separate from joints). + domain_id: DDS domain ID for adapters that use DDS transport + (e.g. Unitree G1). Real robot uses 0; unitree_mujoco sim + defaults to 1. Ignored by non-DDS adapters. """ hardware_id: HardwareId @@ -59,6 +62,12 @@ class HardwareComponent: address: str | None = None auto_enable: bool = True gripper_joints: list[JointName] = field(default_factory=list) + domain_id: int = 0 + # Per-joint PD gains used by ConnectedWholeBody when translating + # position commands to MotorCommand. None → adapter/component + # defaults. Must match `joints` length when set. + kp: list[float] | None = None + kd: list[float] | None = None @property def all_joints(self) -> list[JointName]: diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index b0efaa4bdc..f78dc35f16 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -60,6 +60,7 @@ from dimos.msgs.sensor_msgs import ( JointState, ) +from dimos.msgs.std_msgs.Bool import Bool from dimos.teleop.quest.quest_types import ( Buttons, ) @@ -83,22 +84,26 @@ class TaskConfig: Attributes: name: Task name (e.g., "traj_arm") - type: Task type ("trajectory", "servo", "velocity", "cartesian_ik", "teleop_ik") + type: Task type ("trajectory", "servo", "velocity", "cartesian_ik", "teleop_ik", "groot_wbc") joint_names: List of joint names this task controls priority: Task priority (higher wins arbitration) - model_path: Path to URDF/MJCF for IK solver (cartesian_ik/teleop_ik only) + model_path: Path to URDF/MJCF for IK solver (cartesian_ik/teleop_ik) + or directory containing balance.onnx/walk.onnx (groot_wbc). ee_joint_id: End-effector joint ID in model (cartesian_ik/teleop_ik only) hand: "left" or "right" controller hand (teleop_ik only) gripper_joint: Joint name for gripper virtual joint gripper_open_pos: Gripper position at trigger 0.0 gripper_closed_pos: Gripper position at trigger 1.0 + hardware_id: Hardware id this task reads extra state from + (required by groot_wbc — pulls the WholeBodyAdapter for IMU + and the full joint list for observation assembly). """ name: str type: str = "trajectory" joint_names: list[str] = field(default_factory=lambda: []) priority: int = 10 - # Cartesian IK / Teleop IK specific + # Cartesian IK / Teleop IK / GR00T WBC specific model_path: str | Path | None = None ee_joint_id: int = 6 hand: Literal["left", "right"] | None = None # teleop_ik only @@ -106,6 +111,27 @@ class TaskConfig: gripper_joint: str | None = None gripper_open_pos: float = 0.0 gripper_closed_pos: float = 0.0 + # Tasks that need a hardware reference (e.g. groot_wbc for IMU + 29-DOF state) + hardware_id: str | None = None + # Servo task: optional initial target held until/unless a new one arrives. + default_positions: list[float] | None = None + # Call ``task.start()`` right after registration so the task is live + # from the first tick (e.g. GR00T balance/walk needs to drive joints + # immediately). Default False keeps the existing convention where + # tasks wait for an explicit activation (e.g. from teleop). + auto_start: bool = False + # Arm the task's policy automatically on ``start()`` (applies to + # tasks exposing ``arm()``, e.g. ``GrootWBCTask``). Simulation + # blueprints set this True; real-hardware blueprints leave it False + # so the operator arms via dashboard button after settling. + auto_arm: bool = False + # Start the task in dry-run mode (policy computes but output is + # suppressed). For real-hardware safety checks. + auto_dry_run: bool = False + # Ramp duration (seconds) used by ``arm()`` when called without an + # explicit argument — applies to tasks that interpolate from the + # current pose toward a default on arming. + default_ramp_seconds: float = 10.0 @dataclass @@ -178,6 +204,17 @@ class ControlCoordinator(Module[ControlCoordinatorConfig]): # Input: Teleop buttons for engage/disengage signaling buttons: In[Buttons] + # Input: Arm/disarm velocity-policy tasks (e.g. GrootWBCTask). True + # → task.arm(); False → task.disarm(). Routed to every task that + # duck-types an ``arm`` method (and ``disarm`` for False). + activate: In[Bool] + + # Input: Toggle dry-run on velocity-policy tasks. In dry-run the + # policy keeps computing but the coordinator forwards no command to + # the adapter — operators use this to sanity-check commands on real + # hardware before committing motor torques. + dry_run: In[Bool] + config: ControlCoordinatorConfig default_config = ControlCoordinatorConfig @@ -203,6 +240,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._cartesian_command_unsub: Callable[[], None] | None = None self._twist_command_unsub: Callable[[], None] | None = None self._buttons_unsub: Callable[[], None] | None = None + self._activate_unsub: Callable[[], None] | None = None + self._dry_run_unsub: Callable[[], None] | None = None logger.info(f"ControlCoordinator initialized at {self.config.tick_rate}Hz") @@ -222,6 +261,10 @@ def _setup_from_config(self) -> None: for task_cfg in self.config.tasks: task = self._create_task_from_config(task_cfg) self.add_task(task) + if task_cfg.auto_start: + start = getattr(task, "start", None) + if callable(start): + start() except Exception: # Rollback: clean up all successfully added hardware @@ -288,6 +331,7 @@ def _create_whole_body_adapter(self, component: HardwareComponent) -> object: return whole_body_adapter_registry.create( component.adapter_type, network_interface=addr if addr is not None else 0, + domain_id=component.domain_id, ) def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: @@ -308,12 +352,18 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: elif task_type == "servo": from dimos.control.tasks import JointServoTask, JointServoTaskConfig + servo_cfg_kwargs: dict[str, object] = { + "joint_names": cfg.joint_names, + "priority": cfg.priority, + } + if cfg.default_positions is not None: + servo_cfg_kwargs["default_positions"] = cfg.default_positions + # Zero timeout pairs naturally with default-hold — otherwise + # the task times out even though it's holding a valid target. + servo_cfg_kwargs["timeout"] = 0.0 return JointServoTask( cfg.name, - JointServoTaskConfig( - joint_names=cfg.joint_names, - priority=cfg.priority, - ), + JointServoTaskConfig(**servo_cfg_kwargs), # type: ignore[arg-type] ) elif task_type == "velocity": @@ -363,6 +413,43 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: ), ) + elif task_type == "groot_wbc": + from dimos.control.tasks.groot_wbc_task import ( + GrootWBCTask, + GrootWBCTaskConfig, + ) + + if cfg.model_path is None: + raise ValueError( + f"GrootWBCTask '{cfg.name}' requires model_path " + f"(directory containing balance.onnx + walk.onnx)" + ) + if cfg.hardware_id is None: + raise ValueError(f"GrootWBCTask '{cfg.name}' requires hardware_id in TaskConfig") + hw = self._hardware.get(cfg.hardware_id) + if hw is None: + raise ValueError( + f"GrootWBCTask '{cfg.name}' references unknown hardware " + f"'{cfg.hardware_id}'. List the hardware before the task " + f"in the blueprint config." + ) + + model_dir = Path(cfg.model_path) + return GrootWBCTask( + cfg.name, + GrootWBCTaskConfig( + balance_onnx=model_dir / "balance.onnx", + walk_onnx=model_dir / "walk.onnx", + joint_names=cfg.joint_names, + all_joint_names=hw.joint_names, + priority=cfg.priority, + auto_arm=cfg.auto_arm, + auto_dry_run=cfg.auto_dry_run, + default_ramp_seconds=cfg.default_ramp_seconds, + ), + adapter=hw.adapter, + ) + else: raise ValueError(f"Unknown task type: {task_type}") @@ -612,12 +699,52 @@ def _on_twist_command(self, msg: Twist) -> None: joint_state = JointState(name=names, velocity=velocities) self._on_joint_command(joint_state) + # Also route to tasks that accept a (vx, vy, yaw_rate) command — + # e.g. locomotion policies like GrootWBCTask. Duck-typed: any + # task exposing set_velocity_command opts in. + t_now = time.perf_counter() + with self._task_lock: + for task in self._tasks.values(): + set_vel = getattr(task, "set_velocity_command", None) + if set_vel is not None: + set_vel(msg.linear.x, msg.linear.y, msg.angular.z, t_now) + def _on_buttons(self, msg: Buttons) -> None: """Forward button state to all tasks.""" with self._task_lock: for task in self._tasks.values(): task.on_buttons(msg) + def _on_activate(self, msg: Bool) -> None: + """Arm/disarm every task exposing ``arm()`` / ``disarm()``. + + Duck-typed to match the ``set_velocity_command`` convention used + by ``_on_twist_command``. The blueprint wires this input to a + dashboard button; operators can also drive it directly via LCM. + """ + engage = bool(msg.data) + with self._task_lock: + for task in self._tasks.values(): + method_name = "arm" if engage else "disarm" + handler = getattr(task, method_name, None) + if callable(handler): + try: + handler() + except Exception: + logger.exception(f"{method_name}() raised on task {task.name!r}") + + def _on_dry_run(self, msg: Bool) -> None: + """Forward dry-run toggle to every task exposing ``set_dry_run``.""" + enabled = bool(msg.data) + with self._task_lock: + for task in self._tasks.values(): + handler = getattr(task, "set_dry_run", None) + if callable(handler): + try: + handler(enabled) + except Exception: + logger.exception(f"set_dry_run() raised on task {task.name!r}") + @rpc def task_invoke( self, task_name: TaskName, method: str, kwargs: dict[str, Any] | None = None @@ -737,16 +864,25 @@ def start(self) -> None: "Use task_invoke RPC or set transport via blueprint." ) - # Subscribe to twist commands if any twist base hardware configured + # Subscribe to twist commands if any twist base hardware is configured + # OR if any task accepts velocity commands (locomotion policies like + # GrootWBCTask duck-type with set_velocity_command). Without the + # latter check, a whole-body locomotion blueprint with no BASE + # hardware silently drops every Twist on /cmd_vel. has_twist_base = any(c.hardware_type == HardwareType.BASE for c in self.config.hardware) - if has_twist_base: + with self._task_lock: + has_velocity_task = any( + callable(getattr(task, "set_velocity_command", None)) + for task in self._tasks.values() + ) + if has_twist_base or has_velocity_task: try: self._twist_command_unsub = self.twist_command.subscribe(self._on_twist_command) - logger.info("Subscribed to twist_command for twist base control") + logger.info("Subscribed to twist_command for twist base / velocity-capable tasks") except Exception: logger.warning( - "Twist base configured but could not subscribe to twist_command. " - "Use task_invoke RPC or set transport via blueprint." + "Twist base or velocity-capable task configured but could not subscribe " + "to twist_command. Use task_invoke RPC or set transport via blueprint." ) # Subscribe to buttons if any teleop_ik tasks configured (engage/disengage) @@ -755,6 +891,32 @@ def start(self) -> None: self._buttons_unsub = self.buttons.subscribe(self._on_buttons) logger.info("Subscribed to buttons for engage/disengage") + # Subscribe to activate / dry_run if any task exposes arm() / set_dry_run() + # (duck-typed, same convention as twist_command / set_velocity_command). + with self._task_lock: + has_arm = any(callable(getattr(t, "arm", None)) for t in self._tasks.values()) + has_dry_run = any( + callable(getattr(t, "set_dry_run", None)) for t in self._tasks.values() + ) + if has_arm: + try: + self._activate_unsub = self.activate.subscribe(self._on_activate) + logger.info("Subscribed to activate for arm()/disarm() routing") + except Exception: + logger.warning( + "Arm-capable task configured but could not subscribe to activate. " + "Use task_invoke RPC or set transport via blueprint." + ) + if has_dry_run: + try: + self._dry_run_unsub = self.dry_run.subscribe(self._on_dry_run) + logger.info("Subscribed to dry_run for dry-run routing") + except Exception: + logger.warning( + "Dry-run-capable task configured but could not subscribe to dry_run. " + "Use task_invoke RPC or set transport via blueprint." + ) + logger.info(f"ControlCoordinator started at {self.config.tick_rate}Hz") @rpc @@ -772,6 +934,12 @@ def stop(self) -> None: if self._twist_command_unsub: self._twist_command_unsub() self._twist_command_unsub = None + if self._activate_unsub: + self._activate_unsub() + self._activate_unsub = None + if self._dry_run_unsub: + self._dry_run_unsub() + self._dry_run_unsub = None if self._buttons_unsub: self._buttons_unsub() self._buttons_unsub = None diff --git a/dimos/control/examples/go2_standup.py b/dimos/control/examples/go2_standup.py index e85634c789..2414058e68 100644 --- a/dimos/control/examples/go2_standup.py +++ b/dimos/control/examples/go2_standup.py @@ -46,24 +46,48 @@ # Order: FR_0,1,2 FL_0,1,2 RR_0,1,2 RL_0,1,2 # --------------------------------------------------------------------------- POS_CROUCH = [ - 0.0, 1.36, -2.65, # FR - 0.0, 1.36, -2.65, # FL - -0.2, 1.36, -2.65, # RR - 0.2, 1.36, -2.65, # RL + 0.0, + 1.36, + -2.65, # FR + 0.0, + 1.36, + -2.65, # FL + -0.2, + 1.36, + -2.65, # RR + 0.2, + 1.36, + -2.65, # RL ] POS_STAND = [ - 0.0, 0.67, -1.3, # FR - 0.0, 0.67, -1.3, # FL - 0.0, 0.67, -1.3, # RR - 0.0, 0.67, -1.3, # RL + 0.0, + 0.67, + -1.3, # FR + 0.0, + 0.67, + -1.3, # FL + 0.0, + 0.67, + -1.3, # RR + 0.0, + 0.67, + -1.3, # RL ] POS_SHIFT = [ - -0.35, 1.36, -2.65, # FR - 0.35, 1.36, -2.65, # FL - -0.5, 1.36, -2.65, # RR - 0.5, 1.36, -2.65, # RL + -0.35, + 1.36, + -2.65, # FR + 0.35, + 1.36, + -2.65, # FL + -0.5, + 1.36, + -2.65, # RR + 0.5, + 1.36, + -2.65, # RL ] # Per-leg crouch positions (used for single-leg lower) @@ -72,21 +96,21 @@ LEG_STAND = [0.0, 0.67, -1.3] # Leg index ranges within the 12-DOF array -LEG_FR = slice(0, 3) # indices 0,1,2 -LEG_FL = slice(3, 6) # indices 3,4,5 -LEG_RR = slice(6, 9) # indices 6,7,8 +LEG_FR = slice(0, 3) # indices 0,1,2 +LEG_FL = slice(3, 6) # indices 3,4,5 +LEG_RR = slice(6, 9) # indices 6,7,8 LEG_RL = slice(9, 12) # indices 9,10,11 # Phase durations in ticks (at CMD_HZ) CMD_HZ = 50 # Command publish rate (Hz) -PHASE_1_TICKS = 250 # current → crouch -PHASE_2_TICKS = 250 # crouch → stand -PHASE_3_TICKS = 400 # hold stand -PHASE_4_TICKS = 400 # stand → shift +PHASE_1_TICKS = 250 # current → crouch +PHASE_2_TICKS = 250 # crouch → stand +PHASE_3_TICKS = 400 # hold stand +PHASE_4_TICKS = 400 # stand → shift LEG_LOWER_TICKS = 150 # per-leg lower/raise -LEG_HOLD_TICKS = 100 # hold at bottom per leg +LEG_HOLD_TICKS = 100 # hold at bottom per leg DIAG_LOWER_TICKS = 200 # diagonal pair lower/raise -DIAG_HOLD_TICKS = 150 # hold at bottom diagonal +DIAG_HOLD_TICKS = 150 # hold at bottom diagonal class Go2LowLevelControl(Module): @@ -201,9 +225,7 @@ def _run(self) -> None: # Phase 5: Lower each leg one by one logger.info("Phase 5: single leg lowers (FR → FL → RR → RL)") - for leg_name, leg_slice in [ - ("FR", LEG_FR), ("FL", LEG_FL), ("RR", LEG_RR), ("RL", LEG_RL) - ]: + for leg_name, leg_slice in [("FR", LEG_FR), ("FL", LEG_FL), ("RR", LEG_RR), ("RL", LEG_RL)]: if self._stop_event.is_set(): return logger.info(f" Lowering {leg_name}...") @@ -211,9 +233,7 @@ def _run(self) -> None: # Phase 6: Diagonal pairs logger.info("Phase 6: diagonal pair lowers (FR+RL, then FL+RR)") - for pair_name, slices in [ - ("FR+RL", [LEG_FR, LEG_RL]), ("FL+RR", [LEG_FL, LEG_RR]) - ]: + for pair_name, slices in [("FR+RL", [LEG_FR, LEG_RL]), ("FL+RR", [LEG_FL, LEG_RR])]: if self._stop_event.is_set(): return logger.info(f" Lowering {pair_name}...") @@ -239,7 +259,7 @@ def _interp( if self._stop_event.is_set(): return alpha = (t + 1) / ticks - pos = [(1 - alpha) * s + alpha * e for s, e in zip(start, end)] + pos = [(1 - alpha) * s + alpha * e for s, e in zip(start, end, strict=False)] self._pub(names, pos) time.sleep(dt) diff --git a/dimos/control/hardware_interface.py b/dimos/control/hardware_interface.py index f4c13685a2..bdd553a0d5 100644 --- a/dimos/control/hardware_interface.py +++ b/dimos/control/hardware_interface.py @@ -351,6 +351,30 @@ def __init__( self._component = component self._joint_names = component.joints + # Resolve per-joint PD gains once at wire-up time. Falls back + # to _DEFAULT_KP / _DEFAULT_KD if the blueprint didn't specify. + n = len(self._joint_names) + if component.kp is not None: + if len(component.kp) != n: + raise ValueError( + f"HardwareComponent '{component.hardware_id}': kp length " + f"{len(component.kp)} does not match joints length {n}" + ) + self._kp = list(component.kp) + else: + self._kp = [_DEFAULT_KP] * n + if component.kd is not None: + if len(component.kd) != n: + raise ValueError( + f"HardwareComponent '{component.hardware_id}': kd length " + f"{len(component.kd)} does not match joints length {n}" + ) + self._kd = list(component.kd) + else: + self._kd = [_DEFAULT_KD] * n + self._kp_by_name = dict(zip(self._joint_names, self._kp, strict=False)) + self._kd_by_name = dict(zip(self._joint_names, self._kd, strict=False)) + self._last_commanded: dict[str, float] = {} self._initialized = False self._warned_unknown_joints: set[str] = set() @@ -408,8 +432,8 @@ def write_command(self, commands: dict[str, float], _mode: ControlMode) -> bool: MotorCommand( q=self._last_commanded[name], dq=0.0, - kp=_DEFAULT_KP, - kd=_DEFAULT_KD, + kp=self._kp_by_name[name], + kd=self._kd_by_name[name], tau=0.0, ) for name in self._joint_names diff --git a/dimos/control/tasks/groot_wbc_task.py b/dimos/control/tasks/groot_wbc_task.py new file mode 100644 index 0000000000..5a321b9ddf --- /dev/null +++ b/dimos/control/tasks/groot_wbc_task.py @@ -0,0 +1,642 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GR00T whole-body-control task for the Unitree G1 humanoid. + +Runs the two-model GR00T WBC locomotion policy (balance + walk) inside +the coordinator tick loop. Claims the 15 legs+waist joints at high +priority; arm joints are left to lower-priority tasks in the blueprint. + +Reference implementation: g1_control/backends/groot_wbc_backend.py. +Observation, action, and model-selection semantics are preserved +verbatim — changing them drifts us away from the ONNX policies trained +by GR00T-WholeBodyControl. + +CRITICAL: Uses t_now from CoordinatorState, never calls time.time(). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import threading +from typing import TYPE_CHECKING + +import numpy as np +import onnxruntime as ort # type: ignore[import-untyped] + +from dimos.control.task import ( + BaseControlTask, + ControlMode, + CoordinatorState, + JointCommandOutput, + ResourceClaim, +) +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from pathlib import Path + + from dimos.hardware.whole_body.spec import WholeBodyAdapter + from dimos.msgs.geometry_msgs import Twist + +logger = setup_logger() + + +# Default joint angles copied verbatim from +# g1_control/backends/groot_wbc_backend.py DEFAULT_29. Policy was trained +# against these as the zero-offset pose. +_DEFAULT_POSITIONS_29 = [ + -0.1, + 0.0, + 0.0, + 0.3, + -0.2, + 0.0, # left leg + -0.1, + 0.0, + 0.0, + 0.3, + -0.2, + 0.0, # right leg + 0.0, + 0.0, + 0.0, # waist + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, # left arm (not driven by policy) + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, # right arm (not driven by policy) +] + +_SINGLE_OBS_DIM = 86 +_OBS_HISTORY_LEN = 6 +_NUM_ACTIONS = 15 +_NUM_MOTORS = 29 + + +@dataclass +class GrootWBCTaskConfig: + """Configuration for the GR00T WBC task. + + Attributes: + balance_onnx: Path to the balance ONNX model. Used when + ``||cmd|| <= cmd_norm_threshold``. + walk_onnx: Path to the walk ONNX model. Used otherwise. + joint_names: The 15 coordinator joint names this task claims + (legs 0-11 + waist 12-14, in DDS order). + all_joint_names: All 29 coordinator joint names in DDS order + (legs 0-11 + waist 12-14 + arms 15-28). Required to build + the observation, which feeds all 29 joint states. + default_positions_29: Default joint angles for all 29 joints + (DDS order). First 15 are the policy's zero-offset pose. + priority: Arbitration priority (higher wins). 50 is the + recommended WBC priority per the task.py conventions. + decimation: Run inference every N ticks. At 500 Hz tick / + 50 Hz policy → decimation=10. + action_scale: Multiplier on raw policy output before adding + defaults. + obs_ang_vel_scale: Scale for base angular velocity in obs. + obs_dof_pos_scale: Scale for joint position offset in obs. + obs_dof_vel_scale: Scale for joint velocity in obs. + cmd_scale: Per-axis scale applied to (vx, vy, wz) in obs. + cmd_norm_threshold: ||cmd|| below this selects the balance + model, otherwise walk. + height_cmd: Fixed height command slot in obs. + timeout: Seconds without a velocity command before zeroing it. + auto_arm: Arm the policy automatically on ``start()``. Default + False — safe for real hardware; the blueprint sets True for + simulation. + auto_dry_run: Enter dry-run mode on ``start()``. Policy still + runs but outputs are not emitted to the adapter — useful for + verifying on real hardware without commanding motors. + default_ramp_seconds: Duration of the arming ramp (current pose + → ``default_15``) when ``arm()`` is called without an + explicit duration. Set to 0 in simulation (no ramp needed); + 10 s on real hardware mirrors the g1-control-api default. + """ + + balance_onnx: str | Path + walk_onnx: str | Path + joint_names: list[str] + all_joint_names: list[str] + default_positions_29: list[float] = field(default_factory=lambda: list(_DEFAULT_POSITIONS_29)) + priority: int = 50 + decimation: int = 10 + action_scale: float = 0.25 + obs_ang_vel_scale: float = 0.5 + obs_dof_pos_scale: float = 1.0 + obs_dof_vel_scale: float = 0.05 + cmd_scale: tuple[float, float, float] = (2.0, 2.0, 0.5) + cmd_norm_threshold: float = 0.05 + height_cmd: float = 0.74 + timeout: float = 1.0 + auto_arm: bool = False + auto_dry_run: bool = False + default_ramp_seconds: float = 10.0 + + +class GrootWBCTask(BaseControlTask): + """Runs the GR00T balance / walk ONNX policies inside the coordinator tick loop. + + Observation vector (86 dims, built each inference tick, replicates + ``groot_wbc_backend.GrootWBCBackend._compute_obs`` verbatim): + + [0:3] cmd_vel * cmd_scale # scaled velocity command + [3] height_cmd # fixed slot (0.74) + [4:7] (0, 0, 0) # rpy_cmd, zeros + [7:10] gyro * obs_ang_vel_scale # body-frame ang vel + [10:13] projected_gravity(quat) # gravity in body frame + [13:42] (q_29 - default_29) * dof_pos_scale + [42:71] dq_29 * dof_vel_scale + [71:86] last_action (15 dims) + + The observation is stacked into a 6-frame history buffer (516 dims) + before being fed to ONNX. + + Action (15 dims, legs + waist only): + + target_q_15 = action * action_scale + default_15 + + Arms are NOT driven by this task — the blueprint pairs this task + with a lower-priority servo task scoped to the 14 arm joints. + """ + + def __init__( + self, + name: str, + config: GrootWBCTaskConfig, + adapter: WholeBodyAdapter, + ) -> None: + if len(config.joint_names) != _NUM_ACTIONS: + raise ValueError( + f"GrootWBCTask '{name}' requires exactly {_NUM_ACTIONS} joint names " + f"(legs + waist), got {len(config.joint_names)}" + ) + if len(config.all_joint_names) != _NUM_MOTORS: + raise ValueError( + f"GrootWBCTask '{name}' requires exactly {_NUM_MOTORS} all_joint_names " + f"(full 29-DOF G1), got {len(config.all_joint_names)}" + ) + if len(config.default_positions_29) != _NUM_MOTORS: + raise ValueError( + f"GrootWBCTask '{name}' requires exactly {_NUM_MOTORS} " + f"default_positions_29, got {len(config.default_positions_29)}" + ) + if config.decimation < 1: + raise ValueError(f"GrootWBCTask '{name}' requires decimation >= 1") + + self._name = name + self._config = config + self._adapter = adapter + self._joint_names_list = list(config.joint_names) + self._joint_names_set = frozenset(config.joint_names) + self._all_joint_names = list(config.all_joint_names) + + providers = ort.get_available_providers() + self._balance_session = ort.InferenceSession(str(config.balance_onnx), providers=providers) + self._walk_session = ort.InferenceSession(str(config.walk_onnx), providers=providers) + self._balance_input = self._balance_session.get_inputs()[0].name + self._walk_input = self._walk_session.get_inputs()[0].name + logger.info( + f"GrootWBCTask '{name}' loaded balance={config.balance_onnx}, " + f"walk={config.walk_onnx} (providers: {providers})" + ) + + self._default_29 = np.asarray(config.default_positions_29, dtype=np.float32) + self._default_15 = self._default_29[:_NUM_ACTIONS] + self._cmd_scale = np.asarray(config.cmd_scale, dtype=np.float32) + + # Inference state + self._last_action = np.zeros(_NUM_ACTIONS, dtype=np.float32) + self._obs_buf = np.zeros((1, _SINGLE_OBS_DIM * _OBS_HISTORY_LEN), dtype=np.float32) + self._first_inference = True + self._tick_count = 0 + self._last_targets: list[float] | None = None + + # Lifecycle state machine. + # + # _active — task is registered and compute() is being invoked + # by the coordinator. Gate for the whole compute() + # path; start()/stop() toggle this. + # _armed — policy outputs are emitted to the adapter. Flip + # via arm()/disarm(). + # _arming — currently ramping current-pose → default_15 over + # ``_arming_duration`` seconds. Set by arm() with + # a non-zero ramp, cleared when alpha reaches 1. + # _arm_pending — arm() was called; compute() captures the ramp + # start pose from state on the next tick and flips + # into _arming (or _armed directly if ramp=0). + # _dry_run — compute() still runs the policy (obs history + # stays hot) but returns None so the coordinator + # sends no command to the adapter. Throttled log + # lets the operator see what WOULD have gone out. + # + # When active-but-unarmed, compute() echoes back the current + # joint positions so the PD error is zero and the robot sits in + # pure damping (kd-only) — this mirrors the reference backend's + # "hold current pose" inactive state. + self._active = False + self._armed = False + self._arming = False + self._arm_pending = False + self._dry_run = bool(config.auto_dry_run) + self._arming_duration = 0.0 + self._arming_start_t = 0.0 + self._ramp_start: np.ndarray | None = None + self._last_dry_run_log_t: float = 0.0 + + self._cmd_lock = threading.Lock() + self._cmd = np.zeros(3, dtype=np.float32) + self._last_cmd_time: float = 0.0 + + @property + def name(self) -> str: + return self._name + + def claim(self) -> ResourceClaim: + return ResourceClaim( + joints=self._joint_names_set, + priority=self._config.priority, + mode=ControlMode.SERVO_POSITION, + ) + + def is_active(self) -> bool: + return self._active + + def compute(self, state: CoordinatorState) -> JointCommandOutput | None: + if not self._active: + return None + + # Read our 15 claimed joints' current positions — needed for the + # hold-pose / ramp-start / unarmed echo paths below. + current_15 = np.zeros(_NUM_ACTIONS, dtype=np.float32) + for i, jname in enumerate(self._joint_names_list): + pos = state.joints.get_position(jname) + current_15[i] = pos if pos is not None else 0.0 + + # ------------------------------------------------------------------ + # arm() was called — snapshot the ramp start and enter arming / + # armed state (ramp=0 arms immediately). + # ------------------------------------------------------------------ + if self._arm_pending: + self._ramp_start = current_15.copy() + self._arming_start_t = state.t_now + if self._arming_duration > 0.0: + self._arming = True + self._armed = False + logger.info( + f"GrootWBCTask '{self._name}' arming: " + f"ramp → default_15 over {self._arming_duration:.1f}s" + ) + else: + self._arming = False + self._armed = True + self._reset_policy_state() + logger.info(f"GrootWBCTask '{self._name}' armed (no ramp)") + self._arm_pending = False + + # ------------------------------------------------------------------ + # Unarmed & not arming: echo current joint positions. With the + # component's kp/kd applied downstream, q_tgt == q_actual yields + # pure damping (tau = -kd * dq), which mirrors the reference + # backend's inactive "hold current pose" behaviour. + # ------------------------------------------------------------------ + if not self._armed and not self._arming: + self._last_targets = current_15.tolist() + return JointCommandOutput( + joint_names=self._joint_names_list, + positions=self._last_targets, + mode=ControlMode.SERVO_POSITION, + ) + + # ------------------------------------------------------------------ + # Arming: lerp ramp_start → default_15 over arming_duration. + # ------------------------------------------------------------------ + if self._arming: + assert self._ramp_start is not None + elapsed = state.t_now - self._arming_start_t + alpha = ( + 1.0 if self._arming_duration <= 0.0 else min(1.0, elapsed / self._arming_duration) + ) + target = self._ramp_start + alpha * (self._default_15 - self._ramp_start) + self._last_targets = target.tolist() + if alpha >= 1.0: + self._arming = False + self._armed = True + self._reset_policy_state() + logger.info( + f"GrootWBCTask '{self._name}' ramp complete — policy armed " + f"({'dry-run' if self._dry_run else 'live'})" + ) + return JointCommandOutput( + joint_names=self._joint_names_list, + positions=self._last_targets, + mode=ControlMode.SERVO_POSITION, + ) + + # ------------------------------------------------------------------ + # Armed: run the policy. In dry-run mode we still compute (so + # the obs buffer stays hot), but return None so no command goes + # downstream. A throttled log line shows what WOULD have been + # sent, which is how g1-control-api lets operators verify pre-go. + # ------------------------------------------------------------------ + self._tick_count += 1 + + # Decimation: only run inference every N ticks. Between inference + # ticks, re-emit the last target so the coordinator keeps driving + # the joints (or nothing, in dry-run). + if self._tick_count % self._config.decimation != 0: + if self._dry_run or self._last_targets is None: + return None + return JointCommandOutput( + joint_names=self._joint_names_list, + positions=self._last_targets, + mode=ControlMode.SERVO_POSITION, + ) + + # Read all 29 joints from CoordinatorState in DDS order. + q_29 = np.zeros(_NUM_MOTORS, dtype=np.float32) + dq_29 = np.zeros(_NUM_MOTORS, dtype=np.float32) + for i, jname in enumerate(self._all_joint_names): + pos = state.joints.get_position(jname) + vel = state.joints.get_velocity(jname) + q_29[i] = pos if pos is not None else 0.0 + dq_29[i] = vel if vel is not None else 0.0 + + # IMU comes from the adapter, not CoordinatorState. + imu = self._adapter.read_imu() + gyro = np.asarray(imu.gyroscope, dtype=np.float32) + gravity = self._projected_gravity(imu.quaternion) + + # Velocity command (with timeout → zero). + with self._cmd_lock: + if ( + self._config.timeout > 0.0 + and self._last_cmd_time > 0.0 + and (state.t_now - self._last_cmd_time) > self._config.timeout + ): + cmd = np.zeros(3, dtype=np.float32) + else: + cmd = self._cmd.copy() + + obs = self._build_obs(cmd=cmd, gyro=gyro, gravity=gravity, q=q_29, dq=dq_29) + + # History buffer: first inference fills all slots with the current + # obs (warm-start); subsequent ticks roll the window. + if self._first_inference: + tiled = np.tile(obs, _OBS_HISTORY_LEN) + self._obs_buf[0, :] = tiled + self._first_inference = False + else: + self._obs_buf[0, : _SINGLE_OBS_DIM * (_OBS_HISTORY_LEN - 1)] = self._obs_buf[ + 0, _SINGLE_OBS_DIM: + ] + self._obs_buf[0, _SINGLE_OBS_DIM * (_OBS_HISTORY_LEN - 1) :] = obs + + # Model selection: balance when near-stationary, walk otherwise. + cmd_norm = float(np.linalg.norm(cmd)) + if cmd_norm <= self._config.cmd_norm_threshold: + raw = self._balance_session.run(None, {self._balance_input: self._obs_buf})[0] + else: + raw = self._walk_session.run(None, {self._walk_input: self._obs_buf})[0] + + action = raw[0, :_NUM_ACTIONS].astype(np.float32) + self._last_action[:] = action + + target_q_15 = action * self._config.action_scale + self._default_15 + self._last_targets = target_q_15.tolist() + + if self._dry_run: + # Throttled peek at the commanded pose so the operator can + # decide whether it looks sane before flipping dry-run off. + if (state.t_now - self._last_dry_run_log_t) >= 1.0: + max_delta = float(np.max(np.abs(target_q_15 - current_15))) + logger.info( + f"GrootWBCTask '{self._name}' DRY-RUN (|Δq|_max={max_delta:.3f} rad, " + f"model={'walk' if cmd_norm > self._config.cmd_norm_threshold else 'balance'})" + ) + self._last_dry_run_log_t = state.t_now + return None + + return JointCommandOutput( + joint_names=self._joint_names_list, + positions=self._last_targets, + mode=ControlMode.SERVO_POSITION, + ) + + def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: + if joints & self._joint_names_set: + logger.warning(f"GrootWBCTask '{self._name}' preempted by {by_task} on {joints}") + + # ------------------------------------------------------------------ + # Velocity command input + # ------------------------------------------------------------------ + + def set_velocity_command(self, vx: float, vy: float, yaw_rate: float, t_now: float) -> None: + """Set the (vx, vy, yaw_rate) commanded to the policy. + + Called by the coordinator's twist_command dispatcher and by + external Python callers. Thread-safe. + """ + with self._cmd_lock: + self._cmd[:] = [vx, vy, yaw_rate] + self._last_cmd_time = t_now + + def on_twist(self, msg: Twist, t_now: float) -> bool: + """Accept a Twist message, e.g. from an LCM cmd_vel transport.""" + self.set_velocity_command( + float(msg.linear.x), + float(msg.linear.y), + float(msg.angular.z), + t_now, + ) + return True + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def start(self) -> None: + """Enter the coordinator tick loop. + + Starts in "active but unarmed" — compute() echoes current joint + positions every tick, which (combined with the component's + kp/kd) produces damping-only behaviour on real hardware (the + robot sits quietly in dev mode). + + If ``config.auto_arm`` is set, schedules an immediate + ``arm()`` using ``config.default_ramp_seconds`` — this is how + the simulation blueprint bypasses the activation ritual. + If ``config.auto_dry_run`` is set, starts in dry-run mode. + """ + self._active = True + self._armed = False + self._arming = False + self._arm_pending = False + self._dry_run = bool(self._config.auto_dry_run) + self._last_targets = None + self._reset_policy_state() + with self._cmd_lock: + self._cmd[:] = 0.0 + self._last_cmd_time = 0.0 + logger.info( + f"GrootWBCTask '{self._name}' started (unarmed" + + (", dry-run" if self._dry_run else "") + + ")" + ) + if self._config.auto_arm: + self.arm(self._config.default_ramp_seconds) + + def stop(self) -> None: + """Leave the tick loop. Re-activation resets policy state.""" + self._active = False + self._armed = False + self._arming = False + self._arm_pending = False + self._last_targets = None + logger.info(f"GrootWBCTask '{self._name}' stopped") + + # ------------------------------------------------------------------ + # Arming / dry-run (RPC-callable via coordinator.task_invoke) + # ------------------------------------------------------------------ + + def arm(self, ramp_seconds: float | None = None) -> bool: + """Begin the arming sequence. + + ``compute()`` will snapshot the current joint positions on the + next tick, lerp toward ``default_15`` over ``ramp_seconds``, + then flip ``_armed`` true and hand control to the ONNX policy. + A ramp of 0 arms immediately with no interpolation (what sim + uses — the subprocess already holds the MJCF's default pose). + + Safe to call redundantly; subsequent calls while already armed + are ignored. No-op if the task is not ``_active``. + """ + if not self._active: + logger.warning(f"GrootWBCTask '{self._name}' arm() called before start() — ignoring") + return False + if self._armed: + logger.info(f"GrootWBCTask '{self._name}' already armed — arm() ignored") + return False + ramp = ramp_seconds if ramp_seconds is not None else self._config.default_ramp_seconds + self._arming_duration = max(0.0, float(ramp)) + self._arm_pending = True + logger.info( + f"GrootWBCTask '{self._name}' arm requested (ramp={self._arming_duration:.1f}s)" + ) + return True + + def disarm(self) -> bool: + """Stop emitting policy outputs; fall back to hold-current-pose. + + Called either from an operator ``Disarm`` button or from + safety watchdogs. Resets obs history so the next ``arm()`` + starts with a clean buffer. + """ + if not self._armed and not self._arming and not self._arm_pending: + return False + self._armed = False + self._arming = False + self._arm_pending = False + self._ramp_start = None + self._reset_policy_state() + logger.info(f"GrootWBCTask '{self._name}' disarmed (holding current pose)") + return True + + def set_dry_run(self, enabled: bool) -> None: + """Enable/disable dry-run. + + In dry-run the policy still runs (obs history stays hot) but + ``compute()`` returns ``None``, so the coordinator forwards no + command to the adapter. Use to verify policy sanity on real + hardware before committing motor torques. + """ + new_val = bool(enabled) + if new_val == self._dry_run: + return + self._dry_run = new_val + self._last_dry_run_log_t = 0.0 + logger.info(f"GrootWBCTask '{self._name}' dry_run = {new_val}") + + def state_snapshot(self) -> dict[str, object]: + """Return the current state-machine flags for UI / telemetry.""" + return { + "active": self._active, + "armed": self._armed, + "arming": self._arming, + "arm_pending": self._arm_pending, + "dry_run": self._dry_run, + "arming_duration": self._arming_duration, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _reset_policy_state(self) -> None: + """Clear inference state — obs history, last action, tick count.""" + self._last_action[:] = 0.0 + self._obs_buf[:] = 0.0 + self._first_inference = True + self._tick_count = 0 + + def _build_obs( + self, + cmd: np.ndarray, + gyro: np.ndarray, + gravity: np.ndarray, + q: np.ndarray, + dq: np.ndarray, + ) -> np.ndarray: + """Build the 86-dim GR00T observation. Layout matches + ``groot_wbc_backend.py`` exactly.""" + obs = np.zeros(_SINGLE_OBS_DIM, dtype=np.float32) + obs[0:3] = cmd * self._cmd_scale + obs[3] = self._config.height_cmd + obs[4:7] = 0.0 + obs[7:10] = gyro * self._config.obs_ang_vel_scale + obs[10:13] = gravity + obs[13:42] = (q - self._default_29) * self._config.obs_dof_pos_scale + obs[42:71] = dq * self._config.obs_dof_vel_scale + obs[71:86] = self._last_action + return obs + + @staticmethod + def _projected_gravity(quaternion: tuple[float, ...]) -> np.ndarray: + """Project world gravity into body frame. + + Uses Unitree DDS quaternion order (w, x, y, z). Formula matches + ``groot_wbc_backend._get_gravity_orientation`` and is + algebraically equivalent to the Go2 RLPolicyTask helper. + """ + w, x, y, z = quaternion + gx = 2.0 * (-x * z + w * y) + gy = 2.0 * (-y * z - w * x) + gz = -(w * w - x * x - y * y + z * z) + return np.array([gx, gy, gz], dtype=np.float32) + + +__all__ = [ + "GrootWBCTask", + "GrootWBCTaskConfig", +] diff --git a/dimos/control/tasks/servo_task.py b/dimos/control/tasks/servo_task.py index b69b4dd099..96f3a3d24c 100644 --- a/dimos/control/tasks/servo_task.py +++ b/dimos/control/tasks/servo_task.py @@ -46,11 +46,17 @@ class JointServoTaskConfig: joint_names: List of joint names this task controls priority: Priority for arbitration (higher wins) timeout: If no command received for this many seconds, go inactive (0 = never timeout) + default_positions: Optional initial target held until/unless a + new target arrives via set_target(). Must match joint_names + length if provided. Useful for "hold at this pose" tasks + (e.g. arms during whole-body locomotion). Pair with + timeout=0.0 to hold indefinitely. """ joint_names: list[str] priority: int = 10 timeout: float = 0.5 # 500ms default timeout + default_positions: list[float] | None = None class JointServoTask(BaseControlTask): @@ -99,6 +105,15 @@ def __init__(self, name: str, config: JointServoTaskConfig) -> None: self._last_update_time: float = 0.0 self._active = False + if config.default_positions is not None: + if len(config.default_positions) != self._num_joints: + raise ValueError( + f"JointServoTask '{name}': default_positions length " + f"{len(config.default_positions)} does not match " + f"joint_names length {self._num_joints}" + ) + self._target = list(config.default_positions) + logger.info(f"JointServoTask {name} initialized for joints: {config.joint_names}") @property diff --git a/dimos/control/tasks/test_groot_wbc_task.py b/dimos/control/tasks/test_groot_wbc_task.py new file mode 100644 index 0000000000..bc6f5f9c9a --- /dev/null +++ b/dimos/control/tasks/test_groot_wbc_task.py @@ -0,0 +1,463 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for GrootWBCTask. + +ONNX runtime is monkey-patched with a stub that records which model +was called and returns a deterministic action — so the tests exercise +the obs-build, model-selection, decimation, and command-timeout logic +without depending on the actual GR00T ONNX weights. +""" + +from __future__ import annotations + +import math +from typing import Any +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from dimos.control.components import make_humanoid_joints +from dimos.control.task import ( + ControlMode, + CoordinatorState, + JointStateSnapshot, +) +from dimos.control.tasks import groot_wbc_task +from dimos.control.tasks.groot_wbc_task import GrootWBCTask, GrootWBCTaskConfig +from dimos.hardware.whole_body.spec import IMUState + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +class _StubSession: + """ONNX InferenceSession stub that tracks call count and returns a fixed action.""" + + def __init__( + self, + model_path: str, + *, + label: str, + action: np.ndarray, + call_log: list[str], + ) -> None: + self.model_path = model_path + self._label = label + self._action = action + self._call_log = call_log + fake_input = MagicMock() + fake_input.name = "obs" + self._inputs = [fake_input] + + def get_inputs(self) -> list[Any]: + return self._inputs + + def run(self, _outputs: Any, _feed: dict[str, np.ndarray]) -> list[np.ndarray]: + self._call_log.append(self._label) + return [self._action.reshape(1, -1)] + + +@pytest.fixture +def patched_ort(monkeypatch): + """Patch onnxruntime so no real ONNX files are needed.""" + call_log: list[str] = [] + + def _factory(path: str, providers: Any = None) -> _StubSession: + label = "balance" if "balance" in str(path) else "walk" + return _StubSession( + str(path), + label=label, + action=np.full(15, 0.1, dtype=np.float32), + call_log=call_log, + ) + + monkeypatch.setattr(groot_wbc_task.ort, "InferenceSession", _factory) + monkeypatch.setattr( + groot_wbc_task.ort, "get_available_providers", lambda: ["CPUExecutionProvider"] + ) + return call_log + + +@pytest.fixture +def stub_adapter(): + """Stub WholeBodyAdapter returning a zeroed-out IMU (identity quat).""" + adapter = MagicMock() + adapter.read_imu.return_value = IMUState( + quaternion=(1.0, 0.0, 0.0, 0.0), # identity (w, x, y, z) + gyroscope=(0.0, 0.0, 0.0), + accelerometer=(0.0, 0.0, -9.81), + rpy=(0.0, 0.0, 0.0), + ) + return adapter + + +@pytest.fixture +def joints_29(): + return make_humanoid_joints("g1") + + +@pytest.fixture +def task(patched_ort, stub_adapter, joints_29) -> GrootWBCTask: + """Test fixture: auto-armed with no ramp so the existing policy + tests can run compute() immediately after start(). The arming/ + dry-run state-machine has its own dedicated tests below.""" + legs_waist = joints_29[:15] + return GrootWBCTask( + name="groot_wbc", + config=GrootWBCTaskConfig( + balance_onnx="/fake/balance.onnx", + walk_onnx="/fake/walk.onnx", + joint_names=legs_waist, + all_joint_names=joints_29, + priority=50, + auto_arm=True, + default_ramp_seconds=0.0, + ), + adapter=stub_adapter, + ) + + +@pytest.fixture +def unarmed_task(patched_ort, stub_adapter, joints_29) -> GrootWBCTask: + """Fixture mirroring the real-hardware blueprint: active but + unarmed on start(), so arm()/disarm()/set_dry_run() can be + exercised explicitly.""" + legs_waist = joints_29[:15] + return GrootWBCTask( + name="groot_wbc", + config=GrootWBCTaskConfig( + balance_onnx="/fake/balance.onnx", + walk_onnx="/fake/walk.onnx", + joint_names=legs_waist, + all_joint_names=joints_29, + priority=50, + auto_arm=False, + default_ramp_seconds=0.0, + ), + adapter=stub_adapter, + ) + + +def _state_at(t_now: float, joint_names: list[str]) -> CoordinatorState: + snap = JointStateSnapshot( + joint_positions={n: 0.0 for n in joint_names}, + joint_velocities={n: 0.0 for n in joint_names}, + joint_efforts={n: 0.0 for n in joint_names}, + timestamp=t_now, + ) + return CoordinatorState(joints=snap, t_now=t_now, dt=0.002) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_claim_shape(task, joints_29): + claim = task.claim() + assert claim.joints == frozenset(joints_29[:15]) + assert claim.priority == 50 + assert claim.mode == ControlMode.SERVO_POSITION + + +def test_inactive_returns_none(task, joints_29): + state = _state_at(100.0, joints_29) + assert task.compute(state) is None + + +def test_active_zero_cmd_routes_to_balance(task, joints_29, patched_ort): + task.start() + # Decimation=10 → run compute 10 times to force first inference. + state = _state_at(100.0, joints_29) + result = None + for _ in range(10): + result = task.compute(state) + assert result is not None + assert len(result.positions) == 15 + assert patched_ort == ["balance"] + + +def test_nonzero_cmd_routes_to_walk(task, joints_29, patched_ort): + task.start() + task.set_velocity_command(0.5, 0.0, 0.0, t_now=100.0) + state = _state_at(100.0, joints_29) + for _ in range(10): + task.compute(state) + assert patched_ort == ["walk"] + + +def test_decimation_reemits_last_targets(task, joints_29, patched_ort): + """Between inference ticks, the task should repeat the last output.""" + task.start() + state = _state_at(100.0, joints_29) + # First 9 ticks pre-inference: no targets yet. + for _ in range(9): + assert task.compute(state) is None + # 10th tick: inference fires. + first = task.compute(state) + assert first is not None + assert len(patched_ort) == 1 + # Next 9 ticks: no inference, same targets echoed. + for _ in range(9): + echo = task.compute(state) + assert echo is not None + assert echo.positions == first.positions + assert len(patched_ort) == 1 + # 20th tick: second inference. + task.compute(state) + assert len(patched_ort) == 2 + + +def test_velocity_command_timeout(task, joints_29, patched_ort): + task.start() + task.set_velocity_command(0.5, 0.0, 0.0, t_now=100.0) + # Still inside the 1.0s timeout — walk. + state_inside = _state_at(100.5, joints_29) + for _ in range(10): + task.compute(state_inside) + # Past the timeout — command goes to zero → balance. + state_outside = _state_at(102.0, joints_29) + for _ in range(10): + task.compute(state_outside) + assert patched_ort == ["walk", "balance"] + + +def test_projected_gravity_identity_quat(): + g = GrootWBCTask._projected_gravity((1.0, 0.0, 0.0, 0.0)) + np.testing.assert_allclose(g, np.array([0.0, 0.0, -1.0]), atol=1e-6) + + +def test_projected_gravity_roll_90(): + """+90° roll around body-X: body-Y now points world-up, body-Z world-right. + World gravity (0,0,-1) expressed in body frame is (0, -1, 0).""" + s = math.sin(math.pi / 4.0) + c = math.cos(math.pi / 4.0) + g = GrootWBCTask._projected_gravity((c, s, 0.0, 0.0)) + np.testing.assert_allclose(g, np.array([0.0, -1.0, 0.0]), atol=1e-6) + + +def test_projected_gravity_pitch_90(): + """+90° pitch around body-Y: body-X now points world-down, body-Z world-forward. + World gravity (0,0,-1) expressed in body frame is (+1, 0, 0).""" + s = math.sin(math.pi / 4.0) + c = math.cos(math.pi / 4.0) + g = GrootWBCTask._projected_gravity((c, 0.0, s, 0.0)) + np.testing.assert_allclose(g, np.array([1.0, 0.0, 0.0]), atol=1e-6) + + +def test_obs_build_layout(task): + """Verify the 86-dim obs respects the documented slot layout.""" + cmd = np.array([1.0, 0.5, 0.25], dtype=np.float32) + gyro = np.array([0.1, 0.2, 0.3], dtype=np.float32) + gravity = np.array([0.0, 0.0, -1.0], dtype=np.float32) + q = np.zeros(29, dtype=np.float32) + dq = np.ones(29, dtype=np.float32) + obs = task._build_obs(cmd=cmd, gyro=gyro, gravity=gravity, q=q, dq=dq) + assert obs.shape == (86,) + np.testing.assert_allclose(obs[0:3], cmd * np.array([2.0, 2.0, 0.5])) + assert obs[3] == pytest.approx(0.74) + np.testing.assert_array_equal(obs[4:7], np.zeros(3)) + np.testing.assert_allclose(obs[7:10], gyro * 0.5) + np.testing.assert_array_equal(obs[10:13], gravity) + # q - default_29 → legs/waist get nonzero offsets from DEFAULT_15, + # arms (indices 15..28 in DEFAULT_29) are zero, so obs[28:42] == 0. + np.testing.assert_array_equal(obs[28:42], np.zeros(14)) + np.testing.assert_allclose(obs[42:71], dq * 0.05) + np.testing.assert_array_equal(obs[71:86], np.zeros(15)) + + +def test_first_inference_fills_history(task, joints_29, patched_ort): + """First inference should tile current obs across all 6 history slots.""" + task.start() + state = _state_at(100.0, joints_29) + for _ in range(10): + task.compute(state) + # History has 6 identical 86-dim slices. + buf = task._obs_buf[0] + assert buf.shape == (86 * 6,) + slice0 = buf[0:86] + for k in range(1, 6): + np.testing.assert_array_equal(buf[86 * k : 86 * (k + 1)], slice0) + + +def test_start_resets_state(task, joints_29, patched_ort): + task.start() + state = _state_at(100.0, joints_29) + for _ in range(10): + task.compute(state) + assert np.any(task._last_action != 0.0) + assert task._last_targets is not None + + task.stop() + assert task._last_targets is None + + task.start() + # After restart, tick counter is zero, last_action cleared, first-inference flag set. + assert task._tick_count == 0 + np.testing.assert_array_equal(task._last_action, np.zeros(15, dtype=np.float32)) + assert task._first_inference is True + + +def test_on_twist_routes_to_velocity_cmd(task): + msg = MagicMock() + msg.linear.x = 0.7 + msg.linear.y = -0.2 + msg.angular.z = 0.4 + task.on_twist(msg, t_now=12.34) + np.testing.assert_allclose(task._cmd, np.array([0.7, -0.2, 0.4], dtype=np.float32), atol=1e-6) + assert task._last_cmd_time == 12.34 + + +def test_joint_count_validation(patched_ort, stub_adapter, joints_29): + with pytest.raises(ValueError, match="15 joint names"): + GrootWBCTask( + name="bad", + config=GrootWBCTaskConfig( + balance_onnx="/fake/balance.onnx", + walk_onnx="/fake/walk.onnx", + joint_names=joints_29[:10], # wrong size + all_joint_names=joints_29, + ), + adapter=stub_adapter, + ) + with pytest.raises(ValueError, match="29 all_joint_names"): + GrootWBCTask( + name="bad", + config=GrootWBCTaskConfig( + balance_onnx="/fake/balance.onnx", + walk_onnx="/fake/walk.onnx", + joint_names=joints_29[:15], + all_joint_names=joints_29[:20], # wrong size + ), + adapter=stub_adapter, + ) + + +# --------------------------------------------------------------------------- +# Arming / dry-run state machine +# --------------------------------------------------------------------------- + + +def test_unarmed_holds_current_pose(unarmed_task, joints_29, patched_ort): + """Active but unarmed → compute() echoes current joint positions + every tick. Downstream PD with q_tgt == q_actual → damping only.""" + unarmed_task.start() + snap = JointStateSnapshot( + joint_positions={n: 0.0 for n in joints_29}, + joint_velocities={n: 0.0 for n in joints_29}, + joint_efforts={n: 0.0 for n in joints_29}, + timestamp=100.0, + ) + # Set some non-zero current positions for the 15 claimed joints. + for i, n in enumerate(joints_29[:15]): + snap.joint_positions[n] = 0.1 * (i + 1) + state = CoordinatorState(joints=snap, t_now=100.0, dt=0.002) + for _ in range(30): + out = unarmed_task.compute(state) + assert out is not None + # No inference while unarmed. + assert patched_ort == [] + # Output tracks current pose exactly. + np.testing.assert_allclose(out.positions, [0.1 * (i + 1) for i in range(15)], atol=1e-6) + + +def test_arm_no_ramp_goes_straight_to_policy(unarmed_task, joints_29, patched_ort): + """arm(0.0) → immediately armed → policy runs on the next decimation tick.""" + unarmed_task.start() + unarmed_task.arm(ramp_seconds=0.0) + state = _state_at(100.0, joints_29) + # First compute after arm(): snapshots ramp_start, flips armed=True (ramp=0). + unarmed_task.compute(state) + assert unarmed_task._armed + # 9 more ticks to hit decimation threshold (10th is inference). + for _ in range(9): + unarmed_task.compute(state) + assert patched_ort == ["balance"] + + +def test_arm_with_ramp_lerps_over_duration(unarmed_task, joints_29, patched_ort): + """arm(1.0) → lerp from current pose to default_15 over 1 second.""" + unarmed_task.start() + unarmed_task.arm(ramp_seconds=1.0) + # First tick: snapshot ramp_start (all zeros). + state0 = _state_at(0.0, joints_29) + out0 = unarmed_task.compute(state0) + assert out0 is not None + assert unarmed_task._arming + # alpha=0 → output == ramp_start (all zeros here). + np.testing.assert_allclose(out0.positions, [0.0] * 15, atol=1e-6) + # Halfway through: alpha=0.5. + state_mid = _state_at(0.5, joints_29) + out_mid = unarmed_task.compute(state_mid) + default_15 = list(groot_wbc_task._DEFAULT_POSITIONS_29[:15]) + expected_mid = [0.5 * d for d in default_15] + np.testing.assert_allclose(out_mid.positions, expected_mid, atol=1e-6) + # End: alpha=1 → armed flips, output == default_15. + state_end = _state_at(1.0, joints_29) + unarmed_task.compute(state_end) + assert unarmed_task._armed + assert not unarmed_task._arming + # Policy has NOT run yet — ramp completion doesn't trigger inference. + assert patched_ort == [] + + +def test_dry_run_suppresses_output_but_runs_inference(task, joints_29, patched_ort): + """Dry-run: policy still computes (obs history stays hot), but + compute() returns None so the adapter sees no command.""" + task.start() # fixture has auto_arm=True, so armed immediately + task.set_dry_run(True) + state = _state_at(100.0, joints_29) + # 10 ticks → first inference fires under the hood, but output is None. + for _ in range(10): + out = task.compute(state) + assert out is None + # Policy DID run — obs buffer is hot. + assert patched_ort == ["balance"] + + +def test_dry_run_toggle_off_resumes_output(task, joints_29, patched_ort): + """Flipping dry_run from True → False resumes normal output.""" + task.start() + task.set_dry_run(True) + state = _state_at(100.0, joints_29) + for _ in range(10): + task.compute(state) + assert patched_ort == ["balance"] # ran during dry-run + task.set_dry_run(False) + # Next inference tick: output is non-None. + for _ in range(10): + out = task.compute(state) + assert out is not None + assert len(out.positions) == 15 + + +def test_disarm_returns_to_hold_pose(unarmed_task, joints_29, patched_ort): + """Disarm after policy has run → compute() falls back to echoing pose.""" + unarmed_task.start() + unarmed_task.arm(ramp_seconds=0.0) + state = _state_at(100.0, joints_29) + for _ in range(10): + unarmed_task.compute(state) + assert patched_ort == ["balance"] + assert unarmed_task._armed + + unarmed_task.disarm() + assert not unarmed_task._armed + # Policy should NOT run again. + for _ in range(30): + unarmed_task.compute(state) + assert patched_ort == ["balance"] # still just one call diff --git a/dimos/hardware/whole_body/mujoco/__init__.py b/dimos/hardware/whole_body/mujoco/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/hardware/whole_body/mujoco/g1/__init__.py b/dimos/hardware/whole_body/mujoco/g1/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/hardware/whole_body/mujoco/g1/adapter.py b/dimos/hardware/whole_body/mujoco/g1/adapter.py new file mode 100644 index 0000000000..b5c1aef59d --- /dev/null +++ b/dimos/hardware/whole_body/mujoco/g1/adapter.py @@ -0,0 +1,207 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MuJoCo simulation ``WholeBodyAdapter`` for the Unitree G1. + +Delegates to the existing ``MujocoConnection`` subprocess (the same +infrastructure ``unitree-go2 --simulation`` uses), running in +"low-level passthrough" mode: the subprocess owns the MuJoCo world + +viewer; this adapter reads per-joint state and writes per-joint +commands through shared memory. + +That choice — reusing the battle-tested subprocess pattern — is what +lets ``dimos --simulation run unitree-g1-groot-wbc`` open the viewer +on macOS without the user prefixing ``mjpython``. The subprocess is +auto-spawned under ``mjpython`` on macOS by ``MujocoConnection`` +(``mujoco_connection.py:124``). +""" + +from __future__ import annotations + +import math +import time +from typing import TYPE_CHECKING, Any + +import numpy as np + +from dimos.core.global_config import global_config as _global_config +from dimos.hardware.whole_body.spec import ( + POS_STOP, + IMUState, + MotorCommand, + MotorState, +) +from dimos.robot.unitree.mujoco_connection import MujocoConnection +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.core.global_config import GlobalConfig + from dimos.hardware.whole_body.registry import WholeBodyAdapterRegistry + +logger = setup_logger() + +_NUM_MOTORS = 29 + + +class SimMujocoG1WholeBodyAdapter: + """Whole-body adapter backed by a ``MujocoConnection`` in low-level mode. + + The connection spawns the standard ``mujoco_process.py`` subprocess + (``dimos/simulation/mujoco/mujoco_process.py``), auto-selecting + ``mjpython`` on macOS, and passes ``control_mode="low_level"`` so + the subprocess skips its baked locomotion ONNX and instead reads + per-joint commands from shared memory. + + ``GlobalConfig.robot_model`` must be ``"unitree_g1"`` (the blueprint + sets this) so the subprocess loads the G1 MJCF. ``mujoco_room`` + controls which scene wraps the robot (default ``"office1"``; the + blueprint overrides to ``"empty"`` for a flat floor). + + Args: + network_interface: Unused; kept for adapter-registry kwarg + symmetry with the DDS adapter. + domain_id: Unused; same reason. + cfg: Global config to pass to the subprocess. Defaults to the + process-wide ``global_config`` (what the CLI populates). + """ + + def __init__( + self, + network_interface: int | str = 0, + domain_id: int = 0, + cfg: GlobalConfig | None = None, + **_: Any, + ) -> None: + # Force the two MuJoCo-subprocess-relevant knobs on our own copy + # of the config, regardless of what the worker's ``global_config`` + # singleton says. The worker process starts fresh (forkserver + # spawn), so blueprint-level ``.global_config(robot_model=...)`` + # overrides applied in the main process do NOT propagate into + # the worker's singleton. We hard-pin G1 + empty scene here so + # the subprocess always loads the right model. + base = cfg if cfg is not None else _global_config + self._cfg = base.model_copy(update={"robot_model": "unitree_g1", "mujoco_room": "empty"}) + self._connection: MujocoConnection | None = None + self._connected = False + # Warn once if downstream consumers try to use the adapter + # before the first state packet lands in shm. + self._warned_no_state = False + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def connect(self) -> bool: + try: + self._connection = MujocoConnection(self._cfg, control_mode="low_level") + self._connection.start() + # Block briefly until the child has actually produced a state + # packet, so the first read_motor_states() returns valid data + # (otherwise the coordinator's first tick sees zeros and the + # WBC task builds a junk obs). + deadline = time.time() + 5.0 + while time.time() < deadline: + if self._connection.read_motor_states(_NUM_MOTORS) is not None: + break + time.sleep(0.05) + self._connected = True + logger.info("SimMujocoG1WholeBodyAdapter connected (subprocess ready)") + return True + except Exception as e: + logger.error(f"Failed to start MuJoCo G1 sim subprocess: {e}") + self._connected = False + return False + + def disconnect(self) -> None: + if self._connection is not None: + try: + self._connection.stop() + except Exception as e: # best-effort cleanup + logger.warning(f"MuJoCo sim subprocess stop raised: {e}") + self._connection = None + self._connected = False + + def is_connected(self) -> bool: + return self._connected and self._connection is not None + + # ------------------------------------------------------------------ + # IO (WholeBodyAdapter protocol) + # ------------------------------------------------------------------ + + def read_motor_states(self) -> list[MotorState]: + if not self._connected or self._connection is None: + return [MotorState()] * _NUM_MOTORS + arr = self._connection.read_motor_states(_NUM_MOTORS) + if arr is None: + if not self._warned_no_state: + logger.warning("MuJoCo subprocess has not produced any state yet") + self._warned_no_state = True + return [MotorState()] * _NUM_MOTORS + return [ + MotorState(q=float(arr[i, 0]), dq=float(arr[i, 1]), tau=float(arr[i, 2])) + for i in range(_NUM_MOTORS) + ] + + def read_imu(self) -> IMUState: + if not self._connected or self._connection is None: + return IMUState() + arr = self._connection.read_imu_sensor() + if arr is None or len(arr) < 10: + return IMUState() + w, x, y, z = (float(arr[0]), float(arr[1]), float(arr[2]), float(arr[3])) + gyro = (float(arr[4]), float(arr[5]), float(arr[6])) + accel = (float(arr[7]), float(arr[8]), float(arr[9])) + # Derive ZYX Euler from the quaternion — matches the real G1 adapter. + sinr = 2.0 * (w * x + y * z) + cosr = 1.0 - 2.0 * (x * x + y * y) + roll = math.atan2(sinr, cosr) + sinp = 2.0 * (w * y - z * x) + pitch = math.copysign(math.pi / 2.0, sinp) if abs(sinp) >= 1.0 else math.asin(sinp) + siny = 2.0 * (w * z + x * y) + cosy = 1.0 - 2.0 * (y * y + z * z) + yaw = math.atan2(siny, cosy) + return IMUState( + quaternion=(w, x, y, z), + gyroscope=gyro, + accelerometer=accel, + rpy=(roll, pitch, yaw), + ) + + def write_motor_commands(self, commands: list[MotorCommand]) -> bool: + if not self._connected or self._connection is None: + return False + if len(commands) != _NUM_MOTORS: + logger.error( + f"SimMujocoG1WholeBodyAdapter: expected {_NUM_MOTORS} commands, got {len(commands)}" + ) + return False + q = np.empty(_NUM_MOTORS, dtype=np.float32) + kp = np.empty(_NUM_MOTORS, dtype=np.float32) + kd = np.empty(_NUM_MOTORS, dtype=np.float32) + for i, cmd in enumerate(commands): + # POS_STOP ("no command") — write current state back as the + # target so the subprocess doesn't see a stale target drift. + q[i] = cmd.q if cmd.q != POS_STOP else 0.0 + kp[i] = cmd.kp + kd[i] = cmd.kd + self._connection.write_motor_commands(q, kp, kd) + return True + + +def register(registry: WholeBodyAdapterRegistry) -> None: + """Register with the whole-body adapter registry.""" + registry.register("sim_mujoco_g1", SimMujocoG1WholeBodyAdapter) + + +__all__ = ["SimMujocoG1WholeBodyAdapter"] diff --git a/dimos/hardware/whole_body/unitree/g1/adapter.py b/dimos/hardware/whole_body/unitree/g1/adapter.py index a520099b16..1f794d453d 100644 --- a/dimos/hardware/whole_body/unitree/g1/adapter.py +++ b/dimos/hardware/whole_body/unitree/g1/adapter.py @@ -69,10 +69,19 @@ class UnitreeG1LowLevelAdapter: Args: network_interface: DDS network interface name or ID (default: "eth0"). + domain_id: DDS domain ID. Real robot uses 0; unitree_mujoco sim + defaults to 1. Changing domain lets the same adapter bind to + sim or real with no code change. """ - def __init__(self, network_interface: int | str = 0, **_: object) -> None: + def __init__( + self, + network_interface: int | str = 0, + domain_id: int = 0, + **_: object, + ) -> None: self._network_interface = network_interface + self._domain_id = domain_id self._connected = False self._lock = threading.Lock() @@ -105,11 +114,16 @@ def connect(self) -> bool: from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_, LowState_ from unitree_sdk2py.utils.crc import CRC - # 1. Initialise DDS transport + # 1. Initialise DDS transport. NOTE: the cyclonedds Python + # wheel reads CYCLONEDDS_HOME at runtime — it must point at + # the local cyclonedds install (e.g. ~/cyclonedds/install) + # before this call or DDS topic creation later fails with + # PRECONDITION_NOT_MET. Add to your shell rc. logger.info( - f"Initializing DDS (G1 low-level) with interface {self._network_interface}..." + f"Initializing DDS (G1 low-level) with interface {self._network_interface} " + f"on domain {self._domain_id}..." ) - ChannelFactoryInitialize(0, self._network_interface) + ChannelFactoryInitialize(self._domain_id, self._network_interface) # 2. Create publisher / subscriber self._publisher = ChannelPublisher("rt/lowcmd", LowCmd_) @@ -150,8 +164,8 @@ def connect(self) -> bool: logger.info(f"G1 low-level adapter connected (mode_machine={self._mode_machine})") return True - except Exception as e: - logger.error(f"Failed to connect G1 low-level adapter: {e}") + except Exception: + logger.exception("Failed to connect G1 low-level adapter (full traceback):") self._connected = False return False diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 028ab12439..00abd43bb9 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -68,6 +68,7 @@ "unitree-g1-dds-coordinator": "dimos.robot.unitree.g1.blueprints.basic.unitree_g1_dds_coordinator:unitree_g1_dds_coordinator", "unitree-g1-detection": "dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1_detection:unitree_g1_detection", "unitree-g1-full": "dimos.robot.unitree.g1.blueprints.agentic.unitree_g1_full:unitree_g1_full", + "unitree-g1-groot-wbc": "dimos.robot.unitree.g1.blueprints.basic.unitree_g1_groot_wbc:unitree_g1_groot_wbc", "unitree-g1-joystick": "dimos.robot.unitree.g1.blueprints.basic.unitree_g1_joystick:unitree_g1_joystick", "unitree-g1-lowlevel": "dimos.robot.unitree.g1.blueprints.basic.unitree_g1_lowlevel:unitree_g1_lowlevel", "unitree-g1-playback": "dimos.robot.unitree.g1.blueprints.basic.unitree_g1_playback:unitree_g1_playback", diff --git a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_groot_wbc.py b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_groot_wbc.py new file mode 100644 index 0000000000..4037f15010 --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_groot_wbc.py @@ -0,0 +1,266 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unitree G1 GR00T whole-body-control blueprint. + +Runs the ControlCoordinator at 500 Hz with two tasks: + + - ``groot_wbc`` (priority 50) claims legs + waist (15 DOF) and runs + the GR00T balance / walk ONNX policies at 50 Hz. + - ``servo_arms`` (priority 10) claims the 14 arm joints and holds + them at a configured relaxed pose. No timeout — the task holds + until an external caller sends new arm targets. + +Velocity commands come from the dashboard's KeyboardControlPanel +(http://localhost:7779/, WASD captured in the browser DOM) and are +routed through ``WebsocketVisModule`` → LCM ``/g1/cmd_vel`` → +coordinator ``twist_command`` → ``GrootWBCTask.set_velocity_command``. + +Architecture: + dashboard WASD ──▶ WebsocketVisModule ──▶ LCM /g1/cmd_vel + │ + coordinator twist_command ──▶ GrootWBCTask + │ + ControlCoordinator ──joint_state──▶ LCM /coordinator/joint_state + ◀─joint_command── LCM /g1/joint_command + │ + WholeBodyAdapter: + --simulation → SimMujocoG1WholeBodyAdapter + (MujocoConnection subprocess, + low-level passthrough) + real hardware → UnitreeG1LowLevelAdapter (DDS) + +Usage: + dimos --simulation run unitree-g1-groot-wbc # MuJoCo viewer, browser opens auto + ROBOT_INTERFACE=en7 dimos run unitree-g1-groot-wbc # real robot (set CYCLONEDDS_HOME first) + +Environment: + ROBOT_INTERFACE DDS network interface for real robot (default "enp86s0"). + Ignored under --simulation. + DIMOS_DDS_DOMAIN DDS domain id for real robot (default 0). Ignored + under --simulation. + CYCLONEDDS_HOME Required at runtime on real hw — must point at the + cyclonedds C install (e.g. ~/cyclonedds/install). + Ignored under --simulation. + GROOT_MODEL_DIR Directory containing balance.onnx + walk.onnx + (default "data/groot"). +""" + +from __future__ import annotations + +import os + +from dimos.control.components import ( + HardwareComponent, + HardwareType, + make_humanoid_joints, +) +from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs import Twist +from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.std_msgs.Bool import Bool as DimosBool +from dimos.web.websocket_vis.websocket_vis_module import websocket_vis + +_g1_joints = make_humanoid_joints("g1") +_g1_legs_waist = _g1_joints[:15] # indices 0..14 — legs (12) + waist (3) +_g1_arms = _g1_joints[15:] # indices 15..28 — left arm (7) + right arm (7) + +# Per-joint PD gains, 29 entries in DDS motor order. Lifted verbatim +# from g1-control-api/configs/g1_groot_wbc.yaml, which itself copies +# GR00T-WBC's g1_29dof_gear_wbc.yaml reference config. These gains +# were the ones the balance / walk ONNX policies were trained against +# — diverging from them on real hardware risks instability. +_G1_GROOT_KP = [ + 150.0, + 150.0, + 150.0, + 200.0, + 40.0, + 40.0, # left leg + 150.0, + 150.0, + 150.0, + 200.0, + 40.0, + 40.0, # right leg + 250.0, + 250.0, + 250.0, # waist + 100.0, + 100.0, + 40.0, + 40.0, + 20.0, + 20.0, + 20.0, # left arm + 100.0, + 100.0, + 40.0, + 40.0, + 20.0, + 20.0, + 20.0, # right arm +] +_G1_GROOT_KD = [ + 2.0, + 2.0, + 2.0, + 4.0, + 2.0, + 2.0, # left leg + 2.0, + 2.0, + 2.0, + 4.0, + 2.0, + 2.0, # right leg + 5.0, + 5.0, + 5.0, # waist + 5.0, + 5.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, # left arm + 5.0, + 5.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, # right arm +] + +# Relaxed arms-down pose. Values taken from +# g1_control/backends/groot_wbc_backend.py:DEFAULT_29[15:] (all zeros), +# which is the zero-offset pose the policy was trained against. +# Operators can override at runtime by publishing joint targets on the +# arms via the joint_command transport. +_ARM_DEFAULT_POSE = [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, # left arm (7 DOF) + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, # right arm (7 DOF) +] + +_adapter_type = "sim_mujoco_g1" if global_config.simulation else "unitree_g1" +_address = None if global_config.simulation else os.getenv("ROBOT_INTERFACE", "enp86s0") + +# Arming defaults: sim auto-arms (the MuJoCo subprocess holds the MJCF +# pose until first command, no ramp needed); real hardware comes up +# unarmed + dry-run so the operator can see computed commands before +# committing motor torques, then hit Activate in the dashboard for a +# 10 s ramp to the bent-knee default (mirrors g1-control-api). +_AUTO_ARM = global_config.simulation +_AUTO_DRY_RUN = not global_config.simulation +_DEFAULT_RAMP_SECONDS = 0.0 if global_config.simulation else 10.0 + +_g1_coordinator = ( + control_coordinator( + tick_rate=500.0, + publish_joint_state=True, + joint_state_frame_id="coordinator", + hardware=[ + HardwareComponent( + hardware_id="g1", + hardware_type=HardwareType.WHOLE_BODY, + joints=_g1_joints, + adapter_type=_adapter_type, + address=_address, + domain_id=int(os.getenv("DIMOS_DDS_DOMAIN", "0")), + auto_enable=True, + kp=_G1_GROOT_KP, + kd=_G1_GROOT_KD, + ), + ], + tasks=[ + TaskConfig( + name="groot_wbc", + type="groot_wbc", + joint_names=_g1_legs_waist, + priority=50, + model_path=os.getenv("GROOT_MODEL_DIR", "data/groot"), + hardware_id="g1", + auto_start=True, + auto_arm=_AUTO_ARM, + auto_dry_run=_AUTO_DRY_RUN, + default_ramp_seconds=_DEFAULT_RAMP_SECONDS, + ), + TaskConfig( + name="servo_arms", + type="servo", + joint_names=_g1_arms, + priority=10, + default_positions=_ARM_DEFAULT_POSE, + auto_start=True, + ), + ], + ) + .transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("joint_command", JointState): LCMTransport("/g1/joint_command", JointState), + ("twist_command", Twist): LCMTransport("/g1/cmd_vel", Twist), + ("activate", DimosBool): LCMTransport("/g1/activate", DimosBool), + ("dry_run", DimosBool): LCMTransport("/g1/dry_run", DimosBool), + } + ) + .global_config( + # Picked up by MujocoConnection → mujoco_process.py when the blueprint + # is run with --simulation. robot_model selects which MJCF the sim + # child loads; mujoco_room wraps it in a flat floor (vs the default + # "office1" room used by the perceptive G1 sim blueprint). + robot_model="unitree_g1", + mujoco_room="empty", + ) +) + + +# WASD teleop via the web dashboard (http://localhost:7779/) served by +# WebsocketVisModule. The bundled React command-center at +# ``data/command_center.html`` includes a KeyboardControlPanel that +# captures W/S/A/D on keydown/keyup and emits ``move_command`` events +# which the module re-publishes on its ``cmd_vel`` port. We route that +# over LCM to the coordinator's ``twist_command`` port on /g1/cmd_vel. +# +# This replaces the pygame-based ``keyboard_teleop`` module because +# pygame's pygame.display.set_mode() calls NSWindow on macOS, and Cocoa +# rejects NSWindow creation from non-main threads — which is where +# dimos runs module code. A browser tab has no such constraint. +_g1_ws_vis = websocket_vis().transports( + { + ("cmd_vel", Twist): LCMTransport("/g1/cmd_vel", Twist), + ("activate", DimosBool): LCMTransport("/g1/activate", DimosBool), + ("dry_run", DimosBool): LCMTransport("/g1/dry_run", DimosBool), + }, +) + +unitree_g1_groot_wbc = autoconnect(_g1_coordinator, _g1_ws_vis) + +__all__ = ["unitree_g1_groot_wbc"] diff --git a/dimos/robot/unitree/mujoco_connection.py b/dimos/robot/unitree/mujoco_connection.py index f998ae1dd9..f51ddc20b3 100644 --- a/dimos/robot/unitree/mujoco_connection.py +++ b/dimos/robot/unitree/mujoco_connection.py @@ -60,7 +60,11 @@ class MujocoConnection: """MuJoCo simulator connection that runs in a separate subprocess.""" - def __init__(self, global_config: GlobalConfig) -> None: + def __init__( + self, + global_config: GlobalConfig, + control_mode: str = "high_level", + ) -> None: try: import mujoco # noqa: F401 except ImportError: @@ -76,6 +80,11 @@ def __init__(self, global_config: GlobalConfig) -> None: mjx_env.ensure_menagerie_exists() self.global_config = global_config + # "high_level" = subprocess runs its baked ONNX locomotion policy + # (Twist in, sensor streams out). "low_level" = subprocess + # bypasses the policy and applies per-joint commands from shm + # (used by the dimos ControlCoordinator's sim adapters). + self.control_mode = control_mode self.process: subprocess.Popen[bytes] | None = None self.shm_data: ShmWriter | None = None self._last_video_seq = 0 @@ -125,7 +134,13 @@ def start(self) -> None: executable = sys.executable if sys.platform != "darwin" else "mjpython" self.process = subprocess.Popen( - [executable, str(LAUNCHER_PATH), config_pickle, shm_names_json], + [ + executable, + str(LAUNCHER_PATH), + config_pickle, + shm_names_json, + self.control_mode, + ], ) except Exception as e: @@ -326,6 +341,37 @@ def get_video_as_image() -> Image | None: return self._create_stream(get_video_as_image, VIDEO_FPS, "Video") + # --- Low-level passthrough (parent-side API, when control_mode="low_level") --- + + def write_motor_commands( + self, + q: NDArray[Any], + kp: NDArray[Any], + kd: NDArray[Any], + ) -> None: + """Write per-joint (q, kp, kd) into the subprocess shm. + + Only meaningful when the connection was started with + ``control_mode="low_level"``. The subprocess applies ``q`` to + ``data.ctrl``; ``kp`` / ``kd`` are currently advisory (the MJCF + carries baked position-actuator gains). + """ + if self._is_cleaned_up or self.shm_data is None: + return + self.shm_data.write_joint_cmd(q, kp, kd) + + def read_motor_states(self, num_motors: int) -> NDArray[Any] | None: + """Return an (num_motors, 3) float32 array of (q, dq, tau), or None.""" + if self._is_cleaned_up or self.shm_data is None: + return None + return self.shm_data.read_joint_state(num_motors) + + def read_imu_sensor(self) -> NDArray[Any] | None: + """Return a 10-element [quat(4), gyro(3), accel(3)] float32, or None.""" + if self._is_cleaned_up or self.shm_data is None: + return None + return self.shm_data.read_imu() + def move(self, twist: Twist, duration: float = 0.0) -> bool: if self._is_cleaned_up or self.shm_data is None: return True diff --git a/dimos/simulation/mujoco/model.py b/dimos/simulation/mujoco/model.py index bc309b7307..e2fba3544b 100644 --- a/dimos/simulation/mujoco/model.py +++ b/dimos/simulation/mujoco/model.py @@ -56,8 +56,19 @@ def get_assets() -> dict[str, bytes]: def load_model( - input_device: InputController, robot: str, scene_xml: str + input_device: InputController, + robot: str, + scene_xml: str, + skip_controller: bool = False, ) -> tuple[mujoco.MjModel, mujoco.MjData]: + """Load a MuJoCo model + data for ``robot`` inside ``scene_xml``. + + When ``skip_controller=True``, the baked-in ONNX locomotion policy is + NOT installed as the MuJoCo control callback. Used by low-level + passthrough mode where an external caller (e.g. the dimos + ControlCoordinator via shared memory) drives ``data.ctrl`` each + tick. + """ mujoco.set_mjcb_control(None) xml_string = get_model_xml(robot, scene_xml) @@ -76,6 +87,9 @@ def load_model( n_substeps = round(ctrl_dt / sim_dt) model.opt.timestep = sim_dt + if skip_controller: + return model, data + params = { "policy_path": (_get_data_dir() / f"{robot}_policy.onnx").as_posix(), "default_angles": np.array(model.keyframe("home").qpos[7:]), diff --git a/dimos/simulation/mujoco/mujoco_process.py b/dimos/simulation/mujoco/mujoco_process.py index 27217afadd..96e85285a8 100755 --- a/dimos/simulation/mujoco/mujoco_process.py +++ b/dimos/simulation/mujoco/mujoco_process.py @@ -40,9 +40,10 @@ VIDEO_WIDTH, ) from dimos.simulation.mujoco.depth_camera import depth_image_to_point_cloud -from dimos.simulation.mujoco.model import load_model, load_scene_xml +from dimos.simulation.mujoco.model import get_assets, load_model, load_scene_xml from dimos.simulation.mujoco.person_on_track import PersonPositionController from dimos.simulation.mujoco.shared_memory import ShmReader +from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -72,13 +73,62 @@ def stop(self) -> None: pass -def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None: +def _find_sensor_slice(model: mujoco.MjModel, *names: str, dim: int = 3) -> slice | None: + """Return the first matching sensor slice across ``names``, or None.""" + for n in names: + sid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_SENSOR, n) + if sid >= 0: + adr = int(model.sensor_adr[sid]) + return slice(adr, adr + dim) + return None + + +def _load_g1_gear_wbc_lowlevel() -> tuple[mujoco.MjModel, mujoco.MjData]: + """Load GR00T's ``g1_gear_wbc.xml`` for low-level passthrough mode. + + This MJCF is the one the GR00T balance/walk ONNX policies were + trained against. Critically, it uses ```` (torque) actuators + — NOT ```` — so the subprocess does the PD itself with the + per-joint kp/kd coming in over shm, matching the gains in + ``g1_gear_wbc.yaml`` that shaped the policy during training. + + Position-actuator alternatives (dimos's bundled ``unitree_g1.xml`` + at kp=75 or menagerie's ``unitree_g1/scene.xml`` at kp=500) don't + match the trained gains (hips=150, knees=200, ankles=40, waist=250) + and produce violent instability when driven by the policy. + + The XML references meshes by bare filename (``meshdir`` stripped + when bundled); ``get_assets()`` already injects menagerie's G1 mesh + bytes under those names. + """ + xml_path = get_data("mujoco_sim") / "g1_gear_wbc.xml" + with open(xml_path) as f: + xml_str = f.read() + model = mujoco.MjModel.from_xml_string(xml_str, assets=get_assets()) + data = mujoco.MjData(model) + return model, data + + +def _run_simulation(config: GlobalConfig, shm: ShmReader, control_mode: str = "high_level") -> None: robot_name = config.robot_model or "unitree_go1" if robot_name == "unitree_go2": robot_name = "unitree_go1" controller = MockController(shm) - model, data = load_model(controller, robot=robot_name, scene_xml=load_scene_xml(config)) + skip_controller = control_mode == "low_level" + if skip_controller and robot_name == "unitree_g1": + # Low-level G1: use GR00T's training MJCF (torque actuators) and + # run PD in this subprocess. The dimos-bundled and menagerie + # MJCFs are position-actuator variants whose baked kp does NOT + # match the policy's trained gains. + model, data = _load_g1_gear_wbc_lowlevel() + else: + model, data = load_model( + controller, + robot=robot_name, + scene_xml=load_scene_xml(config), + skip_controller=skip_controller, + ) if model is None or data is None: raise ValueError("Failed to load MuJoCo model: model or data is None") @@ -87,7 +137,9 @@ def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None: case "unitree_go1": z = 0.3 case "unitree_g1": - z = 0.8 + # Match g1_gear_wbc.xml's pelvis pos. Was 0.8 — overrode the + # MJCF and dropped the robot 7 mm at the first mj_step. + z = 0.793 case _: z = 0 @@ -97,15 +149,46 @@ def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None: mujoco.mj_forward(model, data) - camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "head_camera") - lidar_camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_front_camera") - - person_position_controller = PersonPositionController(model) - - lidar_left_camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_left_camera") - lidar_right_camera_id = mujoco.mj_name2id( - model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_right_camera" - ) + # Camera / person machinery only exists in the high-level scenes + # (scene_office1 etc.). Low-level mode uses a minimal robot scene + # (menagerie), so skip those lookups entirely. + camera_id = lidar_camera_id = lidar_left_camera_id = lidar_right_camera_id = -1 + person_position_controller: PersonPositionController | None = None + if not skip_controller: + camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "head_camera") + lidar_camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_front_camera") + lidar_left_camera_id = mujoco.mj_name2id( + model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_left_camera" + ) + lidar_right_camera_id = mujoco.mj_name2id( + model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_right_camera" + ) + person_position_controller = PersonPositionController(model) + + # Low-level passthrough precomputes: actuator→qpos/qvel maps and IMU + # sensor slices so the per-tick hot path is just array copies. + imu_gyro_slice = imu_accel_slice = None + act_qposadr = act_dofadr = None + num_motors = 0 + if skip_controller: + # Menagerie uses "imu-pelvis-*" with hyphens; bundled MJX variant + # uses "gyro_pelvis"/"accelerometer_pelvis" with underscores. + # Try both so the low-level path works against either MJCF. + imu_gyro_slice = _find_sensor_slice( + model, "imu-pelvis-angular-velocity", "gyro_pelvis", dim=3 + ) + imu_accel_slice = _find_sensor_slice( + model, "imu-pelvis-linear-acceleration", "accelerometer_pelvis", dim=3 + ) + num_motors = int(model.nu) + act_qposadr = np.array( + [int(model.jnt_qposadr[int(model.actuator_trnid[i, 0])]) for i in range(num_motors)], + dtype=np.intp, + ) + act_dofadr = np.array( + [int(model.jnt_dofadr[int(model.actuator_trnid[i, 0])]) for i in range(num_motors)], + dtype=np.intp, + ) shm.signal_ready() @@ -136,14 +219,55 @@ def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None: m_viewer.cam.azimuth = config.mujoco_camera_position_float[4] m_viewer.cam.elevation = config.mujoco_camera_position_float[5] + # Low-level startup: the subprocess comes up ~2 s before the + # coordinator starts ticking. Without this flag, the robot would + # free-fall into a sprawl during those 2 s and the first PD tick + # would yank it at kp=150-200 from the fallen heap back toward + # the default bent-knee pose — a startup seizure. + controller_ready = False + while m_viewer.is_running() and not shm.should_stop(): step_start = time.time() - # Step simulation - for _ in range(config.mujoco_steps_per_frame): - mujoco.mj_step(model, data) - - person_position_controller.tick(data) + # Low-level passthrough: read per-joint (q_target, kp, kd) from + # shm, compute PD torque, write to data.ctrl. The MJCF has + # torque-mode actuators, so this subprocess plays the + # role that onboard motor drivers play on real hardware. + # Using shm-sourced kp/kd (not MJCF-baked gains) is the whole + # point: the GR00T policy was trained against a specific + # per-joint PD, and any deviation destabilises it. + if skip_controller: + assert act_qposadr is not None and act_dofadr is not None + cmd = shm.read_joint_cmd(num_motors) + if cmd is not None: + controller_ready = True + q = data.qpos[act_qposadr].astype(np.float32) + dq = data.qvel[act_dofadr].astype(np.float32) + q_tgt = cmd[:, 0] + kp = cmd[:, 1] + kd = cmd[:, 2] + data.ctrl[:num_motors] = kp * (q_tgt - q) - kd * dq + + # Step simulation. In low-level mode we step once per outer + # iteration so sim-time advances in lock-step with wall-time + # — the coordinator is writing new PD targets at ~500 Hz and + # we need fresh (q, dq) → PD each step, not a stale PD held + # across 7 substeps (which made physics run 7× real-time and + # PD react to 14 ms-old state, the seizure we just debugged). + # High-level mode keeps the substeps-per-frame speedup because + # its ONNX controller lives inside mj_step via mjcb_control. + if skip_controller and not controller_ready: + # mj_forward runs kinematics but not dynamics — robot + # stays at MJCF initial pose until the coordinator's + # first command arrives. + mujoco.mj_forward(model, data) + else: + steps = 1 if skip_controller else config.mujoco_steps_per_frame + for _ in range(steps): + mujoco.mj_step(model, data) + + if person_position_controller is not None: + person_position_controller.tick(data) m_viewer.sync() @@ -152,8 +276,40 @@ def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None: quat = data.qpos[3:7].copy() # (w, x, y, z) shm.write_odom(pos, quat, time.time()) + # Low-level passthrough: export per-joint state + IMU to shm. + if skip_controller: + assert act_qposadr is not None and act_dofadr is not None + q_out = data.qpos[act_qposadr].astype(np.float32) + dq_out = data.qvel[act_dofadr].astype(np.float32) + tau_out = data.actuator_force[:num_motors].astype(np.float32) + shm.write_joint_state(q_out, dq_out, tau_out) + # Base orientation from the free joint (qpos[3:7] is + # w,x,y,z per MuJoCo convention) — no framequat sensor + # needed, which menagerie's G1 doesn't ship with. + quat = data.qpos[3:7].astype(np.float32) + gyro = ( + data.sensordata[imu_gyro_slice].astype(np.float32) + if imu_gyro_slice is not None + else np.zeros(3, dtype=np.float32) + ) + accel = ( + data.sensordata[imu_accel_slice].astype(np.float32) + if imu_accel_slice is not None + else np.zeros(3, dtype=np.float32) + ) + shm.write_imu(quat, gyro, accel) + current_time = time.time() + # In low-level mode the robot scene has no head / lidar cameras, + # so the video + lidar streams are skipped entirely. Odom + + # joint_state + imu above are all the rerun layer needs. + if skip_controller: + time_until_next_step = model.opt.timestep - (time.time() - step_start) + if time_until_next_step > 0: + time.sleep(time_until_next_step) + continue + # Video rendering if current_time - last_video_time >= video_interval: rgb_renderer.update_scene(data, camera=camera_id, scene_option=scene_option) @@ -226,7 +382,8 @@ def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None: if time_until_next_step > 0: time.sleep(time_until_next_step) - person_position_controller.stop() + if person_position_controller is not None: + person_position_controller.stop() if __name__ == "__main__": @@ -240,9 +397,10 @@ def signal_handler(_signum: int, _frame: Any) -> None: global_config = pickle.loads(base64.b64decode(sys.argv[1])) shm_names = json.loads(sys.argv[2]) + control_mode = sys.argv[3] if len(sys.argv) > 3 else "high_level" shm = ShmReader(shm_names) try: - _run_simulation(global_config, shm) + _run_simulation(global_config, shm, control_mode=control_mode) finally: shm.cleanup() diff --git a/dimos/simulation/mujoco/shared_memory.py b/dimos/simulation/mujoco/shared_memory.py index bd96ad2025..1f061d5033 100644 --- a/dimos/simulation/mujoco/shared_memory.py +++ b/dimos/simulation/mujoco/shared_memory.py @@ -41,6 +41,12 @@ _seq_size = 8 * 8 # 8 int64 values for different data types # Control buffer: ready flag + stop flag _control_size = 2 * 4 # 2 int32 values +# Low-level passthrough: per-joint command (q, kp, kd) for up to 32 motors. +_LOWLEVEL_MAX_MOTORS = 32 +_joint_cmd_size = _LOWLEVEL_MAX_MOTORS * 3 * 4 # float32 (q, kp, kd) +_joint_state_size = _LOWLEVEL_MAX_MOTORS * 3 * 4 # float32 (q, dq, tau) +# IMU: quat (w,x,y,z) + gyro(3) + accel(3) float32 +_imu_size = 10 * 4 _shm_sizes = { "video": _video_size, @@ -53,6 +59,9 @@ "lidar_len": 4, "seq": _seq_size, "control": _control_size, + "joint_cmd": _joint_cmd_size, + "joint_state": _joint_state_size, + "imu": _imu_size, } @@ -76,6 +85,9 @@ class ShmSet: lidar_len: SharedMemory seq: SharedMemory control: SharedMemory + joint_cmd: SharedMemory + joint_state: SharedMemory + imu: SharedMemory @classmethod def from_names(cls, shm_names: dict[str, str]) -> "ShmSet": @@ -173,6 +185,42 @@ def read_command(self) -> tuple[NDArray[Any], NDArray[Any]] | None: return linear, angular return None + # --- Low-level passthrough (child/subprocess side) --- + + def read_joint_cmd(self, num_motors: int) -> NDArray[Any] | None: + """Return (num_motors, 3) array of (q, kp, kd), or ``None`` if no new cmd.""" + seq = self._get_seq(7) + if seq > 0: + arr: NDArray[Any] = np.ndarray( + (_LOWLEVEL_MAX_MOTORS, 3), dtype=np.float32, buffer=self.shm.joint_cmd.buf + ) + return arr[:num_motors].copy() + return None + + def write_joint_state(self, q: NDArray[Any], dq: NDArray[Any], tau: NDArray[Any]) -> None: + n = len(q) + if n > _LOWLEVEL_MAX_MOTORS: + n = _LOWLEVEL_MAX_MOTORS + arr: NDArray[Any] = np.ndarray( + (_LOWLEVEL_MAX_MOTORS, 3), dtype=np.float32, buffer=self.shm.joint_state.buf + ) + arr[:n, 0] = q[:n] + arr[:n, 1] = dq[:n] + arr[:n, 2] = tau[:n] + self._increment_seq(5) + + def write_imu( + self, + quat: NDArray[Any], + gyro: NDArray[Any], + accel: NDArray[Any], + ) -> None: + arr: NDArray[Any] = np.ndarray((10,), dtype=np.float32, buffer=self.shm.imu.buf) + arr[0:4] = quat + arr[4:7] = gyro + arr[7:10] = accel + self._increment_seq(6) + def _increment_seq(self, index: int) -> None: seq_array: NDArray[Any] = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) seq_array[index] += 1 @@ -237,6 +285,43 @@ def write_command(self, linear: NDArray[Any], angular: NDArray[Any]) -> None: cmd_array[3:6] = angular self._increment_seq(3) + # --- Low-level passthrough (parent side) --- + + def write_joint_cmd( + self, + q: NDArray[Any], + kp: NDArray[Any], + kd: NDArray[Any], + ) -> None: + n = len(q) + if n > _LOWLEVEL_MAX_MOTORS: + n = _LOWLEVEL_MAX_MOTORS + arr: NDArray[Any] = np.ndarray( + (_LOWLEVEL_MAX_MOTORS, 3), dtype=np.float32, buffer=self.shm.joint_cmd.buf + ) + arr[:n, 0] = q[:n] + arr[:n, 1] = kp[:n] + arr[:n, 2] = kd[:n] + self._increment_seq(7) + + def read_joint_state(self, num_motors: int) -> NDArray[Any] | None: + """Return (num_motors, 3) array of (q, dq, tau), or ``None`` if no state yet.""" + seq = self._get_seq(5) + if seq > 0: + arr: NDArray[Any] = np.ndarray( + (_LOWLEVEL_MAX_MOTORS, 3), dtype=np.float32, buffer=self.shm.joint_state.buf + ) + return arr[:num_motors].copy() + return None + + def read_imu(self) -> NDArray[Any] | None: + """Return 10-element float32 [quat(4), gyro(3), accel(3)], or ``None``.""" + seq = self._get_seq(6) + if seq > 0: + arr: NDArray[Any] = np.ndarray((10,), dtype=np.float32, buffer=self.shm.imu.buf) + return arr.copy() + return None + def read_lidar(self) -> tuple[PointCloud2 | None, int]: seq = self._get_seq(4) if seq > 0: diff --git a/dimos/web/command-center-extension/src/ActivationPanel.tsx b/dimos/web/command-center-extension/src/ActivationPanel.tsx new file mode 100644 index 0000000000..51ccd283fb --- /dev/null +++ b/dimos/web/command-center-extension/src/ActivationPanel.tsx @@ -0,0 +1,86 @@ +import * as React from "react"; + +import Button from "./Button"; + +interface ActivationPanelProps { + onArm: () => void; + onDisarm: () => void; + onSetDryRun: (enabled: boolean) => void; + initialDryRun?: boolean; +} + +/** + * Dashboard control for arming/disarming a locomotion-policy task + * (e.g. the G1 GR00T WBC). + * + * The panel is UI-only — it does not know whether the task actually + * accepted the arm/disarm request. The coordinator logs the + * transition server-side; for a future iteration we could subscribe + * to a state-echo stream and reflect the real machine-state here. + * + * Defaults to dry-run ON so hitting Arm on real hardware does NOT + * immediately command motors — the operator toggles dry-run off + * after visually verifying computed targets are sensible. + */ +export default function ActivationPanel({ + onArm, + onDisarm, + onSetDryRun, + initialDryRun = true, +}: ActivationPanelProps): React.ReactElement { + const [armed, setArmed] = React.useState(false); + const [dryRun, setDryRun] = React.useState(initialDryRun); + + const handleArmToggle = () => { + if (armed) { + onDisarm(); + setArmed(false); + } else { + onArm(); + setArmed(true); + } + }; + + const handleDryRunToggle = () => { + const next = !dryRun; + onSetDryRun(next); + setDryRun(next); + }; + + return ( +
+
Policy
+ + +
+ Arm ramps current pose → default over ~10 s. Dry run keeps policy computing but + suppresses commands so you can verify targets in server logs first. +
+
+ ); +} diff --git a/dimos/web/command-center-extension/src/App.tsx b/dimos/web/command-center-extension/src/App.tsx index dc0c90e7ea..df2e5d1a20 100644 --- a/dimos/web/command-center-extension/src/App.tsx +++ b/dimos/web/command-center-extension/src/App.tsx @@ -1,5 +1,6 @@ import * as React from "react"; +import ActivationPanel from "./ActivationPanel"; import Connection from "./Connection"; import ExplorePanel from "./ExplorePanel"; import GpsButton from "./GpsButton"; @@ -78,6 +79,18 @@ export default function App(): React.ReactElement { connectionRef.current?.stopMoveCommand(); }, []); + const handleArm = React.useCallback(() => { + connectionRef.current?.arm(); + }, []); + + const handleDisarm = React.useCallback(() => { + connectionRef.current?.disarm(); + }, []); + + const handleSetDryRun = React.useCallback((enabled: boolean) => { + connectionRef.current?.setDryRun(enabled); + }, []); + const handleReturnHome = React.useCallback(() => { connectionRef.current?.worldClick(0, 0); }, []); @@ -122,6 +135,11 @@ export default function App(): React.ReactElement { onSendMoveCommand={handleSendMoveCommand} onStopMoveCommand={handleStopMoveCommand} /> + ); diff --git a/dimos/web/command-center-extension/src/Connection.ts b/dimos/web/command-center-extension/src/Connection.ts index 7a23c6b98c..6e7f1588b5 100644 --- a/dimos/web/command-center-extension/src/Connection.ts +++ b/dimos/web/command-center-extension/src/Connection.ts @@ -104,6 +104,18 @@ export default class Connection { this.socket.emit("move_command", twist); } + arm(): void { + this.socket.emit("arm"); + } + + disarm(): void { + this.socket.emit("disarm"); + } + + setDryRun(enabled: boolean): void { + this.socket.emit("set_dry_run", { enabled }); + } + disconnect(): void { this.socket.disconnect(); } diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 2c3ad3009b..300acaf3eb 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -54,6 +54,7 @@ from dimos.mapping.types import LatLon from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped, Vector3 from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.msgs.std_msgs.Bool import Bool as DimosBool from dimos.utils.logging_config import setup_logger from .optimized_costmap import OptimizedCostmapEncoder @@ -96,6 +97,13 @@ class WebsocketVisModule(Module): stop_explore_cmd: Out[Bool] cmd_vel: Out[Twist] movecmd_stamped: Out[TwistStamped] + # Arming / dry-run for locomotion-policy tasks (e.g. GrootWBCTask). + # Uses dimos.msgs.std_msgs.Bool to match the coordinator's + # ``activate`` / ``dry_run`` In[Bool] ports, rather than + # dimos_lcm.std_msgs.Bool used by ``explore_cmd`` — the LCM wire + # format is identical; what matters for autoconnect is type parity. + activate: Out[DimosBool] + dry_run: Out[DimosBool] def __init__( self, @@ -327,6 +335,38 @@ async def clear_gps_goals(sid: str) -> None: await self.sio.emit("gps_travel_goal_points", self.gps_goal_points) logger.info("GPS goal points cleared and updated clients") + @self.sio.event # type: ignore[untyped-decorator] + async def arm(sid: str, data: dict[str, Any] | None = None) -> None: + """Dashboard → arm the locomotion policy (with ramp).""" + if self.activate and self.activate.transport: + logger.info("Dashboard requested arm") + self.activate.publish(DimosBool(data=True)) + else: + logger.warning("arm requested but activate transport is not configured") + + @self.sio.event # type: ignore[untyped-decorator] + async def disarm(sid: str, data: dict[str, Any] | None = None) -> None: + """Dashboard → disarm; task falls back to hold-current-pose.""" + if self.activate and self.activate.transport: + logger.info("Dashboard requested disarm") + self.activate.publish(DimosBool(data=False)) + else: + logger.warning("disarm requested but activate transport is not configured") + + @self.sio.event # type: ignore[untyped-decorator] + async def set_dry_run(sid: str, data: dict[str, Any]) -> None: + """Dashboard → toggle dry-run on the locomotion policy. + + Payload: ``{"enabled": bool}``. Task still computes but + coordinator sends nothing to the adapter when enabled. + """ + if self.dry_run and self.dry_run.transport: + enabled = bool(data.get("enabled", False)) + logger.info(f"Dashboard set dry_run = {enabled}") + self.dry_run.publish(DimosBool(data=enabled)) + else: + logger.warning("set_dry_run requested but dry_run transport is not configured") + @self.sio.event # type: ignore[untyped-decorator] async def move_command(sid: str, data: dict[str, Any]) -> None: # Publish Twist if transport is configured