Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 68 additions & 24 deletions egomimic/robot/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from egomimic.pl_utils.pl_model import ModelWrapper
from egomimic.pl_utils.pl_data_utils import build_tokenized_collate
from egomimic.rldb.embodiment.embodiment import get_embodiment
from egomimic.utils.hydra_utils import find_run_snapshot_path, load_run_snapshot
from egomimic.rldb.embodiment.eva import Eva
from egomimic.robot.eva.eva_kinematics import EvaMinkKinematicsSolver
from egomimic.utils.egomimicUtils import (
Expand Down Expand Up @@ -275,46 +276,91 @@ def __init__(
else:
with open(annotation_path, "r") as f:
self.annotation = f.read().strip()
self.collate_fn = build_tokenized_collate(
max_length=128,
model_name="google/paligemma-3b-mix-224",
sampling_mode="first",
annotation_key="annotations",
default_prompt=self.annotation,
)
self.collate_fn = self._build_collate_from_checkpoint_cfg(self.annotation)

LOCAL_WEIGHT_PATH = "/home/robot/robot_ws/egomimic/algo/pi_checkpoints/pi05_base_pytorch"

@classmethod
def _patch_checkpoint_paths(cls, ckpt_path):
"""Rewrite pytorch_weight_path in the checkpoint's saved config
to point to the local base model weights."""
def _load_checkpoint_cfg(cls, ckpt_path):
"""Load the saved hydra config tree from a checkpoint as a plain dict."""
import torch as _torch
from omegaconf import OmegaConf, DictConfig
ckpt = _torch.load(ckpt_path, map_location="cpu", weights_only=False)
ht = ckpt.get("hyper_parameters", {}).get("config_tree")
if ht is None:
return ckpt_path
return None, ckpt
if isinstance(ht, DictConfig):
cfg = OmegaConf.to_container(ht, resolve=True)
else:
cfg = ht
return cfg, ckpt

@classmethod
def _patch_checkpoint_paths(cls, ckpt_path):
"""Rewrite pytorch_weight_path in the checkpoint's saved config
to point to the local base model weights. Returns (patched_path, cfg)."""
import torch as _torch
from omegaconf import OmegaConf
cfg, ckpt = cls._load_checkpoint_cfg(ckpt_path)
if cfg is None:
return ckpt_path, None
# Navigate to pytorch_weight_path in the config
robomimic = cfg.get("model", {}).get("robomimic_model", {})
config = robomimic.get("config", {})
old_path = config.get("pytorch_weight_path")
if old_path is None or old_path == cls.LOCAL_WEIGHT_PATH:
return ckpt_path
return ckpt_path, cfg
print(f"[rollout] Patching pytorch_weight_path: {old_path} -> {cls.LOCAL_WEIGHT_PATH}")
config["pytorch_weight_path"] = cls.LOCAL_WEIGHT_PATH
ckpt["hyper_parameters"]["config_tree"] = OmegaConf.create(cfg)
patched_path = ckpt_path + ".patched"
_torch.save(ckpt, patched_path)
print(f"[rollout] Patched checkpoint saved to {patched_path}")
return patched_path
return patched_path, cfg

def _build_collate_from_checkpoint_cfg(self, default_prompt):
"""Build a tokenized collate_fn using the build_tokenized_collate flags
saved in the checkpoint's hydra config (under the top-level ``data:``
block). This mirrors the pi prompt formatting the model was trained on
(Task / Embodiment / Control mode / State blocks)."""
data_cfg = (self._ckpt_cfg or {}).get("data", {}) or {}
# TODO: remove (debug — verify build_tokenized_collate flags from ckpt)
print(
f"[rollout][debug] ckpt_cfg top-level keys: "
f"{list((self._ckpt_cfg or {}).keys())}"
)
print(
f"[rollout][debug] data_cfg flags: "
f"proprio={data_cfg.get('proprio')}, "
f"embodiment_label={data_cfg.get('embodiment_label')}, "
f"control_mode={data_cfg.get('control_mode')}, "
f"state_num_bins={data_cfg.get('state_num_bins')}, "
f"model_name={data_cfg.get('model_name')}"
)
return build_tokenized_collate(
max_length=128,
model_name=data_cfg.get("model_name", "google/paligemma-3b-mix-224"),
sampling_mode="first",
annotation_key="annotations",
default_prompt=default_prompt,
proprio_keys=data_cfg.get("proprio_keys"),
state_num_bins=data_cfg.get("state_num_bins", 256),
proprio=bool(data_cfg.get("proprio", False)),
embodiment_label=bool(data_cfg.get("embodiment_label", False)),
control_mode=data_cfg.get("control_mode"),
)

def _load_policy(self):
patched_path = self._patch_checkpoint_paths(self.policy_path)
patched_path, _ = self._patch_checkpoint_paths(self.policy_path)
# The .ckpt only stores the model subtree (see trainHydra._build_model_config_tree),
# so load the full hydra run-snapshot (with the data: block) from disk.
snapshot_path = find_run_snapshot_path(self.policy_path)
if snapshot_path is None:
print(f"[rollout] WARNING: no .hydra/config.yaml found near {self.policy_path}")
self._ckpt_cfg = None
else:
print(f"[rollout] Loaded hydra config from {snapshot_path}")
self._ckpt_cfg = load_run_snapshot(self.policy_path)
policy = ModelWrapper.load_from_checkpoint(
patched_path, weights_only=False, map_location="cpu"
)
Expand Down Expand Up @@ -372,6 +418,7 @@ def rollout_step(self, i, obs):
for transform in self.transform_list:
transform_list_batch = transform.transform(transform_list_batch)
transform_list_batch = self.collate_fn([transform_list_batch])
print(f"[rollout][debug] sampled_prompt: {transform_list_batch.get('sampled_prompt')}") # TODO: remove
if self.arm == "both":
embodiment_name = "eva_bimanual"
elif self.arm == "right":
Expand Down Expand Up @@ -500,14 +547,17 @@ def process_obs_for_transform_list(self, obs):
left_cmd_ee_pose = torch.from_numpy(left_xyzwxyz).view(1, 7).repeat(45, 1)
data["left.cmd_ee_pose"] = left_cmd_ee_pose

# `embodiment` is consumed by build_tokenized_collate's _embodiment_name,
# which calls int() on it to look up the embodiment name; it must be
# the integer id, not the string. The string lives in metadata.robot_name.
if self.arm == "both":
data["embodiment"] = "eva_bimanual"
data["embodiment"] = self.embodiment_id
data["metadata.robot_name"] = "eva_bimanual"
elif self.arm == "right":
data["embodiment"] = "eva_right_arm"
data["embodiment"] = self.embodiment_id
data["metadata.robot_name"] = "eva_right_arm"
elif self.arm == "left":
data["embodiment"] = "eva_left_arm"
data["embodiment"] = self.embodiment_id
data["metadata.robot_name"] = "eva_left_arm"

if self.annotation is not None:
Expand All @@ -531,13 +581,7 @@ def load_annotation(self, annotation_path):
with open(annotation_path, "r") as f:
self.annotation = f.read().strip()
if self.collate_fn is default_collate:
self.collate_fn = build_tokenized_collate(
max_length=128,
model_name="google/paligemma-3b-mix-224",
sampling_mode="first",
annotation_key="annotations",
default_prompt=self.annotation,
)
self.collate_fn = self._build_collate_from_checkpoint_cfg(self.annotation)
print(f"[rollout] Loaded new annotation from {annotation_path}: '{self.annotation}'")
return True

Expand Down
48 changes: 47 additions & 1 deletion egomimic/utils/hydra_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from pathlib import Path
from typing import Any

from hydra import compose, initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf

HYDRA_CONFIG_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "hydra_configs"
Expand Down Expand Up @@ -74,6 +75,51 @@ def find_hydra_config_dir(config_path: str | Path) -> str:
)


def find_run_snapshot_path(start_path: str | Path) -> str | None:
"""Walk up from ``start_path`` looking for a sibling ``.hydra/config.yaml``.

Hydra writes the fully-composed config snapshot for a training run into
``<run_dir>/.hydra/config.yaml``. Given a path inside that run (e.g. a
checkpoint at ``<run_dir>/checkpoints/foo.ckpt``), this returns the
snapshot path, or ``None`` if no ancestor has one.
"""
d = os.path.dirname(os.path.abspath(str(start_path)))
seen: set[str] = set()
while d and d not in seen:
seen.add(d)
candidate = os.path.join(d, ".hydra", "config.yaml")
if os.path.isfile(candidate):
return candidate
parent = os.path.dirname(d)
if parent == d:
break
d = parent
return None


def load_run_snapshot(
start_path: str | Path,
*,
resolve: bool = False,
) -> dict[str, Any] | None:
"""Load the fully-composed hydra config snapshot for a training run.

Walks up from ``start_path`` (typically a checkpoint path) to find the
nearest ``<run_dir>/.hydra/config.yaml`` and loads it as a plain dict.

``resolve`` defaults to ``False`` because run snapshots often contain
interpolations like ``${data.dataset.data_schematic}`` that reference
runtime-only nodes and would crash on resolution. Set ``resolve=True``
only when the caller knows all interpolations are self-contained.

Returns ``None`` if no snapshot is found in any ancestor.
"""
snapshot = find_run_snapshot_path(start_path)
if snapshot is None:
return None
return OmegaConf.to_container(OmegaConf.load(snapshot), resolve=resolve)


def load_config_from_path(
config_path: str | Path,
overrides: list[str] | None = None,
Expand Down
Loading