From 7ec3de8f87aa32cd52650c5fd18716aa19656060 Mon Sep 17 00:00:00 2001 From: Aidan Gao Date: Tue, 16 Jun 2026 14:51:06 -0400 Subject: [PATCH] json filtering for viz --- egomimic/eval/eval_video.py | 71 ++++++++++++++- .../hydra_configs/evaluator/eval_hpt.yaml | 3 +- egomimic/hydra_configs/evaluator/eval_pi.yaml | 1 + egomimic/rldb/filters.py | 27 ++++++ egomimic/trainHydra.py | 90 +++++++++++++++++++ 5 files changed, 190 insertions(+), 2 deletions(-) diff --git a/egomimic/eval/eval_video.py b/egomimic/eval/eval_video.py index 58f7b61c7..7da2b9354 100644 --- a/egomimic/eval/eval_video.py +++ b/egomimic/eval/eval_video.py @@ -21,12 +21,22 @@ def __init__( viz_func: dict = None, transform_lists: dict | None = None, viz_every_n_epochs: int = 1, + viz_episodes: str | None = None, ): super().__init__() self.trainer = None self.model = None self.viz_func = viz_func + self.limit_val_batches = limit_val_batches self.viz_every_n_epochs = viz_every_n_epochs + # Path to a JSON list of episode hashes to visualize. When set, viz is + # produced exclusively from these curated episodes (the inline + # per-val-batch viz is disabled); each listed episode is rendered for up + # to ``limit_val_batches`` batches. The per-episode dataloaders are built + # in trainHydra and assigned to ``viz_dataloaders`` + # (embodiment_name -> {episode_hash: loader}). + self.viz_episodes = viz_episodes + self.viz_dataloaders = {} # Per-embodiment list[Transform] applied once during eval to project # the model's wrist-frame actions back into cam (head) frame. Reused for # both cam-frame MSE and the viz video so we don't transform twice. @@ -51,6 +61,59 @@ def _should_viz(self) -> bool: return False return (self.trainer.current_epoch % self.viz_every_n_epochs) == 0 + def _use_viz_episodes(self) -> bool: + return bool(self.viz_episodes) and bool(self.viz_dataloaders) + + @staticmethod + def _to_device(batch, device): + return { + k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items() + } + + def _write_episode_video(self, embodiment_id, episode_hash, frames): + out_dir = os.path.join( + self.video_dir(), + f"epoch_{self.trainer.current_epoch}", + str(get_embodiment(embodiment_id)), + ) + os.makedirs(out_dir, exist_ok=True) + path = os.path.join(out_dir, f"{episode_hash}.mp4") + tvio.write_video(path, torch.stack(frames), fps=30, video_codec="h264") + + def _run_viz_episode_pass(self): + """Render one video per curated episode (rank 0 only). + + Each episode's dataloader is iterated for up to ``limit_val_batches`` + batches, reusing ``compute_metrics_and_viz`` for the model forward + viz + drawing (metrics are discarded here; they are logged on the normal + validation loader). ``self.model`` is the raw algo, so the forward does + not trigger DDP collectives — safe to run on rank 0 alone. + """ + if not self.trainer.is_global_zero: + return + algo = self.model + device = self.trainer.lightning_module.device + max_batches = self.limit_val_batches + with torch.no_grad(): + for embodiment_name, loaders in self.viz_dataloaders.items(): + for episode_hash, loader in loaders.items(): + frames_by_emb = {} + for i, raw in enumerate(loader): + if max_batches and i >= max_batches: + break + batch = {embodiment_name: self._to_device(raw, device)} + batch = algo.process_batch_for_training(batch) + _, images_dict = self.compute_metrics_and_viz( + batch, do_viz=True + ) + for emb_id, images in images_dict.items(): + frames_by_emb.setdefault(emb_id, []).extend( + torch.from_numpy(images) + ) + for emb_id, frames in frames_by_emb.items(): + if frames: + self._write_episode_video(emb_id, episode_hash, frames) + @abstractmethod def compute_metrics_and_viz(self, batch, do_viz=True): """ @@ -75,6 +138,9 @@ def on_validation_start(self): def on_validation_end(self): if not self._should_viz(): return + if self._use_viz_episodes(): + self._run_viz_episode_pass() + return for key, buffer in self.val_image_buffer.items(): os.makedirs( os.path.join( @@ -98,7 +164,10 @@ def on_validation_end(self): self.val_image_buffer[key] = [] def on_validation_step(self, batch, batch_idx, dataloader_idx=0): - do_viz = self._should_viz() + # When curated viz episodes are configured, viz is produced by a + # dedicated pass in on_validation_end; the inline per-val-batch viz is + # disabled here while metrics still log every validation. + do_viz = self._should_viz() and not self._use_viz_episodes() metrics, images_dict = self.compute_metrics_and_viz(batch, do_viz=do_viz) device = self.trainer.lightning_module.device diff --git a/egomimic/hydra_configs/evaluator/eval_hpt.yaml b/egomimic/hydra_configs/evaluator/eval_hpt.yaml index 5dfc9ce99..e9a4b7a02 100644 --- a/egomimic/hydra_configs/evaluator/eval_hpt.yaml +++ b/egomimic/hydra_configs/evaluator/eval_hpt.yaml @@ -4,4 +4,5 @@ defaults: _target_: egomimic.eval.eval_hpt.HPTEvalVideo -viz_every_n_epochs: 200 \ No newline at end of file +viz_every_n_epochs: 200 +viz_episodes: null \ No newline at end of file diff --git a/egomimic/hydra_configs/evaluator/eval_pi.yaml b/egomimic/hydra_configs/evaluator/eval_pi.yaml index 09cad51bf..5f34b2b39 100644 --- a/egomimic/hydra_configs/evaluator/eval_pi.yaml +++ b/egomimic/hydra_configs/evaluator/eval_pi.yaml @@ -5,6 +5,7 @@ defaults: _target_: egomimic.eval.eval_pi.PIEvalVideo viz_every_n_epochs: 200 +viz_episodes: null # Per-embodiment revert transform. Applied once during validation to project # the model's wrist-frame action chunks back to cam (head) frame, then reused diff --git a/egomimic/rldb/filters.py b/egomimic/rldb/filters.py index 395cd5f75..d44200b5e 100644 --- a/egomimic/rldb/filters.py +++ b/egomimic/rldb/filters.py @@ -39,6 +39,33 @@ def matches(self, row: Mapping[str, Any]) -> bool: return True +class EpisodeHashFilter(DatasetFilter): + """Keep only episodes whose ``episode_hash`` is in a curated set. + + Used to scope a viz dataset to a list of episode hashes (e.g. a curation + ``kept_hashes.json``) without materializing a long lambda-string filter. An + optional ``base`` filter is ANDed in so any existing embodiment/quality + constraints from the source dataset config are preserved. + """ + + def __init__(self, hashes, base: DatasetFilter | None = None) -> None: + super().__init__(filter_lambdas=None) + self.hashes = set(hashes) + self.base = base + + def __repr__(self) -> str: + return f"EpisodeHashFilter(n_hashes={len(self.hashes)}, base={self.base!r})" + + def matches(self, row: Mapping[str, Any]) -> bool: + if row.get("is_deleted", False): + return False + if row.get("episode_hash") not in self.hashes: + return False + if self.base is not None and not self.base.matches(row): + return False + return True + + class ScaleAnnotationDatasetFilter(DatasetFilter): def __init__( self, project_name: str, filter_lambdas: Sequence[str] | None = None diff --git a/egomimic/trainHydra.py b/egomimic/trainHydra.py index 57f864f5c..48f85f88c 100644 --- a/egomimic/trainHydra.py +++ b/egomimic/trainHydra.py @@ -1,4 +1,5 @@ import copy +import json import os import signal from typing import Any, Dict, List, Optional, Tuple @@ -11,10 +12,12 @@ from lightning.pytorch.plugins.environments import SLURMEnvironment from omegaconf import DictConfig, OmegaConf, open_dict from tabulate import tabulate +from torch.utils.data import DataLoader import egomimic.utils.hydra_resolvers # noqa: F401 -- registers OmegaConf resolvers from egomimic.eval.eval import Eval from egomimic.pl_utils.pl_model import ModelWrapper +from egomimic.rldb.filters import DatasetFilter, EpisodeHashFilter from egomimic.rldb.zarr.utils import set_global_seed from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset from egomimic.utils.aws.aws_data_utils import load_env @@ -62,6 +65,84 @@ def _log_dataset_frame_counts(train_datasets: dict, valid_datasets: dict) -> Non log.info("Dataset frame counts:\n" + table) +def _build_viz_episode_dataloaders(cfg, datamodule, norm_stats): + """Build per-episode viz dataloaders from ``cfg.evaluator.viz_episodes``. + + ``viz_episodes`` is a path to a JSON file holding either a flat list of + episode hashes (applied to every valid embodiment) or a dict + ``embodiment_name -> list[hash]``. For each valid dataset, the matching + episodes are resolved (``mode="total"``) and one dataloader is built per + episode; the evaluator caps each episode's viz pass at ``limit_val_batches``. + Returns ``{embodiment_name: {episode_hash: DataLoader}}`` ({} when unset). + """ + viz_episodes = OmegaConf.select(cfg, "evaluator.viz_episodes", default=None) + if not viz_episodes: + return {} + + with open(viz_episodes) as f: + data = json.load(f) + hashes_by_name = ( + {name: set(v) for name, v in data.items()} if isinstance(data, dict) else None + ) + all_hashes = None if isinstance(data, dict) else set(data) + + viz_dataloaders = {} + for name, vcfg in cfg.data.valid_datasets.items(): + if vcfg is None: + continue + hashes = ( + all_hashes if all_hashes is not None else hashes_by_name.get(name, set()) + ) + if not hashes: + continue + resolver = hydra.utils.instantiate(vcfg.resolver) + # Scope to this embodiment by the SQL ``embodiment`` field (the valid + # dataset key, matching what the data configs already filter on) so a + # curated hash from another embodiment isn't loaded with this dataset's + # keymap. Deliberately NOT the val dataset's own filter — curated viz + # episodes need not belong to the val split. + emb_filter = DatasetFilter( + filter_lambdas=[f"lambda row: row.get('embodiment') == {name!r}"] + ) + viz_filter = EpisodeHashFilter(hashes, base=emb_filter) + try: + viz_ds = MultiDataset._from_resolver( + resolver, filters=viz_filter, mode="total" + ) + except ValueError as e: + # resolve() raises when the filter matches no episodes (e.g. a + # flat list that doesn't include any of this embodiment's hashes). + log.warning( + f"viz_episodes: no curated episodes resolved for <{name}> " + f"({len(hashes)} hashes requested): {e}" + ) + continue + viz_ds.set_norm_stats_from(norm_stats) + + params = dict(datamodule.valid_dataloader_params.get(name, {})) + params.pop("shuffle", None) + loaders = {} + for episode_hash, child in viz_ds.datasets.items(): + single = MultiDataset( + datasets={episode_hash: child}, + mode="total", + norm_mode=norm_stats.norm_mode, + ) + single.set_norm_stats_from(norm_stats) + loaders[episode_hash] = DataLoader( + single, + shuffle=False, + collate_fn=datamodule.collate_fn, + **params, + ) + if loaders: + viz_dataloaders[name] = loaders + log.info( + f"viz_episodes: built {len(loaders)} per-episode viz loaders for <{name}>" + ) + return viz_dataloaders + + @task_wrapper def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Trains the model. Can additionally evaluate on a testset, using best weights obtained during @@ -149,6 +230,13 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: for ds in datamodule.valid_datasets.values(): ds.set_norm_stats_from(norm_stats) + # Curated per-episode viz loaders (empty unless evaluator.viz_episodes is set). + viz_episode_loaders = ( + _build_viz_episode_dataloaders(cfg, datamodule, norm_stats) + if cfg.get("evaluator") is not None + else {} + ) + log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = ModelWrapper( config_tree=_build_model_config_tree(cfg), @@ -229,6 +317,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: eval_obj: Eval = hydra.utils.instantiate(cfg.evaluator) eval_obj.trainer = trainer eval_obj.model = model.model + eval_obj.viz_dataloaders = viz_episode_loaders model.evaluator = eval_obj log.info("Starting training!") trainer.fit( @@ -240,6 +329,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: elif mode == "eval": eval_obj.trainer = trainer eval_obj.model = model.model + eval_obj.viz_dataloaders = viz_episode_loaders model.evaluator = eval_obj if hasattr(eval_obj, "run"):