Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 68 additions & 10 deletions egomimic/pl_utils/pl_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,20 @@ def build_tokenized_collate(
proprio: bool = False,
embodiment_label: bool = False,
control_mode: dict[str, str] | None = None,
data_schematic=None,
):
"""Return a collate_fn closure that tokenizes the annotations field.

Three orthogonal inclusion flags govern what gets spliced into the prompt:

- ``proprio`` (bool): if True, append ``State: <bins>``. The per-sample
proprio listed in ``proprio_keys`` is concatenated, clipped to
``[-1, 1]``, and discretized into ``state_num_bins`` bins (pi0.5 style;
assumes upstream normalization).
proprio listed in ``proprio_keys`` is normalized via
``data_schematic.normalize_data`` (when provided and stats are
populated), concatenated, clipped to ``[-1, 1]``, and discretized
into ``state_num_bins`` bins (pi0.5 style). The schematic's
``norm_stats`` are read at *call* time, so an empty schematic at
construction is fine as long as ``infer_norm_from_dataset`` runs
before the first batch.
- ``embodiment_label`` (bool): if True, append ``Embodiment: <name>``.
- ``control_mode`` (dict | None): if non-null, append ``Control mode:
<descriptor>``. Keys are substrings matched against the (lowercased,
Expand Down Expand Up @@ -340,22 +345,75 @@ def _control_mode_for(emb_name):
def _discretize_sample_state(sample):
if not proprio_keys:
return None
parts = []

emb_id = sample.get("embodiment")
if isinstance(emb_id, torch.Tensor):
emb_id = int(emb_id.item())
elif isinstance(emb_id, np.ndarray):
emb_id = int(emb_id.item())
elif emb_id is not None:
emb_id = int(emb_id)

# Sample dict is keyed by zarr_key (e.g. "observations.state.ee_pose").
# data_schematic.normalize_data keys by key_name (e.g. "ee_pose"), so
# translate zarr_key -> key_name before calling normalize_data — this
# mirrors the model-side path in pi.py:process_batch_for_training.
raw = {}
# Map key_name -> zarr_key so we can preserve proprio_keys ordering when
# concatenating after normalization.
keyname_to_zarr = {}
for k in proprio_keys:
if k not in sample:
continue
keyname = (
data_schematic.zarr_key_to_keyname(k, emb_id)
if data_schematic is not None and emb_id is not None
else None
)
if keyname is None:
# Not registered for this embodiment under any keyname; skip
# rather than feeding an un-normalizable key into normalize_data.
continue
v = sample[k]
if isinstance(v, torch.Tensor):
v = v.detach().cpu().numpy()
raw[keyname] = v.detach().to(torch.float32)
else:
v = np.asarray(v)
v = np.asarray(v, dtype=np.float32)
# Use the most recent timestep if proprio carries a time axis.
raw[keyname] = torch.as_tensor(np.asarray(v), dtype=torch.float32)
keyname_to_zarr[keyname] = k
if not raw:
return None

# Lazy lookup: norm_stats are populated in-place on the shared
# data_schematic by trainHydra.infer_norm_from_dataset before
# trainer.fit, so by the time the collate runs, stats must exist.
# Hard-fail rather than silently skipping — silent skip would emit
# bins computed on raw, un-normalized values.
if data_schematic is None:
raise ValueError(
"build_tokenized_collate: proprio=True requires data_schematic, got None. "
"Pass data_schematic into MultiDataModuleWrapper."
)
if getattr(data_schematic, "norm_stats", None) is None:
raise ValueError(
"data_schematic.norm_stats is not populated. "
"Call data_schematic.infer_norm_from_dataset(...) before the first batch."
)
if emb_id is None:
raise ValueError(
"Sample is missing 'embodiment' id; cannot look up norm stats for proprio discretization."
)
# data_schematic.normalize_data raises ValueError if stats are missing
# for a specific (embodiment, key) pair — see DataSchematic.normalize_data.
normed = data_schematic.normalize_data(raw, emb_id)

# Iterate in raw insertion order (which mirrors proprio_keys ordering,
# restricted to keys that translated successfully and were present).
parts = []
for keyname in raw:
v = normed[keyname].detach().cpu().numpy().astype(np.float32)
while v.ndim > 1:
v = v[-1]
parts.append(v.reshape(-1))
if not parts:
return None
state = np.concatenate(parts, axis=-1)
state = np.clip(state, -1.0, 1.0)
bins = np.digitize(state, bins=state_bin_edges) - 1
Expand Down
5 changes: 4 additions & 1 deletion egomimic/trainHydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"MultiDataModuleWrapper" in cfg.data._target_
), "cfg.data._target_ must be 'MultiDataModuleWrapper'"
datamodule: LightningDataModule = hydra.utils.instantiate(
cfg.data, train_datasets=train_datasets, valid_datasets=valid_datasets
cfg.data,
train_datasets=train_datasets,
valid_datasets=valid_datasets,
data_schematic=data_schematic,
)

# Stats-only MultiDataset (no graph of its own; explicitly populated from
Expand Down
Loading