Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions egomimic/algo/pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ def __init__(

self.num_steps = getattr(self.config, "num_sampling_steps", 10)
self.is_6dof = kwargs.get("is_6dof", True)
# Number of stochastic action-chunk samples drawn per eval batch for the
# reverse-KL metric. >1 enables it (each sample is a full flow-matching
# rollout, so this multiplies eval sampling cost); 1 disables it.
# Disabled by default — each extra sample adds a full sampling pass per
# eval batch, which dominates validation time on large flow models.
self.rkl_samples = kwargs.get("reverse_kl_samples", 1)

self.action_converters = action_converters

Expand Down Expand Up @@ -461,22 +467,58 @@ def forward_eval(self, batch):
num_steps=self.num_steps,
)

predictions = OrderedDict()
ref = _batch[ac_key]
B, T, D = ref.shape

converter = self.action_registry.get(embodiment_id, ac_key)
pred_actions_orig = converter.from32(pred_actions)

pred = pred_actions_orig[:, :T, :D]
predictions[ac_key] = pred

unnorm_actions = self.norm_stats.unnormalize(predictions, embodiment_id)
for key in unnorm_actions:
unnorm_preds[f"{embodiment_name}_{key}"] = unnorm_actions[key]
pred = self._postprocess_sampled_actions(
pred_actions, _batch, embodiment_id, ac_key
)
unnorm_preds[f"{embodiment_name}_{ac_key}"] = pred

return unnorm_preds

def _postprocess_sampled_actions(self, pred_actions, _batch, embodiment_id, ac_key):
"""Turn raw ``sample_actions`` output into an unnormalized action chunk:
``from32`` -> slice to ``(B, T, D)`` -> ``norm_stats.unnormalize``. Shared
by ``forward_eval`` and ``sample_action_chunks`` so the reverse-KL samples
go through the identical pipeline as the headline prediction."""
ref = _batch[ac_key]
B, T, D = ref.shape
converter = self.action_registry.get(embodiment_id, ac_key)
pred_actions_orig = converter.from32(pred_actions)
pred = pred_actions_orig[:, :T, :D]
unnorm = self.norm_stats.unnormalize(OrderedDict({ac_key: pred}), embodiment_id)
return unnorm[ac_key]

@torch.no_grad()
def sample_action_chunks(self, _batch, embodiment_id, M):
"""Draw ``M`` independent stochastic action chunks for one embodiment's
batch and return them stacked as ``(M, B, T, D)``, unnormalized, on
``self.device``. Each ``sample_actions`` call with ``noise=None`` draws
fresh Gaussian noise, so the ``M`` chunks are independent policy samples.

``_batch`` must be the normalized batch element (same obs space as
``forward_eval``); do not pass an unnormalized batch.
"""
proprio_keys = self.proprio_keys[embodiment_id]
lang_keys = self.lang_keys[embodiment_id]
ac_key = self.ac_keys[embodiment_id]
camera_keys = self.camera_keys.get(embodiment_id, self.pi_cam_keys)
embodiment_name = get_embodiment(embodiment_id).lower()
processed_obs, _ = self._robomimic_to_pi_data(
_batch, camera_keys, proprio_keys, lang_keys, ac_key, embodiment_name
)
samples = []
for _ in range(int(M)):
pred_actions = self.nets["policy"].sample_actions(
device=self.device,
observation=processed_obs,
noise=None,
num_steps=self.num_steps,
)
pred = self._postprocess_sampled_actions(
pred_actions, _batch, embodiment_id, ac_key
)
samples.append(pred.unsqueeze(0))
return torch.cat(samples, dim=0)

@override
def compute_losses(self, predictions, batch):
"""
Expand Down
6 changes: 4 additions & 2 deletions egomimic/eval/eval_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class ACTEvalVideo(EvalVideo):
and draws predicted/ground-truth trajectories on the visualization image.
"""

def compute_metrics_and_viz(self, batch):
def compute_metrics_and_viz(self, batch, do_viz=True):
algo = self.model
preds = algo.forward_eval(batch)
# ground truth normalized; unnormalize for direct comparison.
Expand All @@ -29,7 +29,9 @@ def compute_metrics_and_viz(self, batch):
preds[ac_key][:, -1].cpu(), batch[ac_key][:, -1].cpu()
)

ims = {algo.embodiment_id: self._visualize_preds(preds, batch)}
ims = {}
if do_viz:
ims = {algo.embodiment_id: self._visualize_preds(preds, batch)}
return metrics, ims

def _visualize_preds(self, predictions, batch):
Expand Down
7 changes: 4 additions & 3 deletions egomimic/eval/eval_hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class HPTEvalVideo(EvalVideo):
and the viz video.
"""

def compute_metrics_and_viz(self, batch):
def compute_metrics_and_viz(self, batch, do_viz=True):
algo = self.model
preds = algo.forward_eval(batch)

Expand Down Expand Up @@ -191,8 +191,9 @@ def compute_metrics_and_viz(self, batch):
preds_for_viz = dict(preds)
preds_for_viz[main_pred_key] = pred_batch_viz[ac_key]

ims = self._visualize_preds(preds_for_viz, gt_batch_viz)
images_dict[embodiment_id] = ims
if do_viz:
ims = self._visualize_preds(preds_for_viz, gt_batch_viz)
images_dict[embodiment_id] = ims

if total_loss is not None and n_loss_embodiments > 0:
metrics["Valid/action_loss"] = total_loss / n_loss_embodiments
Expand Down
4 changes: 2 additions & 2 deletions egomimic/eval/eval_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def on_validation_start(self):
exist_ok=True,
)

def compute_metrics_and_viz(self, batch):
def compute_metrics_and_viz(self, batch, do_viz=True):
algo = self.model
metrics = {}
images_dict = {}
Expand Down Expand Up @@ -282,7 +282,7 @@ def compute_metrics_and_viz(self, batch):
unnorm_batch[ac_key][:, -1].cpu(),
)

if self.viz_func is not None:
if do_viz and self.viz_func is not None:
images_dict[embodiment_id] = self._visualize_preds(
unnorm_preds, unnorm_batch
)
Expand Down
89 changes: 68 additions & 21 deletions egomimic/eval/eval_pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,28 @@

from egomimic.eval.eval_video import EvalVideo
from egomimic.rldb.embodiment.embodiment import Embodiment, get_embodiment
from egomimic.utils.egomimicUtils import (
frechet_gaussian_over_time,
reverse_kl_from_samples,
)
from egomimic.utils.pose_utils import bimanual_cartesian_layout


class PIEvalVideo(EvalVideo):
"""
Eval class for PI models. Per embodiment, computes:
- val loss (flow-matching loss, same as training; also aggregated as ``Valid/action_loss``)
- paired/final MSE in the model's native wrist frame
- paired/final MSE in cam frame, when a ``transform_lists`` entry is configured
- paired MSE in the model's native wrist frame (full vector + xyz/ypr split)
- final-timestep MSE in the native frame (end-of-chunk, longest horizon)
- native-frame Fréchet Gaussian (avg/min/max) over the single prediction
- native-frame reverse KL from ``rkl_samples`` stochastic samples, when
the algo's ``reverse_kl_samples > 1``
- paired + final MSE in cam frame, when a ``transform_lists`` entry is configured
The revert transform is applied once and reused for both the cam-frame MSE
and the viz video.
"""

def compute_metrics_and_viz(self, batch):
def compute_metrics_and_viz(self, batch, do_viz=True):
algo = self.model
preds = algo.forward_eval(batch)

Expand Down Expand Up @@ -73,19 +81,63 @@ def _split_mse(pred_t, gt_t):
metrics[f"Valid/{pred_key}_paired_mse_avg"] = mse(
preds[pred_key].cpu(), _batch[ac_key].cpu()
)
# Last-timestep-only MSE: the end of the chunk is the
# longest-horizon (hardest) prediction, so this reads as a
# worst-end signal vs the chunk-wide ``paired`` average.
metrics[f"Valid/{pred_key}_final_mse_avg"] = mse(
preds[pred_key][:, -1].cpu(), _batch[ac_key][:, -1].cpu()
preds[pred_key][:, -1].cpu().contiguous(),
_batch[ac_key][:, -1].cpu().contiguous(),
)
xyz_p, ypr_p = _split_mse(preds[pred_key], _batch[ac_key])
if xyz_p is not None:
metrics[f"Valid/{pred_key}_xyz_paired_mse_avg"] = xyz_p
metrics[f"Valid/{pred_key}_ypr_paired_mse_avg"] = ypr_p
xyz_f, ypr_f = _split_mse(
preds[pred_key][:, -1:], _batch[ac_key][:, -1:]
)
if xyz_f is not None:
metrics[f"Valid/{pred_key}_xyz_final_mse_avg"] = xyz_f
metrics[f"Valid/{pred_key}_ypr_final_mse_avg"] = ypr_f

# Distributional metrics (native frame only). Fréchet compares
# the time-distribution shape of the single prediction; reverse
# KL needs M independent stochastic samples and is gated on the
# algo's ``rkl_samples`` (8x extra sampling per batch).
fd = frechet_gaussian_over_time(preds[pred_key], _batch[ac_key])
metrics[f"Valid/{pred_key}_frechet_gauss_avg"] = fd.mean().item()
metrics[f"Valid/{pred_key}_frechet_gauss_min"] = fd.min().item()
metrics[f"Valid/{pred_key}_frechet_gauss_max"] = fd.max().item()

if algo.rkl_samples and algo.rkl_samples > 1:
M = int(algo.rkl_samples)
gt_tensor = _batch[ac_key].to(algo.device)
# Feed the ORIGINAL normalized batch element, not the loop's
# unnormalized ``_batch`` — ``norm_stats.unnormalize`` also
# denormalizes proprio obs keys, so sampling must run on the
# normalized obs (same as ``forward_eval``).
samples = algo.sample_action_chunks(
batch[embodiment_id], embodiment_id, M
)
rkl = reverse_kl_from_samples(samples, gt_tensor)
metrics[f"Valid/{pred_key}_reverse_kl_M{M}"] = rkl.item()

# Best-of-K coverage from the SAME M samples (no extra
# sampling): per-sample paired MSE to GT, reduced over the
# chunk. ``bestof`` = does the policy produce a good action
# in M tries (multimodal coverage); ``mean`` = avg sample
# quality; ``worstof`` = how bad the worst draw is;
# ``diversity`` = mean per-element std across samples.
per_sample_mse = (
((samples - gt_tensor.unsqueeze(0)) ** 2)
.flatten(start_dim=2)
.mean(dim=2)
) # (M, B)
metrics[f"Valid/{pred_key}_bestof{M}_paired_mse"] = (
per_sample_mse.min(dim=0).values.mean().item()
)
metrics[f"Valid/{pred_key}_mean{M}_paired_mse"] = (
per_sample_mse.mean().item()
)
metrics[f"Valid/{pred_key}_worstof{M}_paired_mse"] = (
per_sample_mse.max(dim=0).values.mean().item()
)
metrics[f"Valid/{pred_key}_sample_diversity_M{M}"] = (
samples.std(dim=0).mean().item()
)

transform_list = self.transform_lists.get(embodiment_name)
gt_batch_viz = _batch
Expand All @@ -100,9 +152,9 @@ def _split_mse(pred_t, gt_t):
gt_batch_viz = {**_batch, **gt_t}
pred_batch_viz = {**_batch, **pred_t}

# ``.contiguous()`` because ``apply_transform`` returns CPU tensors,
# so ``.cpu()`` here is a no-op and ``[:, -1]`` leaves a non-contiguous
# view that torchmetrics' MSE doesn't accept.
# ``.contiguous()`` because ``apply_transform`` returns CPU
# tensors, so ``.cpu()`` here is a no-op and the merged views can
# be non-contiguous, which torchmetrics' MSE doesn't accept.
metrics[f"Valid/{pred_key}_cam_paired_mse_avg"] = mse(
pred_batch_viz[ac_key].cpu().contiguous(),
gt_batch_viz[ac_key].cpu().contiguous(),
Expand All @@ -117,18 +169,13 @@ def _split_mse(pred_t, gt_t):
if xyz_cp is not None:
metrics[f"Valid/{pred_key}_cam_xyz_paired_mse_avg"] = xyz_cp
metrics[f"Valid/{pred_key}_cam_ypr_paired_mse_avg"] = ypr_cp
xyz_cf, ypr_cf = _split_mse(
pred_batch_viz[ac_key][:, -1:], gt_batch_viz[ac_key][:, -1:]
)
if xyz_cf is not None:
metrics[f"Valid/{pred_key}_cam_xyz_final_mse_avg"] = xyz_cf
metrics[f"Valid/{pred_key}_cam_ypr_final_mse_avg"] = ypr_cf

preds_for_viz = dict(preds)
preds_for_viz[pred_key] = pred_batch_viz[ac_key]

ims = self._visualize_preds(preds_for_viz, gt_batch_viz)
images_dict[embodiment_id] = ims
if do_viz:
ims = self._visualize_preds(preds_for_viz, gt_batch_viz)
images_dict[embodiment_id] = ims

if total_loss is not None and n_loss_embodiments > 0:
metrics["Valid/action_loss"] = total_loss / n_loss_embodiments
Expand Down
23 changes: 19 additions & 4 deletions egomimic/eval/eval_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ def __init__(
transform_lists: dict | None = None,
one_video_per_task: bool = False,
max_frames_per_task: int | None = 1000,
viz_every_n_epochs: int = 1,
):
super().__init__()
self.trainer = None
self.model = None
self.viz_func = viz_func
# Validation metrics log every validation; viz video rendering only
# happens on epochs divisible by this (viz is far more expensive than
# the metrics). <=0 disables viz entirely.
self.viz_every_n_epochs = viz_every_n_epochs
# Per-embodiment list[Transform] applied once during eval to project
# the model's wrist-frame actions back into cam (head) frame. Reused for
# both cam-frame MSE and the viz video so we don't transform twice.
Expand All @@ -55,22 +60,29 @@ def __init__(
def video_dir(self):
return os.path.join(self.root_dir(), "videos")

def _should_viz(self) -> bool:
if not self.viz_every_n_epochs or self.viz_every_n_epochs <= 0:
return False
return (self.trainer.current_epoch % self.viz_every_n_epochs) == 0

@abstractmethod
def compute_metrics_and_viz(self, batch):
def compute_metrics_and_viz(self, batch, do_viz=True):
"""
Run the model's eval forward and compute metrics and visualization frames.

Args:
batch (dict): processed batch produced by the algo's
`process_batch_for_training`.
do_viz (bool): when False, skip producing viz frames (the expensive
part); metrics are still computed and returned.
Returns:
metrics (dict[str, torch.Tensor | float])
images_dict (dict[embodiment_id, np.ndarray (B, H, W, 3)])
"""
raise NotImplementedError

def on_validation_start(self):
if self.trainer.is_global_zero:
if self.trainer.is_global_zero and self._should_viz():
os.makedirs(
os.path.join(self.video_dir(), f"epoch_{self.trainer.current_epoch}"),
exist_ok=True,
Expand All @@ -82,6 +94,8 @@ def _sanitize_task(task: str) -> str:
return re.sub(r"[^\w.-]+", "_", str(task)).strip("_") or "unknown"

def on_validation_end(self):
if not self._should_viz():
return
for key, buffer in self.val_image_buffer.items():
if self.one_video_per_task:
embodiment_id, task = key
Expand Down Expand Up @@ -113,15 +127,16 @@ def on_validation_end(self):
self.val_image_buffer[key] = []

def on_validation_step(self, batch, batch_idx, dataloader_idx=0):
metrics, images_dict = self.compute_metrics_and_viz(batch)
do_viz = self._should_viz()
metrics, images_dict = self.compute_metrics_and_viz(batch, do_viz=do_viz)

device = self.trainer.lightning_module.device
metrics = {
k: (v.to(device) if torch.is_tensor(v) else torch.tensor(v, device=device))
for k, v in metrics.items()
}

## images is now a dict
## images is now a dict (empty when do_viz is False)
for embodiment_id, images in images_dict.items():
os.makedirs(
os.path.join(
Expand Down
Loading