From c044ed8680d0cb680cae05889993ef759f9d23e2 Mon Sep 17 00:00:00 2001 From: ElmoPA Date: Fri, 8 May 2026 00:12:05 -0400 Subject: [PATCH] Pass norm stats to collate for proprio tong kenization --- egomimic/pl_utils/pl_data_utils.py | 78 ++++++++++++++++++++++++++---- egomimic/trainHydra.py | 5 +- 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/egomimic/pl_utils/pl_data_utils.py b/egomimic/pl_utils/pl_data_utils.py index eac1f4fbf..d4112e74d 100644 --- a/egomimic/pl_utils/pl_data_utils.py +++ b/egomimic/pl_utils/pl_data_utils.py @@ -282,15 +282,20 @@ def build_tokenized_collate( proprio: bool = False, embodiment_label: bool = False, control_mode: dict[str, str] | None = None, + data_schematic=None, ): """Return a collate_fn closure that tokenizes the annotations field. Three orthogonal inclusion flags govern what gets spliced into the prompt: - ``proprio`` (bool): if True, append ``State: ``. The per-sample - proprio listed in ``proprio_keys`` is concatenated, clipped to - ``[-1, 1]``, and discretized into ``state_num_bins`` bins (pi0.5 style; - assumes upstream normalization). + proprio listed in ``proprio_keys`` is normalized via + ``data_schematic.normalize_data`` (when provided and stats are + populated), concatenated, clipped to ``[-1, 1]``, and discretized + into ``state_num_bins`` bins (pi0.5 style). The schematic's + ``norm_stats`` are read at *call* time, so an empty schematic at + construction is fine as long as ``infer_norm_from_dataset`` runs + before the first batch. - ``embodiment_label`` (bool): if True, append ``Embodiment: ``. - ``control_mode`` (dict | None): if non-null, append ``Control mode: ``. Keys are substrings matched against the (lowercased, @@ -340,22 +345,75 @@ def _control_mode_for(emb_name): def _discretize_sample_state(sample): if not proprio_keys: return None - parts = [] + + emb_id = sample.get("embodiment") + if isinstance(emb_id, torch.Tensor): + emb_id = int(emb_id.item()) + elif isinstance(emb_id, np.ndarray): + emb_id = int(emb_id.item()) + elif emb_id is not None: + emb_id = int(emb_id) + + # Sample dict is keyed by zarr_key (e.g. "observations.state.ee_pose"). + # data_schematic.normalize_data keys by key_name (e.g. "ee_pose"), so + # translate zarr_key -> key_name before calling normalize_data — this + # mirrors the model-side path in pi.py:process_batch_for_training. + raw = {} + # Map key_name -> zarr_key so we can preserve proprio_keys ordering when + # concatenating after normalization. + keyname_to_zarr = {} for k in proprio_keys: if k not in sample: continue + keyname = ( + data_schematic.zarr_key_to_keyname(k, emb_id) + if data_schematic is not None and emb_id is not None + else None + ) + if keyname is None: + # Not registered for this embodiment under any keyname; skip + # rather than feeding an un-normalizable key into normalize_data. + continue v = sample[k] if isinstance(v, torch.Tensor): - v = v.detach().cpu().numpy() + raw[keyname] = v.detach().to(torch.float32) else: - v = np.asarray(v) - v = np.asarray(v, dtype=np.float32) - # Use the most recent timestep if proprio carries a time axis. + raw[keyname] = torch.as_tensor(np.asarray(v), dtype=torch.float32) + keyname_to_zarr[keyname] = k + if not raw: + return None + + # Lazy lookup: norm_stats are populated in-place on the shared + # data_schematic by trainHydra.infer_norm_from_dataset before + # trainer.fit, so by the time the collate runs, stats must exist. + # Hard-fail rather than silently skipping — silent skip would emit + # bins computed on raw, un-normalized values. + if data_schematic is None: + raise ValueError( + "build_tokenized_collate: proprio=True requires data_schematic, got None. " + "Pass data_schematic into MultiDataModuleWrapper." + ) + if getattr(data_schematic, "norm_stats", None) is None: + raise ValueError( + "data_schematic.norm_stats is not populated. " + "Call data_schematic.infer_norm_from_dataset(...) before the first batch." + ) + if emb_id is None: + raise ValueError( + "Sample is missing 'embodiment' id; cannot look up norm stats for proprio discretization." + ) + # data_schematic.normalize_data raises ValueError if stats are missing + # for a specific (embodiment, key) pair — see DataSchematic.normalize_data. + normed = data_schematic.normalize_data(raw, emb_id) + + # Iterate in raw insertion order (which mirrors proprio_keys ordering, + # restricted to keys that translated successfully and were present). + parts = [] + for keyname in raw: + v = normed[keyname].detach().cpu().numpy().astype(np.float32) while v.ndim > 1: v = v[-1] parts.append(v.reshape(-1)) - if not parts: - return None state = np.concatenate(parts, axis=-1) state = np.clip(state, -1.0, 1.0) bins = np.digitize(state, bins=state_bin_edges) - 1 diff --git a/egomimic/trainHydra.py b/egomimic/trainHydra.py index 57f864f5c..0dbf15cf7 100644 --- a/egomimic/trainHydra.py +++ b/egomimic/trainHydra.py @@ -100,7 +100,10 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: "MultiDataModuleWrapper" in cfg.data._target_ ), "cfg.data._target_ must be 'MultiDataModuleWrapper'" datamodule: LightningDataModule = hydra.utils.instantiate( - cfg.data, train_datasets=train_datasets, valid_datasets=valid_datasets + cfg.data, + train_datasets=train_datasets, + valid_datasets=valid_datasets, + data_schematic=data_schematic, ) # Stats-only MultiDataset (no graph of its own; explicitly populated from