Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion egomimic/eval/eval_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,22 @@ def __init__(
viz_func: dict = None,
transform_lists: dict | None = None,
viz_every_n_epochs: int = 1,
viz_episodes: str | None = None,
):
super().__init__()
self.trainer = None
self.model = None
self.viz_func = viz_func
self.limit_val_batches = limit_val_batches
self.viz_every_n_epochs = viz_every_n_epochs
# Path to a JSON list of episode hashes to visualize. When set, viz is
# produced exclusively from these curated episodes (the inline
# per-val-batch viz is disabled); each listed episode is rendered for up
# to ``limit_val_batches`` batches. The per-episode dataloaders are built
# in trainHydra and assigned to ``viz_dataloaders``
# (embodiment_name -> {episode_hash: loader}).
self.viz_episodes = viz_episodes
self.viz_dataloaders = {}
# Per-embodiment list[Transform] applied once during eval to project
# the model's wrist-frame actions back into cam (head) frame. Reused for
# both cam-frame MSE and the viz video so we don't transform twice.
Expand All @@ -51,6 +61,59 @@ def _should_viz(self) -> bool:
return False
return (self.trainer.current_epoch % self.viz_every_n_epochs) == 0

def _use_viz_episodes(self) -> bool:
return bool(self.viz_episodes) and bool(self.viz_dataloaders)

@staticmethod
def _to_device(batch, device):
return {
k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()
}

def _write_episode_video(self, embodiment_id, episode_hash, frames):
out_dir = os.path.join(
self.video_dir(),
f"epoch_{self.trainer.current_epoch}",
str(get_embodiment(embodiment_id)),
)
os.makedirs(out_dir, exist_ok=True)
path = os.path.join(out_dir, f"{episode_hash}.mp4")
tvio.write_video(path, torch.stack(frames), fps=30, video_codec="h264")

def _run_viz_episode_pass(self):
"""Render one video per curated episode (rank 0 only).

Each episode's dataloader is iterated for up to ``limit_val_batches``
batches, reusing ``compute_metrics_and_viz`` for the model forward + viz
drawing (metrics are discarded here; they are logged on the normal
validation loader). ``self.model`` is the raw algo, so the forward does
not trigger DDP collectives — safe to run on rank 0 alone.
"""
if not self.trainer.is_global_zero:
return
algo = self.model
device = self.trainer.lightning_module.device
max_batches = self.limit_val_batches
with torch.no_grad():
for embodiment_name, loaders in self.viz_dataloaders.items():
for episode_hash, loader in loaders.items():
frames_by_emb = {}
for i, raw in enumerate(loader):
if max_batches and i >= max_batches:
break
batch = {embodiment_name: self._to_device(raw, device)}
batch = algo.process_batch_for_training(batch)
_, images_dict = self.compute_metrics_and_viz(
batch, do_viz=True
)
for emb_id, images in images_dict.items():
frames_by_emb.setdefault(emb_id, []).extend(
torch.from_numpy(images)
)
for emb_id, frames in frames_by_emb.items():
if frames:
self._write_episode_video(emb_id, episode_hash, frames)

@abstractmethod
def compute_metrics_and_viz(self, batch, do_viz=True):
"""
Expand All @@ -75,6 +138,9 @@ def on_validation_start(self):
def on_validation_end(self):
if not self._should_viz():
return
if self._use_viz_episodes():
self._run_viz_episode_pass()
return
for key, buffer in self.val_image_buffer.items():
os.makedirs(
os.path.join(
Expand All @@ -98,7 +164,10 @@ def on_validation_end(self):
self.val_image_buffer[key] = []

def on_validation_step(self, batch, batch_idx, dataloader_idx=0):
do_viz = self._should_viz()
# When curated viz episodes are configured, viz is produced by a
# dedicated pass in on_validation_end; the inline per-val-batch viz is
# disabled here while metrics still log every validation.
do_viz = self._should_viz() and not self._use_viz_episodes()
metrics, images_dict = self.compute_metrics_and_viz(batch, do_viz=do_viz)

device = self.trainer.lightning_module.device
Expand Down
3 changes: 2 additions & 1 deletion egomimic/hydra_configs/evaluator/eval_hpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ defaults:

_target_: egomimic.eval.eval_hpt.HPTEvalVideo

viz_every_n_epochs: 200
viz_every_n_epochs: 200
viz_episodes: null
1 change: 1 addition & 0 deletions egomimic/hydra_configs/evaluator/eval_pi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ defaults:
_target_: egomimic.eval.eval_pi.PIEvalVideo

viz_every_n_epochs: 200
viz_episodes: null

# Per-embodiment revert transform. Applied once during validation to project
# the model's wrist-frame action chunks back to cam (head) frame, then reused
Expand Down
27 changes: 27 additions & 0 deletions egomimic/rldb/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,33 @@ def matches(self, row: Mapping[str, Any]) -> bool:
return True


class EpisodeHashFilter(DatasetFilter):
"""Keep only episodes whose ``episode_hash`` is in a curated set.

Used to scope a viz dataset to a list of episode hashes (e.g. a curation
``kept_hashes.json``) without materializing a long lambda-string filter. An
optional ``base`` filter is ANDed in so any existing embodiment/quality
constraints from the source dataset config are preserved.
"""

def __init__(self, hashes, base: DatasetFilter | None = None) -> None:
super().__init__(filter_lambdas=None)
self.hashes = set(hashes)
self.base = base

def __repr__(self) -> str:
return f"EpisodeHashFilter(n_hashes={len(self.hashes)}, base={self.base!r})"

def matches(self, row: Mapping[str, Any]) -> bool:
if row.get("is_deleted", False):
return False
if row.get("episode_hash") not in self.hashes:
return False
if self.base is not None and not self.base.matches(row):
return False
return True


class ScaleAnnotationDatasetFilter(DatasetFilter):
def __init__(
self, project_name: str, filter_lambdas: Sequence[str] | None = None
Expand Down
90 changes: 90 additions & 0 deletions egomimic/trainHydra.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import json
import os
import signal
from typing import Any, Dict, List, Optional, Tuple
Expand All @@ -11,10 +12,12 @@
from lightning.pytorch.plugins.environments import SLURMEnvironment
from omegaconf import DictConfig, OmegaConf, open_dict
from tabulate import tabulate
from torch.utils.data import DataLoader

import egomimic.utils.hydra_resolvers # noqa: F401 -- registers OmegaConf resolvers
from egomimic.eval.eval import Eval
from egomimic.pl_utils.pl_model import ModelWrapper
from egomimic.rldb.filters import DatasetFilter, EpisodeHashFilter
from egomimic.rldb.zarr.utils import set_global_seed
from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset
from egomimic.utils.aws.aws_data_utils import load_env
Expand Down Expand Up @@ -62,6 +65,84 @@ def _log_dataset_frame_counts(train_datasets: dict, valid_datasets: dict) -> Non
log.info("Dataset frame counts:\n" + table)


def _build_viz_episode_dataloaders(cfg, datamodule, norm_stats):
"""Build per-episode viz dataloaders from ``cfg.evaluator.viz_episodes``.

``viz_episodes`` is a path to a JSON file holding either a flat list of
episode hashes (applied to every valid embodiment) or a dict
``embodiment_name -> list[hash]``. For each valid dataset, the matching
episodes are resolved (``mode="total"``) and one dataloader is built per
episode; the evaluator caps each episode's viz pass at ``limit_val_batches``.
Returns ``{embodiment_name: {episode_hash: DataLoader}}`` ({} when unset).
"""
viz_episodes = OmegaConf.select(cfg, "evaluator.viz_episodes", default=None)
if not viz_episodes:
return {}

with open(viz_episodes) as f:
data = json.load(f)
hashes_by_name = (
{name: set(v) for name, v in data.items()} if isinstance(data, dict) else None
)
all_hashes = None if isinstance(data, dict) else set(data)

viz_dataloaders = {}
for name, vcfg in cfg.data.valid_datasets.items():
if vcfg is None:
continue
hashes = (
all_hashes if all_hashes is not None else hashes_by_name.get(name, set())
)
if not hashes:
continue
resolver = hydra.utils.instantiate(vcfg.resolver)
# Scope to this embodiment by the SQL ``embodiment`` field (the valid
# dataset key, matching what the data configs already filter on) so a
# curated hash from another embodiment isn't loaded with this dataset's
# keymap. Deliberately NOT the val dataset's own filter — curated viz
# episodes need not belong to the val split.
emb_filter = DatasetFilter(
filter_lambdas=[f"lambda row: row.get('embodiment') == {name!r}"]
)
viz_filter = EpisodeHashFilter(hashes, base=emb_filter)
try:
viz_ds = MultiDataset._from_resolver(
resolver, filters=viz_filter, mode="total"
)
except ValueError as e:
# resolve() raises when the filter matches no episodes (e.g. a
# flat list that doesn't include any of this embodiment's hashes).
log.warning(
f"viz_episodes: no curated episodes resolved for <{name}> "
f"({len(hashes)} hashes requested): {e}"
)
continue
viz_ds.set_norm_stats_from(norm_stats)

params = dict(datamodule.valid_dataloader_params.get(name, {}))
params.pop("shuffle", None)
loaders = {}
for episode_hash, child in viz_ds.datasets.items():
single = MultiDataset(
datasets={episode_hash: child},
mode="total",
norm_mode=norm_stats.norm_mode,
)
single.set_norm_stats_from(norm_stats)
loaders[episode_hash] = DataLoader(
single,
shuffle=False,
collate_fn=datamodule.collate_fn,
**params,
)
if loaders:
viz_dataloaders[name] = loaders
log.info(
f"viz_episodes: built {len(loaders)} per-episode viz loaders for <{name}>"
)
return viz_dataloaders


@task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
Expand Down Expand Up @@ -149,6 +230,13 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
for ds in datamodule.valid_datasets.values():
ds.set_norm_stats_from(norm_stats)

# Curated per-episode viz loaders (empty unless evaluator.viz_episodes is set).
viz_episode_loaders = (
_build_viz_episode_dataloaders(cfg, datamodule, norm_stats)
if cfg.get("evaluator") is not None
else {}
)

log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = ModelWrapper(
config_tree=_build_model_config_tree(cfg),
Expand Down Expand Up @@ -229,6 +317,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
eval_obj: Eval = hydra.utils.instantiate(cfg.evaluator)
eval_obj.trainer = trainer
eval_obj.model = model.model
eval_obj.viz_dataloaders = viz_episode_loaders
model.evaluator = eval_obj
log.info("Starting training!")
trainer.fit(
Expand All @@ -240,6 +329,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
elif mode == "eval":
eval_obj.trainer = trainer
eval_obj.model = model.model
eval_obj.viz_dataloaders = viz_episode_loaders
model.evaluator = eval_obj

if hasattr(eval_obj, "run"):
Expand Down
Loading