From 0e521b636446ad733c6220e8be7009d48e53c290 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Tue, 20 Jan 2026 08:21:43 +0100 Subject: [PATCH] fix(fr3,panda): robot state --- extensions/rcs_fr3/src/rcs_fr3/envs.py | 12 ++++++++++-- extensions/rcs_panda/src/rcs_panda/envs.py | 21 +++++++++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/extensions/rcs_fr3/src/rcs_fr3/envs.py b/extensions/rcs_fr3/src/rcs_fr3/envs.py index c80e8ef5..7d8bd28f 100644 --- a/extensions/rcs_fr3/src/rcs_fr3/envs.py +++ b/extensions/rcs_fr3/src/rcs_fr3/envs.py @@ -14,6 +14,7 @@ def __init__(self, env): self.unwrapped: RobotEnv assert isinstance(self.unwrapped.robot, hw.Franka), "Robot must be a hw.Franka instance." self.hw_robot = cast(hw.Franka, self.unwrapped.robot) + self._robot_state_keys: list[str] | None = None def step(self, action: Any) -> tuple[dict[str, Any], SupportsFloat, bool, bool, dict]: try: @@ -30,10 +31,17 @@ def step(self, action: Any) -> tuple[dict[str, Any], SupportsFloat, bool, bool, def get_obs(self, obs: dict | None = None) -> dict[str, Any]: if obs is None: obs = dict(self.unwrapped.get_obs()) - # robot_state = cast(hw.FrankaState, self.unwrapped.robot.get_state()) - # obs["robot_state"] = vars(robot_state.robot_state) + robot_state = cast(hw.FrankaState, self.unwrapped.robot.get_state()) + obs["robot_state"] = self._rs2dict(robot_state.robot_state) return obs + def _rs2dict(self, state: hw.RobotState): + if self._robot_state_keys is None: + self._robot_state_keys = [ + attr for attr in dir(state) if not attr.startswith("__") and not callable(getattr(state, attr)) + ] + return {key: getattr(state, key) for key in self._robot_state_keys} + def reset( self, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[dict[str, Any], dict[str, Any]]: diff --git a/extensions/rcs_panda/src/rcs_panda/envs.py b/extensions/rcs_panda/src/rcs_panda/envs.py index 562c118d..58a80212 100644 --- a/extensions/rcs_panda/src/rcs_panda/envs.py +++ b/extensions/rcs_panda/src/rcs_panda/envs.py @@ -14,16 +14,33 @@ def __init__(self, env): self.unwrapped: RobotEnv assert isinstance(self.unwrapped.robot, hw.Franka), "Robot must be a hw.Franka instance." self.hw_robot = cast(hw.Franka, self.unwrapped.robot) + self._robot_state_keys: list[str] | None = None def step(self, action: Any) -> tuple[dict[str, Any], SupportsFloat, bool, bool, dict]: try: - return super().step(action) + obs, reward, terminated, truncated, info = super().step(action) + obs = self.get_obs(obs) + return obs, reward, terminated, truncated, info except hw.exceptions.FrankaControlException as e: _logger.error("FrankaControlException: %s", e) self.hw_robot.automatic_error_recovery() # TODO: this does not work if some wrappers are in between # PandaHW and RobotEnv - return dict(self.unwrapped.get_obs()), 0, False, True, {} + return self.get_obs(), 0, False, True, {} + + def get_obs(self, obs: dict | None = None) -> dict[str, Any]: + if obs is None: + obs = dict(self.unwrapped.get_obs()) + robot_state = cast(hw.FrankaState, self.unwrapped.robot.get_state()) + obs["robot_state"] = self._rs2dict(robot_state.robot_state) + return obs + + def _rs2dict(self, state: hw.RobotState): + if self._robot_state_keys is None: + self._robot_state_keys = [ + attr for attr in dir(state) if not attr.startswith("__") and not callable(getattr(state, attr)) + ] + return {key: getattr(state, key) for key in self._robot_state_keys} def reset( self, seed: int | None = None, options: dict[str, Any] | None = None