Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions egomimic/eval/eval_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class ACTEvalVideo(EvalVideo):
and draws predicted/ground-truth trajectories on the visualization image.
"""

def compute_metrics_and_viz(self, batch):
def compute_metrics_and_viz(self, batch, do_viz=True):
algo = self.model
preds = algo.forward_eval(batch)
# ground truth normalized; unnormalize for direct comparison.
Expand All @@ -29,7 +29,9 @@ def compute_metrics_and_viz(self, batch):
preds[ac_key][:, -1].cpu(), batch[ac_key][:, -1].cpu()
)

ims = {algo.embodiment_id: self._visualize_preds(preds, batch)}
ims = {}
if do_viz:
ims = {algo.embodiment_id: self._visualize_preds(preds, batch)}
return metrics, ims

def _visualize_preds(self, predictions, batch):
Expand Down
7 changes: 4 additions & 3 deletions egomimic/eval/eval_hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class HPTEvalVideo(EvalVideo):
and the viz video.
"""

def compute_metrics_and_viz(self, batch):
def compute_metrics_and_viz(self, batch, do_viz=True):
algo = self.model
preds = algo.forward_eval(batch)

Expand Down Expand Up @@ -191,8 +191,9 @@ def compute_metrics_and_viz(self, batch):
preds_for_viz = dict(preds)
preds_for_viz[main_pred_key] = pred_batch_viz[ac_key]

ims = self._visualize_preds(preds_for_viz, gt_batch_viz)
images_dict[embodiment_id] = ims
if do_viz:
ims = self._visualize_preds(preds_for_viz, gt_batch_viz)
images_dict[embodiment_id] = ims

if total_loss is not None and n_loss_embodiments > 0:
metrics["Valid/action_loss"] = total_loss / n_loss_embodiments
Expand Down
4 changes: 2 additions & 2 deletions egomimic/eval/eval_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def on_validation_start(self):
exist_ok=True,
)

def compute_metrics_and_viz(self, batch):
def compute_metrics_and_viz(self, batch, do_viz=True):
algo = self.model
metrics = {}
images_dict = {}
Expand Down Expand Up @@ -282,7 +282,7 @@ def compute_metrics_and_viz(self, batch):
unnorm_batch[ac_key][:, -1].cpu(),
)

if self.viz_func is not None:
if do_viz and self.viz_func is not None:
images_dict[embodiment_id] = self._visualize_preds(
unnorm_preds, unnorm_batch
)
Expand Down
7 changes: 4 additions & 3 deletions egomimic/eval/eval_pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class PIEvalVideo(EvalVideo):
and the viz video.
"""

def compute_metrics_and_viz(self, batch):
def compute_metrics_and_viz(self, batch, do_viz=True):
algo = self.model
preds = algo.forward_eval(batch)

Expand Down Expand Up @@ -78,8 +78,9 @@ def compute_metrics_and_viz(self, batch):
preds_for_viz = dict(preds)
preds_for_viz[pred_key] = pred_batch_viz[ac_key]

ims = self._visualize_preds(preds_for_viz, gt_batch_viz)
images_dict[embodiment_id] = ims
if do_viz:
ims = self._visualize_preds(preds_for_viz, gt_batch_viz)
images_dict[embodiment_id] = ims

if total_loss is not None and n_loss_embodiments > 0:
metrics["Valid/action_loss"] = total_loss / n_loss_embodiments
Expand Down
66 changes: 40 additions & 26 deletions egomimic/eval/eval_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def __init__(
limit_val_batches: int = 400,
viz_func: dict = None,
transform_lists: dict | None = None,
viz_every_n_epochs: int = 1,
):
super().__init__()
self.trainer = None
self.model = None
self.viz_func = viz_func
self.viz_every_n_epochs = viz_every_n_epochs
# Per-embodiment list[Transform] applied once during eval to project
# the model's wrist-frame actions back into cam (head) frame. Reused for
# both cam-frame MSE and the viz video so we don't transform twice.
Expand All @@ -44,8 +46,13 @@ def __init__(
def video_dir(self):
return os.path.join(self.root_dir(), "videos")

def _should_viz(self) -> bool:
if not self.viz_every_n_epochs or self.viz_every_n_epochs <= 0:
return False
return (self.trainer.current_epoch % self.viz_every_n_epochs) == 0

@abstractmethod
def compute_metrics_and_viz(self, batch):
def compute_metrics_and_viz(self, batch, do_viz=True):
"""
Run the model's eval forward and compute metrics and visualization frames.

Expand All @@ -59,13 +66,15 @@ def compute_metrics_and_viz(self, batch):
raise NotImplementedError

def on_validation_start(self):
if self.trainer.is_global_zero:
if self.trainer.is_global_zero and self._should_viz():
os.makedirs(
os.path.join(self.video_dir(), f"epoch_{self.trainer.current_epoch}"),
exist_ok=True,
)

def on_validation_end(self):
if not self._should_viz():
return
for key, buffer in self.val_image_buffer.items():
os.makedirs(
os.path.join(
Expand All @@ -89,7 +98,8 @@ def on_validation_end(self):
self.val_image_buffer[key] = []

def on_validation_step(self, batch, batch_idx, dataloader_idx=0):
metrics, images_dict = self.compute_metrics_and_viz(batch)
do_viz = self._should_viz()
metrics, images_dict = self.compute_metrics_and_viz(batch, do_viz=do_viz)

device = self.trainer.lightning_module.device
metrics = {
Expand All @@ -98,29 +108,33 @@ def on_validation_step(self, batch, batch_idx, dataloader_idx=0):
}

## images is now a dict
for key, images in images_dict.items():
os.makedirs(
os.path.join(
self.video_dir(),
f"epoch_{self.trainer.current_epoch}",
str(get_embodiment(key)),
),
exist_ok=True,
)
if key not in self.val_image_buffer or self.val_image_buffer[key] is None:
self.val_image_buffer[key] = []
self.val_counter[key] = 0
self.val_image_buffer[key].extend(torch.from_numpy(images))
if len(self.val_image_buffer[key]) >= 1000:
frames = torch.stack(self.val_image_buffer[key])
path = os.path.join(
self.video_dir(),
f"epoch_{self.trainer.current_epoch}",
str(get_embodiment(key)),
f"validation_video_{self.val_counter[key]}.mp4",
if do_viz:
for key, images in images_dict.items():
os.makedirs(
os.path.join(
self.video_dir(),
f"epoch_{self.trainer.current_epoch}",
str(get_embodiment(key)),
),
exist_ok=True,
)
tvio.write_video(path, frames, fps=30, video_codec="h264")
self.val_image_buffer[key].clear()
self.val_counter[key] += 1
if (
key not in self.val_image_buffer
or self.val_image_buffer[key] is None
):
self.val_image_buffer[key] = []
self.val_counter[key] = 0
self.val_image_buffer[key].extend(torch.from_numpy(images))
if len(self.val_image_buffer[key]) >= 1000:
frames = torch.stack(self.val_image_buffer[key])
path = os.path.join(
self.video_dir(),
f"epoch_{self.trainer.current_epoch}",
str(get_embodiment(key)),
f"validation_video_{self.val_counter[key]}.mp4",
)
tvio.write_video(path, frames, fps=30, video_codec="h264")
self.val_image_buffer[key].clear()
self.val_counter[key] += 1

self.trainer.lightning_module.log_dict(metrics, sync_dist=True)
2 changes: 1 addition & 1 deletion egomimic/hydra_configs/callbacks/checkpoints.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ model_checkpoint:
filename: "epoch_{epoch}"
save_last: true
save_top_k: -1
every_n_epochs: 100
every_n_epochs: 200
2 changes: 2 additions & 0 deletions egomimic/hydra_configs/evaluator/eval_hpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ defaults:
- _self_

_target_: egomimic.eval.eval_hpt.HPTEvalVideo

viz_every_n_epochs: 200
2 changes: 2 additions & 0 deletions egomimic/hydra_configs/evaluator/eval_pi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ defaults:

_target_: egomimic.eval.eval_pi.PIEvalVideo

viz_every_n_epochs: 200

# Per-embodiment revert transform. Applied once during validation to project
# the model's wrist-frame action chunks back to cam (head) frame, then reused
# for both the cam-frame MSE and the viz video. Each value resolves to a
Expand Down
2 changes: 1 addition & 1 deletion egomimic/hydra_configs/trainer/ddp_pi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ accelerator: gpu
devices: ${eval:'${launch_params.gpus_per_node} * ${launch_params.nodes}'}
num_nodes: ${launch_params.nodes}
sync_batchnorm: True
check_val_every_n_epoch: 200
check_val_every_n_epoch: 10
num_sanity_val_steps: 0
Loading