diff --git a/egomimic/algo/pi.py b/egomimic/algo/pi.py index abc7ede62..952d28c93 100644 --- a/egomimic/algo/pi.py +++ b/egomimic/algo/pi.py @@ -25,7 +25,11 @@ _to_minus1_1, ) from egomimic.rldb.embodiment.embodiment import get_embodiment, get_embodiment_id -from egomimic.utils.action_utils import ConverterRegistry +from egomimic.utils.action_utils import ( + ConverterRegistry, + PI05_CARTESIAN_ACTION_ENCODING_LEGACY, + PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D, +) logger = logging.getLogger(__name__) # Ensure logger propagates to root logger and has appropriate level @@ -70,6 +74,7 @@ def __init__( state_num_bins: int = 256, control_mode: dict[str, str] | None = None, proprio_keys_for_prompt: list[str] | None = None, + action_encoding: str = PI05_CARTESIAN_ACTION_ENCODING_LEGACY, **kwargs, ): self.nets = nn.ModuleDict() @@ -103,6 +108,7 @@ def __init__( "pi_cam_keys", ["base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb"] ) self.config = config + self.action_encoding = action_encoding self.ac_keys = ac_keys @@ -291,6 +297,21 @@ def _tokenize_prompts(self, prompts: list[str]) -> dict: "token_ar_mask": attention_mask.clone().requires_grad_(False), } + def _action_stats(self, embodiment_id: int, ac_key: str) -> dict: + try: + return self.norm_stats.norm_stats[embodiment_id][ac_key] + except KeyError as exc: + raise KeyError( + f"Missing norm stats for action key {ac_key!r} " + f"and embodiment id {embodiment_id}" + ) from exc + + def _unnormalize_action(self, action: torch.Tensor, embodiment_id: int, ac_key: str): + return self.norm_stats.unnormalize( + {ac_key: action.clone(), "embodiment": embodiment_id}, + embodiment_id, + )[ac_key].to(action.device) + @override def process_batch_for_training(self, batch): """ @@ -451,12 +472,28 @@ def forward_eval(self, batch): B, T, D = ref.shape converter = self.action_registry.get(embodiment_id, ac_key) - pred_actions_orig = converter.from32(pred_actions) - - pred = pred_actions_orig[:, :T, :D] - predictions[ac_key] = pred - - unnorm_actions = self.norm_stats.unnormalize(predictions, embodiment_id) + if ( + self.action_encoding + == PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D + ): + pred_actions_orig = converter.from32_raw_rotation( + pred_actions, + stats=self._action_stats(embodiment_id, ac_key), + norm_mode=self.norm_stats.norm_mode, + unnormalize_non_rotation=True, + ) + unnorm_actions = {ac_key: pred_actions_orig[:, :T, :D]} + elif self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_LEGACY: + pred_actions_orig = converter.from32(pred_actions) + pred = pred_actions_orig[:, :T, :D] + predictions[ac_key] = pred + unnorm_actions = self.norm_stats.unnormalize( + predictions, embodiment_id + ) + else: + raise ValueError( + f"Unsupported PI0.5 action_encoding: {self.action_encoding!r}" + ) for key in unnorm_actions: unnorm_preds[f"{embodiment_name}_{key}"] = unnorm_actions[key] @@ -531,7 +568,20 @@ def _robomimic_to_pi_data( emb_id = get_embodiment_id(embodiment) # embodiment is a name string converter = self.action_registry.get(emb_id, ac_key) - action32 = converter.to32(action) + if self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D: + raw_action = self._unnormalize_action(action, emb_id, ac_key) + action32 = converter.to32_raw_rotation( + raw_action, + normalized_actions=action, + stats=self._action_stats(emb_id, ac_key), + norm_mode=self.norm_stats.norm_mode, + ) + elif self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_LEGACY: + action32 = converter.to32(action) + else: + raise ValueError( + f"Unsupported PI0.5 action_encoding: {self.action_encoding!r}" + ) # OpenPI expects a fixed camera tuple. Human datasets only provide # `base_0_rgb`, so duplicate that view into the missing wrist slots and diff --git a/egomimic/hydra_configs/model/pi0.5_bc_eva.yaml b/egomimic/hydra_configs/model/pi0.5_bc_eva.yaml index b0b6a6a10..83ffad26d 100644 --- a/egomimic/hydra_configs/model/pi0.5_bc_eva.yaml +++ b/egomimic/hydra_configs/model/pi0.5_bc_eva.yaml @@ -10,6 +10,7 @@ robomimic_model: ac_keys: eva_bimanual: "actions_cartesian" domains: ["eva_bimanual"] + action_encoding: "cartesian_ypr_raw_rot6d" action_converters: rules: diff --git a/egomimic/utils/action_utils.py b/egomimic/utils/action_utils.py index 75c4fac11..57602f5a9 100644 --- a/egomimic/utils/action_utils.py +++ b/egomimic/utils/action_utils.py @@ -1,7 +1,14 @@ -from typing import Dict, Tuple +from typing import Any, Dict, Tuple import torch +PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D = "cartesian_ypr_raw_rot6d" +PI05_CARTESIAN_ACTION_ENCODING_LEGACY = "legacy_normalized_ypr_rot6d" + +# Bimanual robot Cartesian layout: [x, y, z, yaw, pitch, roll, gripper] x 2. +ROBOT_BIMANUAL_CARTESIAN_ROT_DIMS = (3, 4, 5, 10, 11, 12) +ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS = (0, 1, 2, 6, 7, 8, 9, 13) + # ---------- registry that stores *objects* ---------- class ConverterRegistry: @@ -43,6 +50,77 @@ def _pad32(x: torch.Tensor) -> torch.Tensor: return x[..., :32] +def _stat_tensor(stats: dict[str, Any], key: str, ref: torch.Tensor) -> torch.Tensor: + value = torch.as_tensor(stats[key], device=ref.device, dtype=torch.float32) + return value.to(dtype=ref.dtype if ref.is_floating_point() else torch.float32) + + +def _apply_norm_one( + tensor: torch.Tensor, + stats: dict[str, Any], + norm_mode: str, +) -> torch.Tensor: + if norm_mode == "zscore": + mean = _stat_tensor(stats, "mean", tensor) + std = _stat_tensor(stats, "std", tensor) + return (tensor - mean) / (std + 1e-6) + if norm_mode == "minmax": + mn = _stat_tensor(stats, "min", tensor) + mx = _stat_tensor(stats, "max", tensor) + return 2.0 * ((tensor - mn) / (mx - mn + 1e-6)) - 1.0 + if norm_mode == "quantile": + q1 = _stat_tensor(stats, "quantile_1", tensor) + q99 = _stat_tensor(stats, "quantile_99", tensor) + return 2.0 * ((tensor - q1) / (q99 - q1 + 1e-6)) - 1.0 + raise ValueError(f"Invalid normalization mode: {norm_mode}") + + +def _apply_unnorm_one( + tensor: torch.Tensor, + stats: dict[str, Any], + norm_mode: str, +) -> torch.Tensor: + if norm_mode == "zscore": + mean = _stat_tensor(stats, "mean", tensor) + std = _stat_tensor(stats, "std", tensor) + return tensor * (std + 1e-6) + mean + if norm_mode == "minmax": + mn = _stat_tensor(stats, "min", tensor) + mx = _stat_tensor(stats, "max", tensor) + return (tensor + 1) * 0.5 * (mx - mn + 1e-6) + mn + if norm_mode == "quantile": + q1 = _stat_tensor(stats, "quantile_1", tensor) + q99 = _stat_tensor(stats, "quantile_99", tensor) + return (tensor + 1) * 0.5 * (q99 - q1 + 1e-6) + q1 + raise ValueError(f"Invalid normalization mode: {norm_mode}") + + +def _normalize_robot_bimanual_non_rot( + raw_actions: torch.Tensor, + stats: dict[str, Any], + norm_mode: str, +) -> torch.Tensor: + normalized = raw_actions.clone() + all_dims = _apply_norm_one(raw_actions, stats, norm_mode) + normalized[..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS] = all_dims[ + ..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS + ] + return normalized + + +def _unnormalize_robot_bimanual_non_rot( + model_actions: torch.Tensor, + stats: dict[str, Any], + norm_mode: str, +) -> torch.Tensor: + raw_actions = model_actions.clone() + all_dims = _apply_unnorm_one(model_actions, stats, norm_mode) + raw_actions[..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS] = all_dims[ + ..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS + ] + return raw_actions + + def _ypr_to_matrix(ypr: torch.Tensor, degrees: bool = False) -> torch.Tensor: if degrees: ypr = ypr * (torch.pi / 180.0) @@ -137,6 +215,34 @@ def to32(self, actions: torch.Tensor) -> torch.Tensor: def from32(self, actions32: torch.Tensor) -> torch.Tensor: raise NotImplementedError + def to32_raw_rotation( + self, + raw_actions: torch.Tensor, + *, + normalized_actions: torch.Tensor | None = None, + stats: dict[str, Any] | None = None, + norm_mode: str = "quantile", + ) -> torch.Tensor: + """Pack actions with raw YPR rotations and normalized non-rotation dims.""" + del normalized_actions, stats, norm_mode + raise NotImplementedError( + f"{type(self).__name__} does not support raw-rotation action encoding" + ) + + def from32_raw_rotation( + self, + actions32: torch.Tensor, + *, + stats: dict[str, Any] | None = None, + norm_mode: str = "quantile", + unnormalize_non_rotation: bool = False, + ) -> torch.Tensor: + """Decode actions whose 6D rotation columns represent raw YPR rotations.""" + del stats, norm_mode, unnormalize_non_rotation + raise NotImplementedError( + f"{type(self).__name__} does not support raw-rotation action decoding" + ) + # ============================================================ # ROBOT CONVERTERS @@ -210,7 +316,7 @@ class RobotBimanualCartesianEuler(BaseActionConverter): 32-pack: left block 0..9, right block 10..19 """ - def to32(self, actions: torch.Tensor) -> torch.Tensor: + def to20(self, actions: torch.Tensor) -> torch.Tensor: actions = _ensure_bsd(actions) if actions.shape[-1] != 14: raise ValueError(f"RobotBimanual: expected 14-dim, got {actions.shape[-1]}") @@ -228,12 +334,19 @@ def to32(self, actions: torch.Tensor) -> torch.Tensor: R_c1, R_c2 = R_R[..., 0], R_R[..., 1] right_block = torch.cat([R_xyz, R_c1, R_c2, R_g], dim=-1) # (B,S,10) - return _pad32(torch.cat([left_block, right_block], dim=-1)) # (B,S,20+) -> 32 + return torch.cat([left_block, right_block], dim=-1) # (B,S,20) - def from32(self, actions32: torch.Tensor) -> torch.Tensor: - actions32 = _ensure_bsd(actions32) - Lb = actions32[..., 0:10] - Rb = actions32[..., 10:20] + def to32(self, actions: torch.Tensor) -> torch.Tensor: + return _pad32(self.to20(actions)) + + def from20(self, actions20: torch.Tensor) -> torch.Tensor: + actions20 = _ensure_bsd(actions20) + if actions20.shape[-1] < 20: + raise ValueError( + f"RobotBimanual: expected at least 20 dims, got {actions20.shape[-1]}" + ) + Lb = actions20[..., 0:10] + Rb = actions20[..., 10:20] # left L_xyz, L_c1, L_c2, L_g = Lb[..., 0:3], Lb[..., 3:6], Lb[..., 6:9], Lb[..., 9:10] @@ -249,6 +362,93 @@ def from32(self, actions32: torch.Tensor) -> torch.Tensor: R7 = torch.cat([R_xyz, R_ypr, R_g], dim=-1) return torch.cat([L7, R7], dim=-1) # (B,S,14) + def from32(self, actions32: torch.Tensor) -> torch.Tensor: + return self.from20(actions32) + + def to20_raw_rotation( + self, + raw_actions: torch.Tensor, + *, + normalized_actions: torch.Tensor | None = None, + stats: dict[str, Any] | None = None, + norm_mode: str = "quantile", + ) -> torch.Tensor: + raw_actions = _ensure_bsd(raw_actions) + if raw_actions.shape[-1] != 14: + raise ValueError( + f"RobotBimanual: expected 14-dim, got {raw_actions.shape[-1]}" + ) + if normalized_actions is None: + if stats is None: + raise ValueError("stats are required when normalized_actions is omitted") + model_actions = _normalize_robot_bimanual_non_rot( + raw_actions, stats, norm_mode + ) + else: + normalized_actions = _ensure_bsd(normalized_actions).to(raw_actions.device) + if normalized_actions.shape != raw_actions.shape: + raise ValueError( + "normalized_actions must match raw_actions shape; got " + f"{tuple(normalized_actions.shape)} vs {tuple(raw_actions.shape)}" + ) + model_actions = raw_actions.clone() + model_actions[..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS] = ( + normalized_actions[..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS] + ) + return self.to20(model_actions) + + def from20_raw_rotation( + self, + actions20: torch.Tensor, + *, + stats: dict[str, Any] | None = None, + norm_mode: str = "quantile", + unnormalize_non_rotation: bool = False, + ) -> torch.Tensor: + model_actions = self.from20(actions20) + if not unnormalize_non_rotation: + return model_actions + if model_actions.shape[-1] != 14: + raise ValueError( + "RobotBimanual raw-rotation decoding expected 14D Cartesian actions; " + f"got {model_actions.shape[-1]} dims" + ) + if stats is None: + raise ValueError("stats are required to unnormalize non-rotation dims") + return _unnormalize_robot_bimanual_non_rot(model_actions, stats, norm_mode) + + def to32_raw_rotation( + self, + raw_actions: torch.Tensor, + *, + normalized_actions: torch.Tensor | None = None, + stats: dict[str, Any] | None = None, + norm_mode: str = "quantile", + ) -> torch.Tensor: + return _pad32( + self.to20_raw_rotation( + raw_actions, + normalized_actions=normalized_actions, + stats=stats, + norm_mode=norm_mode, + ) + ) + + def from32_raw_rotation( + self, + actions32: torch.Tensor, + *, + stats: dict[str, Any] | None = None, + norm_mode: str = "quantile", + unnormalize_non_rotation: bool = False, + ) -> torch.Tensor: + return self.from20_raw_rotation( + actions32, + stats=stats, + norm_mode=norm_mode, + unnormalize_non_rotation=unnormalize_non_rotation, + ) + # ============================================================ # HUMAN CONVERTERS diff --git a/egomimic/utils/test_pi05_action_encoding.py b/egomimic/utils/test_pi05_action_encoding.py new file mode 100644 index 000000000..71f3d04d8 --- /dev/null +++ b/egomimic/utils/test_pi05_action_encoding.py @@ -0,0 +1,112 @@ +import torch +import pytest + +from egomimic.utils.action_utils import BaseActionConverter, RobotBimanualCartesianEuler + + +def _stats_for() -> dict[str, torch.Tensor]: + # Deliberately asymmetric per-dim ranges exercise non-rotation normalization. + q1 = torch.tensor( + [-1.0, -2.0, -3.0, -torch.pi, -1.0, -1.0, 0.0] * 2, + dtype=torch.float32, + ) + q99 = torch.tensor( + [1.0, 2.0, 3.0, torch.pi, 1.0, 1.0, 1.0] * 2, + dtype=torch.float32, + ) + return { + "mean": torch.zeros_like(q1), + "std": torch.ones_like(q1), + "min": q1, + "max": q99, + "quantile_1": q1, + "quantile_99": q99, + } + + +def _normalize(raw: torch.Tensor, stats: dict[str, torch.Tensor], norm_mode: str) -> torch.Tensor: + if norm_mode == "zscore": + return (raw - stats["mean"]) / (stats["std"] + 1e-6) + if norm_mode == "minmax": + return 2.0 * ((raw - stats["min"]) / (stats["max"] - stats["min"] + 1e-6)) - 1.0 + if norm_mode == "quantile": + return 2.0 * ( + (raw - stats["quantile_1"]) + / (stats["quantile_99"] - stats["quantile_1"] + 1e-6) + ) - 1.0 + raise AssertionError(f"unexpected norm mode {norm_mode}") + + +@pytest.mark.parametrize("norm_mode", ["zscore", "minmax", "quantile"]) +def test_raw_rotation_encoding_round_trips_robot_bimanual_actions(norm_mode): + converter = RobotBimanualCartesianEuler() + raw = torch.tensor( + [ + [ + [ + 0.25, + -0.5, + 1.0, + 0.3, + -0.2, + 0.1, + 0.75, + -0.2, + 0.4, + -1.2, + -0.4, + 0.25, + -0.15, + 0.2, + ] + ] + ], + dtype=torch.float32, + ) + stats = _stats_for() + + packed = converter.to32_raw_rotation(raw, stats=stats, norm_mode=norm_mode) + decoded = converter.from32_raw_rotation( + packed, + stats=stats, + norm_mode=norm_mode, + unnormalize_non_rotation=True, + ) + + torch.testing.assert_close(decoded, raw, atol=1e-5, rtol=1e-5) + + +def test_raw_rotation_encoding_preserves_yaw_wrap_continuity(): + converter = RobotBimanualCartesianEuler() + eps = 1e-4 + raw = torch.zeros(1, 2, 14, dtype=torch.float32) + raw[:, 0, 3] = torch.pi - eps + raw[:, 1, 3] = -torch.pi + eps + stats = _stats_for() + normalized = _normalize(raw, stats, "quantile") + + legacy_packed = converter.to32(normalized) + fixed_packed = converter.to32_raw_rotation( + raw, + normalized_actions=normalized, + stats=stats, + norm_mode="quantile", + ) + + legacy_rot_distance = torch.linalg.norm( + legacy_packed[0, 0, 3:9] - legacy_packed[0, 1, 3:9] + ) + fixed_rot_distance = torch.linalg.norm( + fixed_packed[0, 0, 3:9] - fixed_packed[0, 1, 3:9] + ) + + assert legacy_rot_distance > 2.0 + assert fixed_rot_distance < 1e-3 + + +def test_base_converter_rejects_raw_rotation_encoding(): + converter = BaseActionConverter() + actions = torch.zeros(1, 1, 14) + + with pytest.raises(NotImplementedError, match="raw-rotation action encoding"): + converter.to32_raw_rotation(actions)