Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 58 additions & 8 deletions egomimic/algo/pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
_to_minus1_1,
)
from egomimic.rldb.embodiment.embodiment import get_embodiment, get_embodiment_id
from egomimic.utils.action_utils import ConverterRegistry
from egomimic.utils.action_utils import (
ConverterRegistry,
PI05_CARTESIAN_ACTION_ENCODING_LEGACY,
PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D,
)

logger = logging.getLogger(__name__)
# Ensure logger propagates to root logger and has appropriate level
Expand Down Expand Up @@ -70,6 +74,7 @@ def __init__(
state_num_bins: int = 256,
control_mode: dict[str, str] | None = None,
proprio_keys_for_prompt: list[str] | None = None,
action_encoding: str = PI05_CARTESIAN_ACTION_ENCODING_LEGACY,
**kwargs,
):
self.nets = nn.ModuleDict()
Expand Down Expand Up @@ -103,6 +108,7 @@ def __init__(
"pi_cam_keys", ["base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb"]
)
self.config = config
self.action_encoding = action_encoding

self.ac_keys = ac_keys

Expand Down Expand Up @@ -291,6 +297,21 @@ def _tokenize_prompts(self, prompts: list[str]) -> dict:
"token_ar_mask": attention_mask.clone().requires_grad_(False),
}

def _action_stats(self, embodiment_id: int, ac_key: str) -> dict:
try:
return self.norm_stats.norm_stats[embodiment_id][ac_key]
except KeyError as exc:
raise KeyError(
f"Missing norm stats for action key {ac_key!r} "
f"and embodiment id {embodiment_id}"
) from exc

def _unnormalize_action(self, action: torch.Tensor, embodiment_id: int, ac_key: str):
return self.norm_stats.unnormalize(
{ac_key: action.clone(), "embodiment": embodiment_id},
embodiment_id,
)[ac_key].to(action.device)

@override
def process_batch_for_training(self, batch):
"""
Expand Down Expand Up @@ -451,12 +472,28 @@ def forward_eval(self, batch):
B, T, D = ref.shape

converter = self.action_registry.get(embodiment_id, ac_key)
pred_actions_orig = converter.from32(pred_actions)

pred = pred_actions_orig[:, :T, :D]
predictions[ac_key] = pred

unnorm_actions = self.norm_stats.unnormalize(predictions, embodiment_id)
if (
self.action_encoding
== PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D
):
pred_actions_orig = converter.from32_raw_rotation(
pred_actions,
stats=self._action_stats(embodiment_id, ac_key),
norm_mode=self.norm_stats.norm_mode,
unnormalize_non_rotation=True,
)
unnorm_actions = {ac_key: pred_actions_orig[:, :T, :D]}
elif self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_LEGACY:
pred_actions_orig = converter.from32(pred_actions)
pred = pred_actions_orig[:, :T, :D]
predictions[ac_key] = pred
unnorm_actions = self.norm_stats.unnormalize(
predictions, embodiment_id
)
else:
raise ValueError(
f"Unsupported PI0.5 action_encoding: {self.action_encoding!r}"
)
for key in unnorm_actions:
unnorm_preds[f"{embodiment_name}_{key}"] = unnorm_actions[key]

Expand Down Expand Up @@ -531,7 +568,20 @@ def _robomimic_to_pi_data(

emb_id = get_embodiment_id(embodiment) # embodiment is a name string
converter = self.action_registry.get(emb_id, ac_key)
action32 = converter.to32(action)
if self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D:
raw_action = self._unnormalize_action(action, emb_id, ac_key)
action32 = converter.to32_raw_rotation(
raw_action,
normalized_actions=action,
stats=self._action_stats(emb_id, ac_key),
norm_mode=self.norm_stats.norm_mode,
)
elif self.action_encoding == PI05_CARTESIAN_ACTION_ENCODING_LEGACY:
action32 = converter.to32(action)
else:
raise ValueError(
f"Unsupported PI0.5 action_encoding: {self.action_encoding!r}"
)

# OpenPI expects a fixed camera tuple. Human datasets only provide
# `base_0_rgb`, so duplicate that view into the missing wrist slots and
Expand Down
1 change: 1 addition & 0 deletions egomimic/hydra_configs/model/pi0.5_bc_eva.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ robomimic_model:
ac_keys:
eva_bimanual: "actions_cartesian"
domains: ["eva_bimanual"]
action_encoding: "cartesian_ypr_raw_rot6d"

action_converters:
rules:
Expand Down
214 changes: 207 additions & 7 deletions egomimic/utils/action_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Dict, Tuple
from typing import Any, Dict, Tuple

import torch

PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D = "cartesian_ypr_raw_rot6d"
PI05_CARTESIAN_ACTION_ENCODING_LEGACY = "legacy_normalized_ypr_rot6d"

# Bimanual robot Cartesian layout: [x, y, z, yaw, pitch, roll, gripper] x 2.
ROBOT_BIMANUAL_CARTESIAN_ROT_DIMS = (3, 4, 5, 10, 11, 12)
ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS = (0, 1, 2, 6, 7, 8, 9, 13)


# ---------- registry that stores *objects* ----------
class ConverterRegistry:
Expand Down Expand Up @@ -43,6 +50,77 @@ def _pad32(x: torch.Tensor) -> torch.Tensor:
return x[..., :32]


def _stat_tensor(stats: dict[str, Any], key: str, ref: torch.Tensor) -> torch.Tensor:
value = torch.as_tensor(stats[key], device=ref.device, dtype=torch.float32)
return value.to(dtype=ref.dtype if ref.is_floating_point() else torch.float32)


def _apply_norm_one(
tensor: torch.Tensor,
stats: dict[str, Any],
norm_mode: str,
) -> torch.Tensor:
if norm_mode == "zscore":
mean = _stat_tensor(stats, "mean", tensor)
std = _stat_tensor(stats, "std", tensor)
return (tensor - mean) / (std + 1e-6)
if norm_mode == "minmax":
mn = _stat_tensor(stats, "min", tensor)
mx = _stat_tensor(stats, "max", tensor)
return 2.0 * ((tensor - mn) / (mx - mn + 1e-6)) - 1.0
if norm_mode == "quantile":
q1 = _stat_tensor(stats, "quantile_1", tensor)
q99 = _stat_tensor(stats, "quantile_99", tensor)
return 2.0 * ((tensor - q1) / (q99 - q1 + 1e-6)) - 1.0
raise ValueError(f"Invalid normalization mode: {norm_mode}")


def _apply_unnorm_one(
tensor: torch.Tensor,
stats: dict[str, Any],
norm_mode: str,
) -> torch.Tensor:
if norm_mode == "zscore":
mean = _stat_tensor(stats, "mean", tensor)
std = _stat_tensor(stats, "std", tensor)
return tensor * (std + 1e-6) + mean
if norm_mode == "minmax":
mn = _stat_tensor(stats, "min", tensor)
mx = _stat_tensor(stats, "max", tensor)
return (tensor + 1) * 0.5 * (mx - mn + 1e-6) + mn
if norm_mode == "quantile":
q1 = _stat_tensor(stats, "quantile_1", tensor)
q99 = _stat_tensor(stats, "quantile_99", tensor)
return (tensor + 1) * 0.5 * (q99 - q1 + 1e-6) + q1
raise ValueError(f"Invalid normalization mode: {norm_mode}")


def _normalize_robot_bimanual_non_rot(
raw_actions: torch.Tensor,
stats: dict[str, Any],
norm_mode: str,
) -> torch.Tensor:
normalized = raw_actions.clone()
all_dims = _apply_norm_one(raw_actions, stats, norm_mode)
normalized[..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS] = all_dims[
..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS
]
return normalized


def _unnormalize_robot_bimanual_non_rot(
model_actions: torch.Tensor,
stats: dict[str, Any],
norm_mode: str,
) -> torch.Tensor:
raw_actions = model_actions.clone()
all_dims = _apply_unnorm_one(model_actions, stats, norm_mode)
raw_actions[..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS] = all_dims[
..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS
]
return raw_actions


def _ypr_to_matrix(ypr: torch.Tensor, degrees: bool = False) -> torch.Tensor:
if degrees:
ypr = ypr * (torch.pi / 180.0)
Expand Down Expand Up @@ -137,6 +215,34 @@ def to32(self, actions: torch.Tensor) -> torch.Tensor:
def from32(self, actions32: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

def to32_raw_rotation(
self,
raw_actions: torch.Tensor,
*,
normalized_actions: torch.Tensor | None = None,
stats: dict[str, Any] | None = None,
norm_mode: str = "quantile",
) -> torch.Tensor:
"""Pack actions with raw YPR rotations and normalized non-rotation dims."""
del normalized_actions, stats, norm_mode
raise NotImplementedError(
f"{type(self).__name__} does not support raw-rotation action encoding"
)

def from32_raw_rotation(
self,
actions32: torch.Tensor,
*,
stats: dict[str, Any] | None = None,
norm_mode: str = "quantile",
unnormalize_non_rotation: bool = False,
) -> torch.Tensor:
"""Decode actions whose 6D rotation columns represent raw YPR rotations."""
del stats, norm_mode, unnormalize_non_rotation
raise NotImplementedError(
f"{type(self).__name__} does not support raw-rotation action decoding"
)


# ============================================================
# ROBOT CONVERTERS
Expand Down Expand Up @@ -210,7 +316,7 @@ class RobotBimanualCartesianEuler(BaseActionConverter):
32-pack: left block 0..9, right block 10..19
"""

def to32(self, actions: torch.Tensor) -> torch.Tensor:
def to20(self, actions: torch.Tensor) -> torch.Tensor:
actions = _ensure_bsd(actions)
if actions.shape[-1] != 14:
raise ValueError(f"RobotBimanual: expected 14-dim, got {actions.shape[-1]}")
Expand All @@ -228,12 +334,19 @@ def to32(self, actions: torch.Tensor) -> torch.Tensor:
R_c1, R_c2 = R_R[..., 0], R_R[..., 1]
right_block = torch.cat([R_xyz, R_c1, R_c2, R_g], dim=-1) # (B,S,10)

return _pad32(torch.cat([left_block, right_block], dim=-1)) # (B,S,20+) -> 32
return torch.cat([left_block, right_block], dim=-1) # (B,S,20)

def from32(self, actions32: torch.Tensor) -> torch.Tensor:
actions32 = _ensure_bsd(actions32)
Lb = actions32[..., 0:10]
Rb = actions32[..., 10:20]
def to32(self, actions: torch.Tensor) -> torch.Tensor:
return _pad32(self.to20(actions))

def from20(self, actions20: torch.Tensor) -> torch.Tensor:
actions20 = _ensure_bsd(actions20)
if actions20.shape[-1] < 20:
raise ValueError(
f"RobotBimanual: expected at least 20 dims, got {actions20.shape[-1]}"
)
Lb = actions20[..., 0:10]
Rb = actions20[..., 10:20]

# left
L_xyz, L_c1, L_c2, L_g = Lb[..., 0:3], Lb[..., 3:6], Lb[..., 6:9], Lb[..., 9:10]
Expand All @@ -249,6 +362,93 @@ def from32(self, actions32: torch.Tensor) -> torch.Tensor:
R7 = torch.cat([R_xyz, R_ypr, R_g], dim=-1)
return torch.cat([L7, R7], dim=-1) # (B,S,14)

def from32(self, actions32: torch.Tensor) -> torch.Tensor:
return self.from20(actions32)

def to20_raw_rotation(
self,
raw_actions: torch.Tensor,
*,
normalized_actions: torch.Tensor | None = None,
stats: dict[str, Any] | None = None,
norm_mode: str = "quantile",
) -> torch.Tensor:
raw_actions = _ensure_bsd(raw_actions)
if raw_actions.shape[-1] != 14:
raise ValueError(
f"RobotBimanual: expected 14-dim, got {raw_actions.shape[-1]}"
)
if normalized_actions is None:
if stats is None:
raise ValueError("stats are required when normalized_actions is omitted")
model_actions = _normalize_robot_bimanual_non_rot(
raw_actions, stats, norm_mode
)
else:
normalized_actions = _ensure_bsd(normalized_actions).to(raw_actions.device)
if normalized_actions.shape != raw_actions.shape:
raise ValueError(
"normalized_actions must match raw_actions shape; got "
f"{tuple(normalized_actions.shape)} vs {tuple(raw_actions.shape)}"
)
model_actions = raw_actions.clone()
model_actions[..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS] = (
normalized_actions[..., ROBOT_BIMANUAL_CARTESIAN_NON_ROT_DIMS]
)
return self.to20(model_actions)

def from20_raw_rotation(
self,
actions20: torch.Tensor,
*,
stats: dict[str, Any] | None = None,
norm_mode: str = "quantile",
unnormalize_non_rotation: bool = False,
) -> torch.Tensor:
model_actions = self.from20(actions20)
if not unnormalize_non_rotation:
return model_actions
if model_actions.shape[-1] != 14:
raise ValueError(
"RobotBimanual raw-rotation decoding expected 14D Cartesian actions; "
f"got {model_actions.shape[-1]} dims"
)
if stats is None:
raise ValueError("stats are required to unnormalize non-rotation dims")
return _unnormalize_robot_bimanual_non_rot(model_actions, stats, norm_mode)

def to32_raw_rotation(
self,
raw_actions: torch.Tensor,
*,
normalized_actions: torch.Tensor | None = None,
stats: dict[str, Any] | None = None,
norm_mode: str = "quantile",
) -> torch.Tensor:
return _pad32(
self.to20_raw_rotation(
raw_actions,
normalized_actions=normalized_actions,
stats=stats,
norm_mode=norm_mode,
)
)

def from32_raw_rotation(
self,
actions32: torch.Tensor,
*,
stats: dict[str, Any] | None = None,
norm_mode: str = "quantile",
unnormalize_non_rotation: bool = False,
) -> torch.Tensor:
return self.from20_raw_rotation(
actions32,
stats=stats,
norm_mode=norm_mode,
unnormalize_non_rotation=unnormalize_non_rotation,
)


# ============================================================
# HUMAN CONVERTERS
Expand Down
Loading
Loading