diff --git a/egomimic/algo/pi.py b/egomimic/algo/pi.py index 0415ee478..2ebb3d59e 100644 --- a/egomimic/algo/pi.py +++ b/egomimic/algo/pi.py @@ -137,6 +137,12 @@ def __init__( self.num_steps = getattr(self.config, "num_sampling_steps", 10) self.is_6dof = kwargs.get("is_6dof", True) + # Number of stochastic action-chunk samples drawn per eval batch for the + # reverse-KL metric. >1 enables it (each sample is a full flow-matching + # rollout, so this multiplies eval sampling cost); 1 disables it. + # Disabled by default — each extra sample adds a full sampling pass per + # eval batch, which dominates validation time on large flow models. + self.rkl_samples = kwargs.get("reverse_kl_samples", 1) self.action_converters = action_converters @@ -461,22 +467,58 @@ def forward_eval(self, batch): num_steps=self.num_steps, ) - predictions = OrderedDict() - ref = _batch[ac_key] - B, T, D = ref.shape - - converter = self.action_registry.get(embodiment_id, ac_key) - pred_actions_orig = converter.from32(pred_actions) - - pred = pred_actions_orig[:, :T, :D] - predictions[ac_key] = pred - - unnorm_actions = self.norm_stats.unnormalize(predictions, embodiment_id) - for key in unnorm_actions: - unnorm_preds[f"{embodiment_name}_{key}"] = unnorm_actions[key] + pred = self._postprocess_sampled_actions( + pred_actions, _batch, embodiment_id, ac_key + ) + unnorm_preds[f"{embodiment_name}_{ac_key}"] = pred return unnorm_preds + def _postprocess_sampled_actions(self, pred_actions, _batch, embodiment_id, ac_key): + """Turn raw ``sample_actions`` output into an unnormalized action chunk: + ``from32`` -> slice to ``(B, T, D)`` -> ``norm_stats.unnormalize``. Shared + by ``forward_eval`` and ``sample_action_chunks`` so the reverse-KL samples + go through the identical pipeline as the headline prediction.""" + ref = _batch[ac_key] + B, T, D = ref.shape + converter = self.action_registry.get(embodiment_id, ac_key) + pred_actions_orig = converter.from32(pred_actions) + pred = pred_actions_orig[:, :T, :D] + unnorm = self.norm_stats.unnormalize(OrderedDict({ac_key: pred}), embodiment_id) + return unnorm[ac_key] + + @torch.no_grad() + def sample_action_chunks(self, _batch, embodiment_id, M): + """Draw ``M`` independent stochastic action chunks for one embodiment's + batch and return them stacked as ``(M, B, T, D)``, unnormalized, on + ``self.device``. Each ``sample_actions`` call with ``noise=None`` draws + fresh Gaussian noise, so the ``M`` chunks are independent policy samples. + + ``_batch`` must be the normalized batch element (same obs space as + ``forward_eval``); do not pass an unnormalized batch. + """ + proprio_keys = self.proprio_keys[embodiment_id] + lang_keys = self.lang_keys[embodiment_id] + ac_key = self.ac_keys[embodiment_id] + camera_keys = self.camera_keys.get(embodiment_id, self.pi_cam_keys) + embodiment_name = get_embodiment(embodiment_id).lower() + processed_obs, _ = self._robomimic_to_pi_data( + _batch, camera_keys, proprio_keys, lang_keys, ac_key, embodiment_name + ) + samples = [] + for _ in range(int(M)): + pred_actions = self.nets["policy"].sample_actions( + device=self.device, + observation=processed_obs, + noise=None, + num_steps=self.num_steps, + ) + pred = self._postprocess_sampled_actions( + pred_actions, _batch, embodiment_id, ac_key + ) + samples.append(pred.unsqueeze(0)) + return torch.cat(samples, dim=0) + @override def compute_losses(self, predictions, batch): """ diff --git a/egomimic/eval/eval_act.py b/egomimic/eval/eval_act.py index 350a2ce85..41386f7a2 100644 --- a/egomimic/eval/eval_act.py +++ b/egomimic/eval/eval_act.py @@ -11,7 +11,7 @@ class ACTEvalVideo(EvalVideo): and draws predicted/ground-truth trajectories on the visualization image. """ - def compute_metrics_and_viz(self, batch): + def compute_metrics_and_viz(self, batch, do_viz=True): algo = self.model preds = algo.forward_eval(batch) # ground truth normalized; unnormalize for direct comparison. @@ -29,7 +29,9 @@ def compute_metrics_and_viz(self, batch): preds[ac_key][:, -1].cpu(), batch[ac_key][:, -1].cpu() ) - ims = {algo.embodiment_id: self._visualize_preds(preds, batch)} + ims = {} + if do_viz: + ims = {algo.embodiment_id: self._visualize_preds(preds, batch)} return metrics, ims def _visualize_preds(self, predictions, batch): diff --git a/egomimic/eval/eval_hpt.py b/egomimic/eval/eval_hpt.py index 87c9188cf..9f41b5c4e 100644 --- a/egomimic/eval/eval_hpt.py +++ b/egomimic/eval/eval_hpt.py @@ -23,7 +23,7 @@ class HPTEvalVideo(EvalVideo): and the viz video. """ - def compute_metrics_and_viz(self, batch): + def compute_metrics_and_viz(self, batch, do_viz=True): algo = self.model preds = algo.forward_eval(batch) @@ -191,8 +191,9 @@ def compute_metrics_and_viz(self, batch): preds_for_viz = dict(preds) preds_for_viz[main_pred_key] = pred_batch_viz[ac_key] - ims = self._visualize_preds(preds_for_viz, gt_batch_viz) - images_dict[embodiment_id] = ims + if do_viz: + ims = self._visualize_preds(preds_for_viz, gt_batch_viz) + images_dict[embodiment_id] = ims if total_loss is not None and n_loss_embodiments > 0: metrics["Valid/action_loss"] = total_loss / n_loss_embodiments diff --git a/egomimic/eval/eval_latent.py b/egomimic/eval/eval_latent.py index 2d663b3fc..f5e5a785a 100644 --- a/egomimic/eval/eval_latent.py +++ b/egomimic/eval/eval_latent.py @@ -221,7 +221,7 @@ def on_validation_start(self): exist_ok=True, ) - def compute_metrics_and_viz(self, batch): + def compute_metrics_and_viz(self, batch, do_viz=True): algo = self.model metrics = {} images_dict = {} @@ -282,7 +282,7 @@ def compute_metrics_and_viz(self, batch): unnorm_batch[ac_key][:, -1].cpu(), ) - if self.viz_func is not None: + if do_viz and self.viz_func is not None: images_dict[embodiment_id] = self._visualize_preds( unnorm_preds, unnorm_batch ) diff --git a/egomimic/eval/eval_pi.py b/egomimic/eval/eval_pi.py index a508f7943..7245c94dd 100644 --- a/egomimic/eval/eval_pi.py +++ b/egomimic/eval/eval_pi.py @@ -5,6 +5,10 @@ from egomimic.eval.eval_video import EvalVideo from egomimic.rldb.embodiment.embodiment import Embodiment, get_embodiment +from egomimic.utils.egomimicUtils import ( + frechet_gaussian_over_time, + reverse_kl_from_samples, +) from egomimic.utils.pose_utils import bimanual_cartesian_layout @@ -12,13 +16,17 @@ class PIEvalVideo(EvalVideo): """ Eval class for PI models. Per embodiment, computes: - val loss (flow-matching loss, same as training; also aggregated as ``Valid/action_loss``) - - paired/final MSE in the model's native wrist frame - - paired/final MSE in cam frame, when a ``transform_lists`` entry is configured + - paired MSE in the model's native wrist frame (full vector + xyz/ypr split) + - final-timestep MSE in the native frame (end-of-chunk, longest horizon) + - native-frame Fréchet Gaussian (avg/min/max) over the single prediction + - native-frame reverse KL from ``rkl_samples`` stochastic samples, when + the algo's ``reverse_kl_samples > 1`` + - paired + final MSE in cam frame, when a ``transform_lists`` entry is configured The revert transform is applied once and reused for both the cam-frame MSE and the viz video. """ - def compute_metrics_and_viz(self, batch): + def compute_metrics_and_viz(self, batch, do_viz=True): algo = self.model preds = algo.forward_eval(batch) @@ -73,19 +81,63 @@ def _split_mse(pred_t, gt_t): metrics[f"Valid/{pred_key}_paired_mse_avg"] = mse( preds[pred_key].cpu(), _batch[ac_key].cpu() ) + # Last-timestep-only MSE: the end of the chunk is the + # longest-horizon (hardest) prediction, so this reads as a + # worst-end signal vs the chunk-wide ``paired`` average. metrics[f"Valid/{pred_key}_final_mse_avg"] = mse( - preds[pred_key][:, -1].cpu(), _batch[ac_key][:, -1].cpu() + preds[pred_key][:, -1].cpu().contiguous(), + _batch[ac_key][:, -1].cpu().contiguous(), ) xyz_p, ypr_p = _split_mse(preds[pred_key], _batch[ac_key]) if xyz_p is not None: metrics[f"Valid/{pred_key}_xyz_paired_mse_avg"] = xyz_p metrics[f"Valid/{pred_key}_ypr_paired_mse_avg"] = ypr_p - xyz_f, ypr_f = _split_mse( - preds[pred_key][:, -1:], _batch[ac_key][:, -1:] - ) - if xyz_f is not None: - metrics[f"Valid/{pred_key}_xyz_final_mse_avg"] = xyz_f - metrics[f"Valid/{pred_key}_ypr_final_mse_avg"] = ypr_f + + # Distributional metrics (native frame only). Fréchet compares + # the time-distribution shape of the single prediction; reverse + # KL needs M independent stochastic samples and is gated on the + # algo's ``rkl_samples`` (8x extra sampling per batch). + fd = frechet_gaussian_over_time(preds[pred_key], _batch[ac_key]) + metrics[f"Valid/{pred_key}_frechet_gauss_avg"] = fd.mean().item() + metrics[f"Valid/{pred_key}_frechet_gauss_min"] = fd.min().item() + metrics[f"Valid/{pred_key}_frechet_gauss_max"] = fd.max().item() + + if algo.rkl_samples and algo.rkl_samples > 1: + M = int(algo.rkl_samples) + gt_tensor = _batch[ac_key].to(algo.device) + # Feed the ORIGINAL normalized batch element, not the loop's + # unnormalized ``_batch`` — ``norm_stats.unnormalize`` also + # denormalizes proprio obs keys, so sampling must run on the + # normalized obs (same as ``forward_eval``). + samples = algo.sample_action_chunks( + batch[embodiment_id], embodiment_id, M + ) + rkl = reverse_kl_from_samples(samples, gt_tensor) + metrics[f"Valid/{pred_key}_reverse_kl_M{M}"] = rkl.item() + + # Best-of-K coverage from the SAME M samples (no extra + # sampling): per-sample paired MSE to GT, reduced over the + # chunk. ``bestof`` = does the policy produce a good action + # in M tries (multimodal coverage); ``mean`` = avg sample + # quality; ``worstof`` = how bad the worst draw is; + # ``diversity`` = mean per-element std across samples. + per_sample_mse = ( + ((samples - gt_tensor.unsqueeze(0)) ** 2) + .flatten(start_dim=2) + .mean(dim=2) + ) # (M, B) + metrics[f"Valid/{pred_key}_bestof{M}_paired_mse"] = ( + per_sample_mse.min(dim=0).values.mean().item() + ) + metrics[f"Valid/{pred_key}_mean{M}_paired_mse"] = ( + per_sample_mse.mean().item() + ) + metrics[f"Valid/{pred_key}_worstof{M}_paired_mse"] = ( + per_sample_mse.max(dim=0).values.mean().item() + ) + metrics[f"Valid/{pred_key}_sample_diversity_M{M}"] = ( + samples.std(dim=0).mean().item() + ) transform_list = self.transform_lists.get(embodiment_name) gt_batch_viz = _batch @@ -100,9 +152,9 @@ def _split_mse(pred_t, gt_t): gt_batch_viz = {**_batch, **gt_t} pred_batch_viz = {**_batch, **pred_t} - # ``.contiguous()`` because ``apply_transform`` returns CPU tensors, - # so ``.cpu()`` here is a no-op and ``[:, -1]`` leaves a non-contiguous - # view that torchmetrics' MSE doesn't accept. + # ``.contiguous()`` because ``apply_transform`` returns CPU + # tensors, so ``.cpu()`` here is a no-op and the merged views can + # be non-contiguous, which torchmetrics' MSE doesn't accept. metrics[f"Valid/{pred_key}_cam_paired_mse_avg"] = mse( pred_batch_viz[ac_key].cpu().contiguous(), gt_batch_viz[ac_key].cpu().contiguous(), @@ -117,18 +169,13 @@ def _split_mse(pred_t, gt_t): if xyz_cp is not None: metrics[f"Valid/{pred_key}_cam_xyz_paired_mse_avg"] = xyz_cp metrics[f"Valid/{pred_key}_cam_ypr_paired_mse_avg"] = ypr_cp - xyz_cf, ypr_cf = _split_mse( - pred_batch_viz[ac_key][:, -1:], gt_batch_viz[ac_key][:, -1:] - ) - if xyz_cf is not None: - metrics[f"Valid/{pred_key}_cam_xyz_final_mse_avg"] = xyz_cf - metrics[f"Valid/{pred_key}_cam_ypr_final_mse_avg"] = ypr_cf preds_for_viz = dict(preds) preds_for_viz[pred_key] = pred_batch_viz[ac_key] - ims = self._visualize_preds(preds_for_viz, gt_batch_viz) - images_dict[embodiment_id] = ims + if do_viz: + ims = self._visualize_preds(preds_for_viz, gt_batch_viz) + images_dict[embodiment_id] = ims if total_loss is not None and n_loss_embodiments > 0: metrics["Valid/action_loss"] = total_loss / n_loss_embodiments diff --git a/egomimic/eval/eval_video.py b/egomimic/eval/eval_video.py index 77e293ffe..cffe239ad 100644 --- a/egomimic/eval/eval_video.py +++ b/egomimic/eval/eval_video.py @@ -24,11 +24,16 @@ def __init__( transform_lists: dict | None = None, one_video_per_task: bool = False, max_frames_per_task: int | None = 1000, + viz_every_n_epochs: int = 1, ): super().__init__() self.trainer = None self.model = None self.viz_func = viz_func + # Validation metrics log every validation; viz video rendering only + # happens on epochs divisible by this (viz is far more expensive than + # the metrics). <=0 disables viz entirely. + self.viz_every_n_epochs = viz_every_n_epochs # Per-embodiment list[Transform] applied once during eval to project # the model's wrist-frame actions back into cam (head) frame. Reused for # both cam-frame MSE and the viz video so we don't transform twice. @@ -55,14 +60,21 @@ def __init__( def video_dir(self): return os.path.join(self.root_dir(), "videos") + def _should_viz(self) -> bool: + if not self.viz_every_n_epochs or self.viz_every_n_epochs <= 0: + return False + return (self.trainer.current_epoch % self.viz_every_n_epochs) == 0 + @abstractmethod - def compute_metrics_and_viz(self, batch): + def compute_metrics_and_viz(self, batch, do_viz=True): """ Run the model's eval forward and compute metrics and visualization frames. Args: batch (dict): processed batch produced by the algo's `process_batch_for_training`. + do_viz (bool): when False, skip producing viz frames (the expensive + part); metrics are still computed and returned. Returns: metrics (dict[str, torch.Tensor | float]) images_dict (dict[embodiment_id, np.ndarray (B, H, W, 3)]) @@ -70,7 +82,7 @@ def compute_metrics_and_viz(self, batch): raise NotImplementedError def on_validation_start(self): - if self.trainer.is_global_zero: + if self.trainer.is_global_zero and self._should_viz(): os.makedirs( os.path.join(self.video_dir(), f"epoch_{self.trainer.current_epoch}"), exist_ok=True, @@ -82,6 +94,8 @@ def _sanitize_task(task: str) -> str: return re.sub(r"[^\w.-]+", "_", str(task)).strip("_") or "unknown" def on_validation_end(self): + if not self._should_viz(): + return for key, buffer in self.val_image_buffer.items(): if self.one_video_per_task: embodiment_id, task = key @@ -113,7 +127,8 @@ def on_validation_end(self): self.val_image_buffer[key] = [] def on_validation_step(self, batch, batch_idx, dataloader_idx=0): - metrics, images_dict = self.compute_metrics_and_viz(batch) + do_viz = self._should_viz() + metrics, images_dict = self.compute_metrics_and_viz(batch, do_viz=do_viz) device = self.trainer.lightning_module.device metrics = { @@ -121,7 +136,7 @@ def on_validation_step(self, batch, batch_idx, dataloader_idx=0): for k, v in metrics.items() } - ## images is now a dict + ## images is now a dict (empty when do_viz is False) for embodiment_id, images in images_dict.items(): os.makedirs( os.path.join( diff --git a/egomimic/eval/test_val_loss_equivalence.py b/egomimic/eval/test_val_loss_equivalence.py new file mode 100644 index 000000000..aeb19f2c3 --- /dev/null +++ b/egomimic/eval/test_val_loss_equivalence.py @@ -0,0 +1,111 @@ +"""Validation-loss equivalence between the Pi and HPT evaluators. + +Both flow-matching configs (``pi0.5_bc_eva``, ``hpt_bc_flow_eva``) compute the *same* +validation loss: flow-matching velocity MSE, i.e. the training loss under ``no_grad``. +HPT's ``bc_flow`` head is an ``FMPolicy`` (``hpt_bc_flow_eva.yaml``), so its +``compute_loss`` matches Pi's ``forward``. The loss-*value* equivalence is proven in +``egomimic/models/test_fm_loss_equivalence.py`` (the shared flow-matching formula). + +This module pins the *validation* side. ``PIEvalVideo`` and ``HPTEvalVideo`` +(``egomimic/eval/eval_pi.py`` and ``eval_hpt.py``) carry a duplicated per-embodiment +loss-aggregation convention: + + metrics["Valid/{emb}_loss"] = preds["{emb}_loss"] # per embodiment + metrics["Valid/action_loss"] = mean of the per-embodiment losses + +Given the same algo predictions, both evaluators must emit identical ``Valid/*loss`` +metrics. We drive the *real* ``compute_metrics_and_viz`` of both with a mocked algo +that returns only the loss keys (no sampled-action keys), which exercises exactly the +loss-aggregation path and skips the per-key MSE / Fréchet / reverse-KL / viz machinery. +""" + +import types + +import torch + +from egomimic.eval.eval_hpt import HPTEvalVideo +from egomimic.eval.eval_pi import PIEvalVideo +from egomimic.rldb.embodiment.embodiment import get_embodiment, get_embodiment_id + +EVA = get_embodiment_id("eva_bimanual") +ARIA = get_embodiment_id("aria_bimanual") + + +def _mock_algo(losses_by_emb): + """Algo stand-in whose ``forward_eval`` returns only ``{emb}_loss`` keys. + + Without sampled-action prediction keys, every ``if pred_key in preds`` branch in + both evaluators is skipped, so only the loss aggregation runs. The HPT-only + attributes are set so its optional (aux / shared / reverse-KL) branches no-op. + """ + preds = { + f"{get_embodiment(eid).lower()}_loss": torch.tensor(v, dtype=torch.float32) + for eid, v in losses_by_emb.items() + } + return types.SimpleNamespace( + forward_eval=lambda batch: preds, + norm_stats=types.SimpleNamespace(unnormalize=lambda b, eid: b), + ac_keys={eid: "actions_cartesian" for eid in losses_by_emb}, + shared_ac_key=None, + auxiliary_ac_keys={}, + rkl_samples=1, + device=torch.device("cpu"), + ) + + +def _loss_metrics(evaluator_cls, losses_by_emb): + evaluator = evaluator_cls() + evaluator.model = _mock_algo(losses_by_emb) + # _batch contents are unused once the sampled-action keys are absent. + batch = {eid: {} for eid in losses_by_emb} + metrics, _ = evaluator.compute_metrics_and_viz(batch, do_viz=False) + return {k: v for k, v in metrics.items() if k.endswith("loss")} + + +def _assert_metrics_equal(a, b): + assert set(a) == set(b), (set(a), set(b)) + for k in a: + torch.testing.assert_close(torch.as_tensor(a[k]), torch.as_tensor(b[k])) + + +def test_pi_hpt_val_loss_metrics_match_single_embodiment(): + losses = {EVA: 0.7} + pi = _loss_metrics(PIEvalVideo, losses) + hpt = _loss_metrics(HPTEvalVideo, losses) + + # Same keys + values across the two evaluators ... + _assert_metrics_equal(pi, hpt) + # ... and the expected convention: per-embodiment loss + the aggregate. + assert set(pi) == {"Valid/eva_bimanual_loss", "Valid/action_loss"} + torch.testing.assert_close(pi["Valid/eva_bimanual_loss"], torch.tensor(0.7)) + torch.testing.assert_close(pi["Valid/action_loss"], torch.tensor(0.7)) + + +def test_pi_hpt_val_loss_metrics_match_multi_embodiment(): + losses = {EVA: 0.4, ARIA: 1.0} + pi = _loss_metrics(PIEvalVideo, losses) + hpt = _loss_metrics(HPTEvalVideo, losses) + + _assert_metrics_equal(pi, hpt) + assert set(pi) == { + "Valid/eva_bimanual_loss", + "Valid/aria_bimanual_loss", + "Valid/action_loss", + } + # Valid/action_loss is the mean over embodiments. + torch.testing.assert_close(pi["Valid/action_loss"], torch.tensor((0.4 + 1.0) / 2)) + + +def test_no_action_loss_when_no_per_embodiment_loss(): + """If forward_eval emits no ``{emb}_loss``, neither evaluator emits Valid/action_loss.""" + + def _no_loss_algo(): + algo = _mock_algo({EVA: 0.0}) + algo.forward_eval = lambda batch: {} # drop the loss key + return algo + + for cls in (PIEvalVideo, HPTEvalVideo): + evaluator = cls() + evaluator.model = _no_loss_algo() + metrics, _ = evaluator.compute_metrics_and_viz({EVA: {}}, do_viz=False) + assert "Valid/action_loss" not in metrics diff --git a/egomimic/hydra_configs/data/mecka_pi.yaml b/egomimic/hydra_configs/data/mecka_pi.yaml index 2bd23921a..d0286979d 100644 --- a/egomimic/hydra_configs/data/mecka_pi.yaml +++ b/egomimic/hydra_configs/data/mecka_pi.yaml @@ -16,7 +16,7 @@ train_datasets: filters: _target_: egomimic.rldb.filters.DatasetFilter filter_lambdas: - - "lambda row: row['lab'] == 'mecka' and row['task'] == 'fold_clothes'" + - "lambda row: row['lab'] == 'mecka'" mode: train valid_datasets: diff --git a/egomimic/hydra_configs/evaluator/eval_hpt.yaml b/egomimic/hydra_configs/evaluator/eval_hpt.yaml index 901be3074..bc822e9af 100644 --- a/egomimic/hydra_configs/evaluator/eval_hpt.yaml +++ b/egomimic/hydra_configs/evaluator/eval_hpt.yaml @@ -3,3 +3,8 @@ defaults: - _self_ _target_: egomimic.eval.eval_hpt.HPTEvalVideo + +# Validation metrics log every validation (controlled by trainer +# check_val_every_n_epoch); viz video rendering only runs on epochs divisible by +# this. Set to a multiple of check_val_every_n_epoch so viz lands on a val epoch. +viz_every_n_epochs: 200 diff --git a/egomimic/hydra_configs/evaluator/eval_pi.yaml b/egomimic/hydra_configs/evaluator/eval_pi.yaml index 3b4485c24..4cd64bc6b 100644 --- a/egomimic/hydra_configs/evaluator/eval_pi.yaml +++ b/egomimic/hydra_configs/evaluator/eval_pi.yaml @@ -8,6 +8,11 @@ _target_: egomimic.eval.eval_pi.PIEvalVideo # in on_validation_end. Opt-in for eval-only runs that span multiple tasks. one_video_per_task: false +# Validation metrics log every validation (controlled by trainer +# check_val_every_n_epoch); viz video rendering only runs on epochs divisible by +# this. Set to a multiple of check_val_every_n_epoch so viz lands on a val epoch. +viz_every_n_epochs: 100 + # Cap each (embodiment, task) video at this many frames. Only takes effect # when one_video_per_task=true. Set null to disable. max_frames_per_task: 1000 diff --git a/egomimic/hydra_configs/model/pi0.5_base.yaml b/egomimic/hydra_configs/model/pi0.5_base.yaml index b3fa8f362..37977d48a 100644 --- a/egomimic/hydra_configs/model/pi0.5_base.yaml +++ b/egomimic/hydra_configs/model/pi0.5_base.yaml @@ -39,6 +39,12 @@ robomimic_model: state_num_bins: 256 control_mode: null + # Stochastic samples per eval batch for the reverse-KL metric (>1 enables it; + # each sample is a full flow-matching rollout). 1 disables it. Off by default + # because each extra sample adds a full sampling pass per eval batch and + # dominates validation time on large flow models; raise (e.g. 8) to enable. + reverse_kl_samples: 1 + train_image_augs: _target_: torchvision.transforms.Compose transforms: @@ -52,7 +58,7 @@ robomimic_model: size: 224 interpolation: 3 -enable_grad_norm: false +enable_grad_norm: true optimizer: _target_: torch.optim.AdamW @@ -66,7 +72,7 @@ scheduler: _target_: transformers.optimization.get_cosine_with_min_lr_schedule_with_warmup _partial_: true num_warmup_steps: 1000 - num_training_steps: 200000 + num_training_steps: 150000 num_cycles: 0.5 min_lr: 1e-5 diff --git a/egomimic/hydra_configs/train_zarr_cartesian.yaml b/egomimic/hydra_configs/train_zarr_cartesian.yaml index d325ad8ce..66e15c59f 100644 --- a/egomimic/hydra_configs/train_zarr_cartesian.yaml +++ b/egomimic/hydra_configs/train_zarr_cartesian.yaml @@ -15,11 +15,21 @@ description: test ckpt_path: null mode: train +# When mode=eval, keep the configured logger (e.g. wandb) so validation metrics +# are pushed to the dashboard instead of only printing. Set false to restore the +# historical "eval disables logger" behavior. +eval_logger: true + # Optional second evaluator that runs against the train_viz dataloader. # Set via override, e.g. `+train_viz_evaluator=train_viz_pi`, together with a # data config that defines `train_viz_datasets`. train_viz_evaluator: null +# When true, run a single validation pass (metrics + viz) at epoch 0 before any +# parameter updates, as a pre-fit baseline. Uses the in-memory init weights, so +# this is the untrained/pretrained-backbone baseline. Skipped on SLURM requeue. +val_at_start: false + hydra: run: # Dir should be experiment_name/description_{timestamp} diff --git a/egomimic/hydra_configs/train_zarr_cartesian_pi.yaml b/egomimic/hydra_configs/train_zarr_cartesian_pi.yaml index c625cbd71..eabda697c 100644 --- a/egomimic/hydra_configs/train_zarr_cartesian_pi.yaml +++ b/egomimic/hydra_configs/train_zarr_cartesian_pi.yaml @@ -5,3 +5,8 @@ defaults: - override trainer: ddp_pi - override evaluator: eval_pi - override hydra/launcher: submitit_pace + - _self_ + +# Pi runs get a pre-fit baseline (val metrics + viz) at epoch 0 by default. +# Override with `val_at_start=false` to skip. +val_at_start: true diff --git a/egomimic/hydra_configs/trainer/default.yaml b/egomimic/hydra_configs/trainer/default.yaml index 0c5a48331..528a38611 100644 --- a/egomimic/hydra_configs/trainer/default.yaml +++ b/egomimic/hydra_configs/trainer/default.yaml @@ -11,9 +11,9 @@ devices: 1 # mixed precision for extra speed-up precision: bf16 limit_train_batches: 100 -limit_val_batches: 80 +limit_val_batches: 200 # perform a validation loop every N training epochs -check_val_every_n_epoch: 50 +check_val_every_n_epoch: 25 # set True to to ensure deterministic results # makes training slower but gives more reproducibility than just setting seeds diff --git a/egomimic/models/test_fm_loss_equivalence.py b/egomimic/models/test_fm_loss_equivalence.py new file mode 100644 index 000000000..da505ffd9 --- /dev/null +++ b/egomimic/models/test_fm_loss_equivalence.py @@ -0,0 +1,250 @@ +"""Equivalence tests for the Pi and HPT training losses. + +Both model families train with the *same* flow-matching velocity-prediction loss, +but the math lives in two separate implementations: + + Pi : external/openpi/src/openpi/models_pytorch/pi0_pytorch.py + ``PI0Pytorch.forward`` (interpolation/target/MSE, ~lines 326-373) + HPT: egomimic/models/fm_policy.py + ``FMPolicy.predict`` + ``DenoisingPolicy.loss_fn`` + +Given actions ``a``, noise ``e ~ N(0, 1)`` and time ``t ~ Beta(1.5, 1.0)*0.999+0.001`` +both must compute: + + x_t = t * e + (1 - t) * a # noised sample fed to the velocity model + u_t = e - a # target (constant) velocity + loss = MSE(v_t, u_t) # v_t is the model's velocity prediction + +If these conventions ever drift apart (e.g. a flipped sign on ``u_t`` or a swapped +interpolation), train/val losses stop being comparable across the two families. These +tests pin the convention by running the *real* loss code of both paths with identical +inputs and a shared, fixed velocity output ``V`` (the differing neural nets are stubbed +out so only the loss formulation is compared). +""" + +import types + +import numpy as np +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from egomimic.models.fm_policy import FMPolicy + +try: + import openpi.models_pytorch.pi0_pytorch as pim + + _HAS_OPENPI = True +except Exception: # pragma: no cover - exercised only when openpi is absent + pim = None + _HAS_OPENPI = False + +requires_openpi = pytest.mark.skipif(not _HAS_OPENPI, reason="openpi not importable") + +# Small shapes; H is an arbitrary stubbed hidden width for the Pi network surface. +B, T, D, H = 3, 4, 7, 5 +EMB = "eva_bimanual" + + +def _rand(shape, seed): + """Deterministic float32 tensor that does not depend on torch's RNG. + + The HPT path is driven by monkeypatching ``torch.randn`` (see below), so inputs + must be built with a torch-RNG-independent source. + """ + rng = np.random.default_rng(seed) + return torch.from_numpy(rng.standard_normal(size=shape).astype(np.float32)) + + +def _fixed_inputs(): + """Shared (actions, noise, time, V) ground truth for an equivalence comparison.""" + actions = _rand((B, T, D), seed=1) + noise = _rand((B, T, D), seed=2) + rng = np.random.default_rng(3) + time = torch.from_numpy(rng.uniform(0.05, 0.95, size=(B,)).astype(np.float32)) + velocity = _rand((B, T, D), seed=4) # the shared, fixed model output v_t + return actions, noise, time, velocity + + +def _fm_reference(actions, noise, time): + """The agreed flow-matching convention, written once as the spec.""" + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1.0 - time_expanded) * actions + u_t = noise - actions + return x_t, u_t + + +class _StubVelocityModel(nn.Module): + """Stands in for the HPT velocity net: returns a fixed ``V`` and captures inputs.""" + + def __init__(self, velocity): + super().__init__() + self.velocity = velocity + self.captured = {} + + def forward(self, x_t, time, global_cond): + self.captured["x_t"] = x_t.detach().clone() + self.captured["time"] = time.detach().clone() + return self.velocity + + +def _run_hpt(actions, noise, time, velocity, monkeypatch): + """Run the real ``FMPolicy.predict`` + ``loss_fn`` with injected noise/time. + + ``predict`` samples noise/time internally, so we monkeypatch the two RNG sources + to inject the shared ground-truth values, then read back what the model received. + """ + model = _StubVelocityModel(velocity) + policy = FMPolicy(model=model, action_horizon=T, infer_ac_dims={EMB: D}) + policy.eval() + + # Inject the shared noise: predict calls ``torch.randn(actions.shape, ...)``. + monkeypatch.setattr(torch, "randn", lambda *a, **k: noise) + # Inject the shared time: predict draws ``Beta(a, b).sample(...)`` then applies + # ``t = raw*0.999 + 0.001``. Hand back ``raw`` so the transform yields ``time``. + raw = (time - 0.001) / 0.999 + monkeypatch.setattr( + torch.distributions, + "Beta", + lambda a, b: types.SimpleNamespace(sample=lambda shape: raw), + ) + + pred, target = policy.predict(actions, global_cond=torch.zeros(B, H)) + loss = policy.loss_fn(pred, target) + return { + "x_t": model.captured["x_t"], + "time": model.captured["time"], + "pred": pred, + "target": target, + "loss": loss, + } + + +def _run_pi(actions, noise, time, velocity): + """Run the real ``PI0Pytorch.forward`` loss math with the transformer stubbed out. + + We bypass ``__init__`` (which builds PaliGemma) and supply only the network surface + ``forward`` touches, so the genuine interpolation/target/MSE lines execute while the + expensive model is replaced by a fixed ``V``. ``forward`` returns the per-element MSE. + """ + pi = pim.PI0Pytorch.__new__(pim.PI0Pytorch) + nn.Module.__init__(pi) # set up nn.Module's attribute machinery + captured = {} + + pi.config = types.SimpleNamespace(action_horizon=T) + pi._preprocess_observation = lambda observation, train=True: (None,) * 5 + pi.embed_prefix = lambda *a: ( + torch.zeros(B, 1, H), + torch.ones(B, 1, dtype=torch.bool), + torch.zeros(B, 1), + ) + + def _embed_suffix(state, noisy_actions, timestep): + captured["x_t"] = noisy_actions.detach().clone() # x_t computed by forward + return ( + torch.zeros(B, T, H), + torch.ones(B, T, dtype=torch.bool), + torch.zeros(B, T), + None, + ) + + pi.embed_suffix = _embed_suffix + # forward checks q_proj.weight.dtype against bfloat16; float32 skips the cast branch. + pi.paligemma_with_expert = types.SimpleNamespace( + paligemma=types.SimpleNamespace( + language_model=types.SimpleNamespace( + layers=[ + types.SimpleNamespace( + self_attn=types.SimpleNamespace( + q_proj=types.SimpleNamespace(weight=torch.zeros(1)) + ) + ) + ] + ) + ), + forward=lambda **kw: ((None, torch.cat(kw["inputs_embeds"], dim=1)), None), + ) + pi._prepare_attention_masks_4d = lambda att_2d_masks: att_2d_masks + pi._apply_checkpoint = lambda func, *a, **k: func(*a, **k) + pi.action_out_proj = lambda suffix_out: velocity # fixed v_t + + loss_elementwise = pi.forward(None, actions, noise=noise, time=time) + return {"x_t": captured["x_t"], "loss_elementwise": loss_elementwise} + + +def test_hpt_predict_matches_fm_reference(monkeypatch): + actions, noise, time, velocity = _fixed_inputs() + out = _run_hpt(actions, noise, time, velocity, monkeypatch) + x_ref, u_ref = _fm_reference(actions, noise, time) + + torch.testing.assert_close(out["time"], time) + torch.testing.assert_close(out["x_t"], x_ref) + torch.testing.assert_close(out["target"], u_ref) # u_t = noise - actions + torch.testing.assert_close(out["pred"], velocity) + torch.testing.assert_close(out["loss"], F.mse_loss(velocity, u_ref)) + + +@requires_openpi +def test_pi_forward_matches_fm_reference(): + actions, noise, time, velocity = _fixed_inputs() + out = _run_pi(actions, noise, time, velocity) + x_ref, u_ref = _fm_reference(actions, noise, time) + + torch.testing.assert_close(out["x_t"], x_ref) + # PI returns per-element MSE(u_t, v_t) with reduction="none". + torch.testing.assert_close(out["loss_elementwise"], (u_ref - velocity) ** 2) + torch.testing.assert_close( + out["loss_elementwise"].mean(), F.mse_loss(velocity, u_ref) + ) + + +@requires_openpi +def test_pi_and_hpt_losses_match(monkeypatch): + """The headline check: identical inputs + identical v_t => identical loss.""" + actions, noise, time, velocity = _fixed_inputs() + hpt = _run_hpt(actions, noise, time, velocity, monkeypatch) + pi = _run_pi(actions, noise, time, velocity) + + # Same noised sample (interpolation convention) ... + torch.testing.assert_close(pi["x_t"], hpt["x_t"]) + # ... and same scalar loss (target + MSE convention). + torch.testing.assert_close(pi["loss_elementwise"].mean(), hpt["loss"]) + + +@requires_openpi +def test_validation_loss_value_matches(monkeypatch): + """The per-embodiment validation loss reduces to the same scalar for both. + + ``forward_eval`` turns the flow-matching loss into one scalar ``{emb}_loss``: + Pi (egomimic/algo/pi.py) does ``forward(obs, act).mean()``; HPT + (egomimic/algo/hpt.py) does ``compute_loss(...)`` which, for the ``bc_flow`` + FMPolicy head, is ``loss_fn`` == mean MSE. Same inputs => same val loss. + """ + actions, noise, time, velocity = _fixed_inputs() + hpt = _run_hpt(actions, noise, time, velocity, monkeypatch) + pi = _run_pi(actions, noise, time, velocity) + + _, u_ref = _fm_reference(actions, noise, time) + expected = F.mse_loss(velocity, u_ref) # mean velocity MSE + + torch.testing.assert_close(hpt["loss"], expected) # HPT compute_loss (loss_fn mean) + torch.testing.assert_close( + pi["loss_elementwise"].mean(), expected + ) # Pi forward().mean() + + +@requires_openpi +def test_time_sampler_matches(): + """Pi's ``sample_time`` and HPT's beta branch draw from the same distribution.""" + bare_pi = pim.PI0Pytorch.__new__(pim.PI0Pytorch) + dev = torch.device("cpu") + + torch.manual_seed(0) + t_pi = bare_pi.sample_time(B, dev) + + # HPT (fm_policy.predict, beta branch): Beta(1.5, 1.0).sample()*0.999 + 0.001. + torch.manual_seed(0) + t_hpt = torch.distributions.Beta(1.5, 1.0).sample((B,)) * 0.999 + 0.001 + + torch.testing.assert_close(t_pi, t_hpt) diff --git a/egomimic/trainHydra.py b/egomimic/trainHydra.py index bf62b5a62..2f6e2769f 100644 --- a/egomimic/trainHydra.py +++ b/egomimic/trainHydra.py @@ -189,11 +189,17 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: else: raise ValueError("Config must specify either `mode` or `train`/`eval` booleans") - # In eval mode, apply trainer overrides from the eval object and disable logger + # In eval mode, apply trainer overrides from the eval object. By default the + # logger is disabled (metrics only print / land in callback_metrics); set + # `eval_logger=true` to keep the configured logger (e.g. wandb) so the + # Valid/* metrics from PIEvalVideo are pushed to the dashboard. Videos are + # always written to disk regardless of the logger. if mode == "eval": eval_obj: Eval = hydra.utils.instantiate(cfg.evaluator) + keep_logger = cfg.get("eval_logger", False) log.info( - "Eval mode: applying trainer overrides from eval config, disabling logger" + "Eval mode: applying trainer overrides from eval config; " + + ("keeping logger" if keep_logger else "disabling logger") ) with open_dict(cfg): for k, v in eval_obj.override_dict.items(): @@ -201,7 +207,8 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: cfg.trainer.devices = 1 cfg.trainer.num_nodes = 1 cfg.trainer.num_sanity_val_steps = 0 - cfg.logger = None + if not keep_logger: + cfg.logger = None log.info("Instantiating loggers...") logger: List[Logger] = instantiate_loggers(cfg.get("logger")) @@ -230,10 +237,11 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info("Logging hyperparameters!") log_hyperparameters(object_dict) - if ( + is_requeue = bool( os.environ.get("SLURM_JOB_ID") and os.environ.get("SLURM_RESTART_COUNT", "0") != "0" - ): + ) + if is_requeue: last_ckpt_path = os.path.join( trainer.default_root_dir, "checkpoints", "last.ckpt" ) @@ -253,6 +261,11 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: train_viz_eval_obj.trainer = trainer train_viz_eval_obj.model = model.model model.train_viz_evaluator = train_viz_eval_obj + if cfg.get("val_at_start", False) and not is_requeue: + log.info( + "val_at_start: running validation + viz at epoch 0 (pre-fit baseline)" + ) + trainer.validate(model=model, datamodule=datamodule) log.info("Starting training!") trainer.fit( model=model,