diff --git a/egomimic/algo/hpt.py b/egomimic/algo/hpt.py index 46f7150c2..8adfc901c 100644 --- a/egomimic/algo/hpt.py +++ b/egomimic/algo/hpt.py @@ -413,7 +413,11 @@ def resume_from_depth(self, block_outputs, depth): blocks = self.trunk["trunk"].blocks for blk in list(blocks)[depth:]: - cut_tokens = blk(cut_tokens, attn_mask=None) + cut_tokens = blk( + cut_tokens, + attn_mask=None, + key_padding_mask=getattr(self, "_ricl_key_padding_mask", None), + ) if self.trunk["trunk"].post_transformer_layer is not None: cut_tokens = self.trunk["trunk"].post_transformer_layer(cut_tokens) @@ -485,7 +489,13 @@ def forward_features(self, domain, data): trunk_tokens = self.preprocess_tokens(domain, stem_tokens) if not self.no_trunk: - trunk_tokens, block_outputs = self.trunk["trunk"](trunk_tokens) + # RICL subclasses stash a per-sample (B, L) key_padding_mask in + # stem_process to hide invalid retrieved-demo token spans; absent / + # None for plain HPT (then the trunk behaves exactly as before). + trunk_tokens, block_outputs = self.trunk["trunk"]( + trunk_tokens, + key_padding_mask=getattr(self, "_ricl_key_padding_mask", None), + ) proc_tokens = self.postprocess_tokens(trunk_tokens) return proc_tokens, block_outputs @@ -784,6 +794,10 @@ def load_pretrained(self, checkpoint_path): class HPT(Algo): """ """ + # Inner model class; RICL subclasses override with an HPTModel subclass that + # adds retrieved-demo stems + masking. Plain HPT is unaffected. + model_cls = HPTModel + def __init__( self, norm_stats, @@ -850,7 +864,7 @@ def __init__( self.is_6dof = kwargs.get("6dof", False) self.kinematics_solver = kwargs.get("kinematics_solver", None) - model = HPTModel(**trunk) + model = self.model_cls(**trunk) model.auxiliary_ac_keys = self.auxiliary_ac_keys self.multitask = kwargs.get("multitask", False) diff --git a/egomimic/algo/hpt_ricl.py b/egomimic/algo/hpt_ricl.py new file mode 100644 index 000000000..8a0940f47 --- /dev/null +++ b/egomimic/algo/hpt_ricl.py @@ -0,0 +1,244 @@ +"""HptRicl: HPT + retrieval-based in-context learning (RICL). + +Ports the RICL idea from :class:`egomimic.algo.pi_ricl.PIRicl` (pi0.5) onto the +stem/trunk/head HPT model. Unlike pi0.5 (a unified VLM where retrieved demos are +just extra images + discretized prompt text), HPT encodes every modality with a +*separate stem*, so each retrieved demonstration's ``(image, state, action-chunk)`` +is turned into tokens by **dedicated retrieved stems**: + +- ``ricl_image`` : shares the query ResNet backbone (generic vision; retrieval is + DINOv2-similarity based) but has its own cross-attention pooling head. +- ``ricl_state`` : a separate ``MLPPolicyStem`` (32-D shared-space input). +- ``ricl_action`` : a new :class:`~egomimic.models.hpt_nets.ActionChunkStem` + (HPT otherwise never encodes an action chunk as input). + +Fusion is **flat concatenation** (no trunk architecture surgery): the retrieved +token blocks are appended after the query stem tokens, tagged with learned +demo-index + modality embeddings, and invalid demos (``ricl_retrieved_mask``) are +both zeroed at the stem output and hidden from attention via a per-sample +``key_padding_mask`` threaded through the trunk. With no ``ricl_*`` keys present +(the k=0 zero-context floor), behaviour is identical to plain HPT. + +The retrieval data pipeline (DINOv2 kNN, bank normalization + 32-D conversion, +the 5 ``ricl_*`` collate keys) is model-agnostic and reused as-is from +``egomimic/ricl`` via ``RiclDataModuleWrapper``. +""" + +from __future__ import annotations + +import logging + +import torch +from overrides import override + +from egomimic.algo.hpt import HPT, HPTModel +from egomimic.rldb.embodiment.embodiment import get_embodiment_id + +logger = logging.getLogger(__name__) + +# Keys the RICL collate attaches per query sample (see egomimic/ricl + the P3 +# collate). Identical contract to egomimic.algo.pi_ricl.RICL_BATCH_KEYS. +RICL_BATCH_KEYS = ( + "ricl_retrieved_images", # (B, k, C, H, W) + "ricl_retrieved_state", # (B, k, Ds) normalized to the bank convention + "ricl_retrieved_action", # (B, k, Ha, Da) normalized, converted to shared space + "ricl_retrieved_mask", # (B, k) bool, valid neighbor (handles < k) + "ricl_retrieved_dist", # (B, k) float, kNN distances (unused in v1) +) + +# Order in which retrieved modalities are encoded / concatenated. The index into +# this tuple selects the per-modality learned embedding row. +_RICL_MODALITIES = ("ricl_image", "ricl_state", "ricl_action") + + +class HptRiclModel(HPTModel): + """HPTModel that also encodes retrieved in-context demonstrations. + + Constructed via ``HPT.model_cls`` with the same ``trunk`` kwargs as HPTModel + plus ``num_retrieved_observations`` (k) and ``ricl_image_encoder_key``. + """ + + def __init__( + self, + *args, + num_retrieved_observations: int = 4, + ricl_image_encoder_key: str = "front_img_1", + **kwargs, + ): + super().__init__(*args, **kwargs) + self.num_retrieved = int(num_retrieved_observations) + self.ricl_image_encoder_key = ricl_image_encoder_key + D = self.embed_dim + # Learned tags so the trunk can tell which demo / modality a token came + # from. Zero-init => early training stays close to plain-HPT behaviour. + self.ricl_demo_embed = torch.nn.Parameter(torch.zeros(self.num_retrieved, D)) + self.ricl_modality_embed = torch.nn.Parameter( + torch.zeros(len(_RICL_MODALITIES), D) + ) + # Stashed per-forward; read by HPTModel.forward_features (defaults to None + # there, so plain HPT is unaffected). + self._ricl_key_padding_mask = None + + # ------------------------------------------------------------------ + # Don't add the singleton instance axis to retrieved tensors: they carry a + # leading k dim and are handled wholesale in stem_process. (Base only special- + # cases keys containing "state"; "ricl_state" would match.) + # ------------------------------------------------------------------ + @override + def preprocess_states(self, domain, data): + for key in data: + if "state" in key and not key.startswith("ricl_"): + data[key] = data[key][:, :, None] + return data + + # ------------------------------------------------------------------ + # Encode query modalities (super), then the retrieved demos. + # ------------------------------------------------------------------ + @override + def stem_process(self, domain, data): + ricl = {} + for key in (*_RICL_MODALITIES, "ricl_retrieved_mask"): + if key in data: + ricl[key] = data.pop(key) + + # Query stems exactly as plain HPT (ricl keys removed, so the base loop's + # `if modality not in data: continue` skips the retrieved stems). + feats, feat_dict = super().stem_process(domain, data) + + self._ricl_key_padding_mask = None + if "ricl_image" in ricl: + ricl_tokens, kpm = self._process_retrieved(domain, ricl, feats) + feats.append(ricl_tokens) + feat_dict["ricl"] = ricl_tokens + self._ricl_key_padding_mask = kpm + + return feats, feat_dict + + # ------------------------------------------------------------------ + def _encode_ricl_image(self, domain, images): + """images: (B, k, C, H, W) -> (B*k, t_img, D) via shared ResNet + own head.""" + B, k = images.shape[0], images.shape[1] + x = images.reshape(B * k, *images.shape[2:]) # (B*k, C, H, W) + x = x.unsqueeze(1).unsqueeze(1) # (B*k, 1, 1, C, H, W) — ResNet input shape + feat = self.encoders[self.ricl_image_encoder_key](x) # (B*k, M, D) + stem = self.stems[f"{domain}_ricl_image"] + return stem.compute_latent(feat) # (B*k, t_img, D) + + def _process_retrieved(self, domain, ricl, query_feats): + """Encode the 3 retrieved modalities, tag + mask them, and build the + trunk key_padding_mask. Returns (ricl_tokens (B, sum_k_t, D), kpm (B, L)).""" + images = ricl["ricl_image"] # (B, k, C, H, W) + states = ricl["ricl_state"] # (B, k, Ds) + actions = ricl["ricl_action"] # (B, k, Ha, Da) + mask = ricl.get("ricl_retrieved_mask") # (B, k) bool, True = valid + + B, k = images.shape[0], images.shape[1] + assert k <= self.ricl_demo_embed.shape[0], ( + f"got k={k} retrieved demos but ricl_demo_embed sized for " + f"{self.ricl_demo_embed.shape[0]} (set trunk.num_retrieved_observations)" + ) + + per_modality_tokens = [ + self._encode_ricl_image(domain, images), # (B*k, t, D) + self.stems[f"{domain}_ricl_state"].compute_latent( + states.reshape(B * k, 1, states.shape[-1]) + ), + self.stems[f"{domain}_ricl_action"].compute_latent( + actions.reshape(B * k, actions.shape[2], actions.shape[3]) + ), + ] + + if mask is not None: + valid = mask.bool().view(B, k, 1, 1) # True = valid + blocks, mask_blocks = [], [] + for m_idx, tok in enumerate(per_modality_tokens): + t, D = tok.shape[1], tok.shape[-1] + tok = tok.view(B, k, t, D) + tok = tok + self.ricl_demo_embed[:k].view(1, k, 1, D) + tok = tok + self.ricl_modality_embed[m_idx].view(1, 1, 1, D) + if mask is not None: + tok = tok * valid.to(tok.dtype) # zero invalid demo tokens + # key_padding_mask convention: True = ignore. + mask_blocks.append( + (~mask.bool()).view(B, k, 1).expand(B, k, t).reshape(B, k * t) + ) + else: + mask_blocks.append( + torch.zeros(B, k * t, dtype=torch.bool, device=tok.device) + ) + blocks.append(tok.reshape(B, k * t, D)) + + ricl_tokens = torch.cat(blocks, dim=1) # (B, sum_k_t, D) + + # Prefix = learnable action queries (if any) + all query stem tokens, all + # valid (never ignored). Order matches preprocess_tokens: + # [action_tokens, *query_feats, ricl_tokens]. + n_query = sum(int(f.shape[1]) for f in query_feats) + n_prefix = n_query + ( + self.action_horizon if self.token_postprocessing == "action_token" else 0 + ) + prefix = torch.zeros(B, n_prefix, dtype=torch.bool, device=ricl_tokens.device) + kpm = torch.cat([prefix, *mask_blocks], dim=1) # (B, L) + return ricl_tokens, kpm + + +class HptRicl(HPT): + """HPT with flat-concatenated retrieved in-context demonstrations.""" + + model_cls = HptRiclModel + + def __init__(self, *args, num_retrieved_observations: int = 4, **kwargs): + super().__init__(*args, **kwargs) + self.num_retrieved_observations = int(num_retrieved_observations) + logger.info("HptRicl: k=%d retrieved demos", self.num_retrieved_observations) + + # ------------------------------------------------------------------ + # Carry ricl_* keys through process_batch_for_training (mirror PIRicl). + # ------------------------------------------------------------------ + @override + def process_batch_for_training(self, batch): + processed = super().process_batch_for_training(batch) + for embodiment_name, _batch in batch.items(): + emb_id = get_embodiment_id(embodiment_name) + if emb_id not in processed: + continue + # The base loop maps unknown keys (the ricl_* tensors) to keyname None + # via zarr_key_to_keyname; drop that garbage slot, then re-add cleanly. + processed[emb_id].pop(None, None) + for key in RICL_BATCH_KEYS: + if key in _batch: + val = _batch[key] + if isinstance(val, torch.Tensor): + val = val.to(self.device) + if val.is_floating_point(): + val = val.float() + processed[emb_id][key] = val + return processed + + # ------------------------------------------------------------------ + # Route retrieved tensors into the data dict under the new stem modality keys. + # ------------------------------------------------------------------ + @override + def _robomimic_to_hpt_data( + self, batch, cam_keys, proprio_keys, lang_keys, ac_key, aux_ac_keys=[] + ): + data = super()._robomimic_to_hpt_data( + batch, cam_keys, proprio_keys, lang_keys, ac_key, aux_ac_keys + ) + if "ricl_retrieved_images" not in batch: + return data # zero-context (k=0) -> identical to base HPT + + imgs = batch["ricl_retrieved_images"] # (B, k, C, H, W) + B, k = imgs.shape[0], imgs.shape[1] + flat = imgs.reshape(B * k, *imgs.shape[2:]) + # Normalize retrieved frames the same way as query frames (ImageNet stats + # for the pretrained ResNet); invalid (zero) demos are masked downstream. + if self.nets.training and self.train_image_augs is not None: + flat = self.train_image_augs(flat) + elif self.eval_image_augs is not None: + flat = self.eval_image_augs(flat) + data["ricl_image"] = flat.reshape(B, k, *flat.shape[1:]) + data["ricl_state"] = batch["ricl_retrieved_state"] + data["ricl_action"] = batch["ricl_retrieved_action"] + data["ricl_retrieved_mask"] = batch["ricl_retrieved_mask"] + return data diff --git a/egomimic/eval/hpt_ricl_eval.py b/egomimic/eval/hpt_ricl_eval.py new file mode 100644 index 000000000..e4091cee6 --- /dev/null +++ b/egomimic/eval/hpt_ricl_eval.py @@ -0,0 +1,172 @@ +"""HptRiclEval: compare retrieval-conditioned HPT vs the zero-context floor. + +The HPT analog of :class:`egomimic.eval.pi_ricl_eval.PIRiclEval`. For each +validation batch it runs the model on the *same* eva query frames twice: + - retrieval: the full batch (with ``ricl_*`` keys) -> HptRicl encodes k retrieved + in-context demos through the retrieved stems, + - floor: the same batch with ``ricl_*`` stripped -> HptRicl == plain HPT (k=0). +It reports sampled-action Cartesian MSE / L1 / gripper accuracy + BC loss for both +conditions and the deltas (``RICL/delta_*``, ``RICL/retrieval_helps``). Retrieval +"works" when loss / MSE drop vs the floor; for the eva->eva oracle this bounds the +ceiling. + +**Clean floor (strip before processing).** This evaluator receives the *raw* batch +(``wants_raw_batch = True``) and builds each condition itself. The floor strips +``ricl_*`` from the **raw** batch and *then* processes it, so +``process_batch_for_training`` carries no retrieved tensors and +``_robomimic_to_hpt_data`` adds no retrieved stems -> a genuine k=0 floor (stripping +after processing would leave the retrieved tokens in the model input). + +v1 is intentionally lean (retrieval vs floor). The model-agnostic random-demo +control (``compute_random``, via :func:`shuffle_ricl_keys`) and the cheap +paired-seed flow-loss proxy (``compute_flow_loss``) are wired but default OFF; +flip them on to match the fuller PIRiclEval ablation set. +""" + +from __future__ import annotations + +import torch + +from egomimic.eval.eval_hpt import HPTEvalVideo +from egomimic.ricl import metrics as M +from egomimic.rldb.embodiment.embodiment import get_embodiment + + +class HptRiclEval(HPTEvalVideo): + # Receive the raw (un-processed) batch so the floor is built by stripping + # ricl_* before process_batch_for_training (see module docstring). + wants_raw_batch = True + + def __init__( + self, + *args, + compute_floor: bool = True, + compute_flow_loss: bool = False, + compute_random: bool = False, + n_flow_samples: int = 4, + seed_base: int = 1234, + gripper_threshold: float = 0.0, + gripper_indices=(6, 13), + **kwargs, + ): + super().__init__(*args, **kwargs) + self.compute_floor = compute_floor + self.compute_flow_loss = compute_flow_loss + self.compute_random = compute_random + self.n_flow_samples = max(1, int(n_flow_samples)) + self.seed_base = seed_base + self.gripper_threshold = gripper_threshold + self.gripper_indices = tuple(gripper_indices) + self._batch_idx = 0 + + def on_validation_step(self, batch, batch_idx, dataloader_idx=0): + # Shared per-batch seed so the optional flow-loss conditions see identical + # noise/time (the delta then isolates the conditioning, not RNG). + self._batch_idx = batch_idx + super().on_validation_step(batch, batch_idx, dataloader_idx) + + # ------------------------------------------------------------------ + # One condition: sampled forward_eval + native-space scalar metrics. + # (HPT.forward_eval already unnormalizes its predictions; unnormalize the GT.) + # ------------------------------------------------------------------ + def _eval_condition(self, proc_batch, make_viz: bool): + algo = self.model + preds = algo.forward_eval(proc_batch) + metrics, viz = {}, {} + for embodiment_id, _batch in proc_batch.items(): + _batch = algo.norm_stats.unnormalize(_batch, embodiment_id) + name = get_embodiment(embodiment_id).lower() + ac_key = algo.ac_keys[embodiment_id] + pred_key = f"{name}_{ac_key}" + loss_key = f"{name}_loss" + + if loss_key in preds: + metrics[f"{name}_loss"] = float(preds[loss_key]) + if pred_key in preds: + p = preds[pred_key].cpu() + g = _batch[ac_key].cpu() + metrics[f"{name}_paired_mse"] = M.cartesian_mse(p, g) + metrics[f"{name}_final_mse"] = M.cartesian_mse(p[:, -1], g[:, -1]) + metrics[f"{name}_paired_l1"] = M.cartesian_l1(p, g) + ga = M.gripper_accuracy( + p, g, self.gripper_indices, self.gripper_threshold + ) + if ga == ga: # not NaN (i.e. the 14-D bimanual layout) + metrics[f"{name}_gripper_acc"] = ga + if make_viz and self.viz_func is not None: + # Native-frame viz (the base's cam-frame revert is omitted here, + # mirroring PIRiclEval — the scalar deltas are the headline). + viz[embodiment_id] = self._visualize_preds(preds, _batch) + return metrics, viz + + def _flow_loss(self, proc_batch, seed: int) -> dict: + """Paired-seed flow-matching (training) loss, averaged over n draws. Cheap + (no sampling) and, scored under shared seeds, isolates the conditioning.""" + algo = self.model + accum: dict[str, float] = {} + with torch.no_grad(): + for s in range(self.n_flow_samples): + torch.manual_seed(seed + s * 100003) + preds = algo.forward_training(proc_batch) + for k, v in preds.items(): + if k.endswith("_loss"): + nk = k.replace("_loss", "_flow_loss") + accum[nk] = accum.get(nk, 0.0) + float(v) + return {k: v / self.n_flow_samples for k, v in accum.items()} + + def compute_metrics_and_viz(self, raw_batch): + algo = self.model + seed = self.seed_base + self._batch_idx + + # retrieval (sampled): process the full raw batch -> retrieved demos encoded. + proc_ret = algo.process_batch_for_training(raw_batch) + ret_metrics, viz = self._eval_condition(proc_ret, make_viz=True) + out = {f"RICL/retrieval_{k}": v for k, v in ret_metrics.items()} + + floor_metrics = None + if self.compute_floor: + # strip ricl_* from the RAW batch, then process -> genuine k=0 floor. + proc_flr = algo.process_batch_for_training(M.strip_ricl_keys(raw_batch)) + floor_metrics, _ = self._eval_condition(proc_flr, make_viz=False) + out.update({f"RICL/floor_{k}": v for k, v in floor_metrics.items()}) + out.update( + { + f"RICL/{k}": v + for k, v in M.compare_to_floor(ret_metrics, floor_metrics).items() + } + ) + + # Optional cheap paired-seed flow-loss proxy (off by default). + if self.compute_flow_loss: + ret_flow = self._flow_loss(proc_ret, seed) + out.update({f"RICL/retrieval_{k}": v for k, v in ret_flow.items()}) + if self.compute_floor: + floor_flow = self._flow_loss( + algo.process_batch_for_training(M.strip_ricl_keys(raw_batch)), seed + ) + out.update({f"RICL/floor_{k}": v for k, v in floor_flow.items()}) + out.update( + { + f"RICL/flow_{k}": v + for k, v in M.compare_to_floor(ret_flow, floor_flow).items() + } + ) + + # Optional random-demo control (off by default): each query gets another + # query's k demos; retrieval < random asks whether *similarity* matters. + if self.compute_random: + shuffled = M.shuffle_ricl_keys(raw_batch, seed) + if shuffled is not None: + proc_rnd = algo.process_batch_for_training(shuffled) + rnd_metrics, _ = self._eval_condition(proc_rnd, make_viz=False) + out.update({f"RICL/random_{k}": v for k, v in rnd_metrics.items()}) + cmp_r = M.compare_to_floor(ret_metrics, rnd_metrics) + out["RICL/beats_random"] = float(cmp_r["retrieval_helps"]) + out["RICL/improvement_vs_random"] = cmp_r["mean_improvement"] + + # Primary scalar the trainer logs. + for k, v in ret_metrics.items(): + if k.endswith("_loss"): + out["Valid/action_loss"] = v + break + return out, viz diff --git a/egomimic/hydra_configs/data/cotrain_hpt_ricl_pickplace.yaml b/egomimic/hydra_configs/data/cotrain_hpt_ricl_pickplace.yaml new file mode 100644 index 000000000..aaf6aabe1 --- /dev/null +++ b/egomimic/hydra_configs/data/cotrain_hpt_ricl_pickplace.yaml @@ -0,0 +1,34 @@ +# HPT RICL training data (eva->eva oracle). Reuses the pi RICL pickplace recipe +# verbatim (RiclDataModuleWrapper + query train/valid segment splits + eva bank +# wiring + retrieval cache) and changes only what HPT needs: +# - query camera keymap cartesian_pi -> cartesian, so the query image batch key +# ends in `front_img_1` (HPT stem name) instead of pi's `base_0_rgb`. Proprio/ +# action keys are identical across modes. The bank loader stays cartesian_pi +# (base_0_rgb); retrieved frames are repackaged into model-agnostic ricl_* +# keys regardless of the bank's internal naming. +# - retrieved (state, action) widths -> 32-D shared space (the retrieved stems +# are separate/learned; keeps the model config identical for a later aria->eva +# run). The pi config truncates these to eva's 14-D for prompt discretization. +# +# Provide at run time: retrieval_cache_dir (pickplace train cache), and the eva +# norm stats via norm_stats.precomputed_norm_path (bank_norm_path mirrors it). +defaults: + - cotrain_pi_ricl_pickplace + - _self_ + +train_datasets: + eva_bimanual: + resolver: + key_map: + keymap_mode: cartesian + +# Feed retrieved encoders the full 32-D shared space (overrides the pi config's +# 14-D discretization width). ZarrBankFrameProvider's RobotBimanualCartesianEuler +# already emits 32-D; this keeps it untruncated. +state_dim: 32 +action_dim: 32 + +# eva->eva oracle: query and bank are the same embodiment, so the bank's norm +# stats == the query's precomputed stats. (For a cross-embodiment aria->eva run, +# set this to the BANK's own stats file instead.) +bank_norm_path: ${norm_stats.precomputed_norm_path} diff --git a/egomimic/hydra_configs/evaluator/eval_hpt_ricl.yaml b/egomimic/hydra_configs/evaluator/eval_hpt_ricl.yaml new file mode 100644 index 000000000..7f98fdc6a --- /dev/null +++ b/egomimic/hydra_configs/evaluator/eval_hpt_ricl.yaml @@ -0,0 +1,16 @@ +# HPT RICL evaluator: retrieval (k) vs zero-context floor (k=0) on the eva query +# set. Same cartesian viz as eval_hpt; metrics report the comparison +# (RICL/retrieval_*, RICL/floor_*, RICL/delta_*, RICL/retrieval_helps). +# v1 = retrieval vs floor; flip compute_random / compute_flow_loss on for the +# fuller ablation set. +defaults: + - viz@viz_func: cartesian + - _self_ + +_target_: egomimic.eval.hpt_ricl_eval.HptRiclEval + +compute_floor: true +compute_flow_loss: false +compute_random: false +gripper_threshold: 0.0 +gripper_indices: [6, 13] diff --git a/egomimic/hydra_configs/model/hpt_ricl_pickplace_qwen.yaml b/egomimic/hydra_configs/model/hpt_ricl_pickplace_qwen.yaml new file mode 100644 index 000000000..ae4256ad6 --- /dev/null +++ b/egomimic/hydra_configs/model/hpt_ricl_pickplace_qwen.yaml @@ -0,0 +1,206 @@ +# HptRicl (eva->eva oracle) = HPT + Qwen text stem + retrieval-based in-context +# learning. Forks hpt_bc_pickplace_qwen_pooled: same trunk/query stems/Qwen +# annotation stem + FMPolicy head, plus THREE dedicated retrieved-demo stems +# (ricl_image / ricl_state / ricl_action) whose tokens are flat-concatenated into +# the trunk (no trunk architecture surgery; k=0 == plain HPT). See +# egomimic/algo/hpt_ricl.py. +# +# Pair with: data=cotrain_hpt_ricl_pickplace, evaluator=eval_hpt_ricl. +# query + bank = eva_bimanual pick/place segment groups, 15-step chunks. +# Retrieved (state, action) are fed in the 32-D shared space (data sets +# state_dim/action_dim=32); query state/action stay 14-D (actions_cartesian). +_target_: egomimic.pl_utils.pl_model.ModelWrapper +robomimic_model: + _target_: egomimic.algo.hpt_ricl.HptRicl + + # k retrieved in-context demos per query (also set in trunk for the model's + # demo-index embedding table; keep the two in sync). + num_retrieved_observations: 4 + + camera_transforms: + eva_bimanual: + _target_: egomimic.utils.egomimicUtils.CameraTransforms + intrinsics_key: "base" + extrinsics_key: "x5Dec13_2" + + diffusion: true + 6dof: true + + ac_keys: + eva_bimanual: "actions_cartesian" + + annotation_key: "annotations" + annotation_sampling_mode: "random" + annotation_modality: "annotation" + default_prompt: "" + + trunk: + embed_dim: 256 + num_blocks: 16 + num_heads: 8 + token_postprocessing: "action_token" + observation_horizon: 1 + action_horizon: 64 # number of learnable action-query (cond) tokens + no_trunk: false + use_domain_embedding: true + drop_path: 0.1 + weight_init_style: "pytorch" + # RICL: size of the learned demo-index embedding table (== k above). + num_retrieved_observations: 4 + + multitask: false + pretrained: false + pretrained_checkpoint: "" + reverse_kl_samples: 8 + + domains: ["eva_bimanual"] + shared_obs_keys: ["front_img_1", "annotation"] + + shared_stem_specs: + front_img_1: + _target_: egomimic.models.hpt_nets.MLPPolicyStem + input_dim: 256 + output_dim: 256 + widths: [256] + specs: + random_horizon_masking: false + cross_attn: + crossattn_latent: 16 + crossattn_heads: 8 + crossattn_dim_head: 64 + crossattn_modality_dropout: 0.1 + modality_embed_dim: 256 + annotation: + _target_: egomimic.models.hpt_nets.QwenPooledEncoder + model_name: "Qwen/Qwen3-Embedding-0.6B" + max_length: 128 + freeze: true + dtype: "float16" + output_dim: 256 + specs: + random_horizon_masking: false + cross_attn: + crossattn_latent: 16 + crossattn_heads: 8 + crossattn_dim_head: 64 + crossattn_modality_dropout: 0.1 + modality_embed_dim: 256 + + stem_specs: + eva_bimanual: + state_ee_pose: # query proprio (14-D) + _target_: egomimic.models.hpt_nets.MLPPolicyStem + input_dim: 14 + output_dim: 256 + widths: [256] + specs: + random_horizon_masking: false + cross_attn: + crossattn_latent: 16 + crossattn_heads: 8 + crossattn_dim_head: 64 + crossattn_modality_dropout: 0.1 + modality_embed_dim: 256 + ricl_image: # retrieved-demo image pooling head + # NOTE: no encoder_specs entry -> HptRiclModel runs the SHARED front_img_1 + # ResNet on retrieved frames, then this separate cross-attn head pools. + _target_: egomimic.models.hpt_nets.MLPPolicyStem + input_dim: 256 + output_dim: 256 + widths: [256] + specs: + random_horizon_masking: false + cross_attn: + crossattn_latent: 16 + crossattn_heads: 8 + crossattn_dim_head: 64 + crossattn_modality_dropout: 0.1 + modality_embed_dim: 256 + ricl_state: # retrieved-demo state (32-D shared space) + _target_: egomimic.models.hpt_nets.MLPPolicyStem + input_dim: 32 + output_dim: 256 + widths: [256] + specs: + random_horizon_masking: false + cross_attn: + crossattn_latent: 16 + crossattn_heads: 8 + crossattn_dim_head: 64 + crossattn_modality_dropout: 0.1 + modality_embed_dim: 256 + ricl_action: # retrieved-demo action chunk (new stem) + _target_: egomimic.models.hpt_nets.ActionChunkStem + action_dim: 32 + output_dim: 256 + hidden_dim: 256 + num_layers: 2 + kernel_size: 3 + temporal_encoder: conv # 'transformer' is a drop-in alternative + specs: + random_horizon_masking: false + cross_attn: + crossattn_latent: 16 + crossattn_heads: 8 + crossattn_dim_head: 64 + crossattn_modality_dropout: 0.1 + modality_embed_dim: 256 + + head_specs: + eva_bimanual: + _target_: egomimic.models.fm_policy.FMPolicy + action_horizon: 15 # predict a 15-step chunk (matches data chunk_length) + num_inference_steps: 50 + pooling: null + time_dist: "beta" + infer_ac_dims: + eva_bimanual: 14 + model: + _target_: egomimic.models.denoising_nets.CrossTransformer + nblocks: 6 + cond_dim: 256 + hidden_dim: 128 + act_dim: 14 + act_seq: 15 # == FMPolicy.action_horizon / data chunk_length + n_heads: 4 + dropout: 0.1 + mlp_layers: 4 + mlp_ratio: 4 + + encoder_specs: + front_img_1: # shared by query + retrieved image stems + _target_: egomimic.models.hpt_nets.ResNet + output_dim: 256 + num_of_copy: 1 + + train_image_augs: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ColorJitter + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.05 + - _target_: torchvision.transforms.Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_image_augs: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 2e-4 + weight_decay: 0.0001 + +scheduler: + _target_: egomimic.utils.scheduler_utils.warmup_then_cosine + _partial_: true + warmup_epochs: 100 + total_epochs: 2000 + eta_min: 1.0e-5 +scheduler_interval: epoch diff --git a/egomimic/hydra_configs/train_zarr_hpt_ricl.yaml b/egomimic/hydra_configs/train_zarr_hpt_ricl.yaml new file mode 100644 index 000000000..7ebea637f --- /dev/null +++ b/egomimic/hydra_configs/train_zarr_hpt_ricl.yaml @@ -0,0 +1,48 @@ +# Top-level config: HPT + RICL (eva->eva oracle) on eva pick/place segments. +# Forks train_zarr_cartesian.yaml, swapping in the RICL model / data / evaluator. +# +# Run-time overrides required: +# data.retrieval_cache_dir= +# norm_stats.precomputed_norm_path= # == bank stats +# (data.bank_norm_path mirrors norm_stats.precomputed_norm_path for eva->eva.) +defaults: + - model: hpt_ricl_pickplace_qwen + - paths: default + - trainer: ddp + - debug: null + - logger: wandb + - data: cotrain_hpt_ricl_pickplace + - callbacks: checkpoints + - evaluator: eval_hpt_ricl + - override hydra/launcher: submitit + - _self_ + +name: test +description: test +ckpt_path: null +mode: train + +train_viz_evaluator: null + +hydra: + run: + dir: ./logs/${name}/${description}_${now:%Y-%m-%d_%H-%M-%S} + sweep: + dir: ./logs/${name}/${description}_${now:%Y-%m-%d_%H-%M-%S} + +launch_params: + gpus_per_node: 4 + nodes: 1 + +seed: 42 + +# Normalization and norm-stat cache. precomputed_norm_path MUST be the eva (bank) +# stats — RICL requires the query and bank stats to match so retrieved blocks land +# in the same normalized space as the query (see RiclDataModuleWrapper). +norm_stats: + norm_mode: quantile + sample_frac: 0.2 + num_workers: 6 + save_cache_dir: ${hydra:runtime.output_dir} + precomputed_norm_path: /storage/project/r-dxu345-0/rco3/EgoVerse/logs/norm_stats +reject_outliers: true diff --git a/egomimic/models/hpt_nets.py b/egomimic/models/hpt_nets.py index e4c50e426..288677811 100644 --- a/egomimic/models/hpt_nets.py +++ b/egomimic/models/hpt_nets.py @@ -13,6 +13,7 @@ from torchvision import transforms from transformers import T5Model, T5Tokenizer +from egomimic.models.denoising_nets import Conv1dBlock from egomimic.utils.egomimicUtils import get_sinusoid_encoding_table @@ -187,9 +188,9 @@ def __init__( ): super().__init__() - assert not isinstance(attn_target, nn.Module), ( - "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" - ) + assert not isinstance( + attn_target, nn.Module + ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" self.attn = attn_target() if drop_path > 0.0: self.drop_path = DropPath(drop_path) @@ -226,14 +227,21 @@ def __init__( requires_grad=True, ) - def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor = None, + key_padding_mask: torch.Tensor = None, + ): if self.layer_scale_type is None: - x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + x = x + self.drop_path( + self.attn(self.norm_1(x), attn_mask, key_padding_mask) + ) x = x + self.drop_path(self.mlp(self.norm_2(x))) else: x = ( x - + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + + self.drop_path(self.attn(self.norm_1(x), attn_mask, key_padding_mask)) * self.layer_scale_gamma1 ) x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 @@ -244,8 +252,23 @@ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): class MultiheadAttention(nn.MultiheadAttention): - def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): - return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor = None, + key_padding_mask: torch.Tensor = None, + ): + # key_padding_mask: (B, L) bool, True = ignore that key position (per-sample). + # x is (L, B, D) here (trunk runs batch_first=False via the b l d -> l b d + # pre-layer), so key_padding_mask shape (B, L) is correct. + return super().forward( + x, + x, + x, + need_weights=False, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + )[0] class SimpleTransformer(nn.Module): @@ -308,6 +331,7 @@ def forward( self, tokens: torch.Tensor, attn_mask: torch.Tensor = None, + key_padding_mask: torch.Tensor = None, use_checkpoint: bool = False, checkpoint_every_n: int = 1, checkpoint_blk_ids: Optional[List[int]] = None, @@ -315,7 +339,8 @@ def forward( """ Inputs - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) - - attn: mask of shape L x L + - attn_mask: mask of shape L x L (shared across the batch) + - key_padding_mask: mask of shape N x L, True = ignore that key (per-sample) Output - x: data of shape N x L x D (or L x N x D depending on the attention implementation) @@ -334,10 +359,12 @@ def forward( for blk_id, blk in enumerate(self.blocks): if use_checkpoint and blk_id in checkpoint_blk_ids: tokens = checkpoint.checkpoint( - blk, tokens, attn_mask, use_reentrant=False + blk, tokens, attn_mask, key_padding_mask, use_reentrant=False ) else: - tokens = blk(tokens, attn_mask=attn_mask) + tokens = blk( + tokens, attn_mask=attn_mask, key_padding_mask=key_padding_mask + ) block_outputs.append(tokens) if self.post_transformer_layer: tokens = self.post_transformer_layer(tokens) @@ -658,6 +685,74 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return feat +class ActionChunkStem(PolicyStem): + """Encoder for a *retrieved* action chunk used in RICL. + + Plain HPT never ingests an action chunk as input (actions only appear as + learnable query tokens + head outputs), so this is a dedicated stem. It maps + one chunk ``(B, Ha, Da)`` to a feature sequence ``(B, Ha, output_dim)`` that + the inherited cross-attention pool (:meth:`PolicyStem.compute_latent`) turns + into a fixed ``crossattn_latent`` token block, exactly like every other stem. + + ``temporal_encoder='conv'`` (default) stacks Conv1dBlocks over the time axis + (``Ha``-agnostic, preserves local temporal structure, light). ``'transformer'`` + embeds each step + sinusoidal position and runs a couple of self-attn layers. + Avoid flattening the chunk to one vector: the per-step trajectory is exactly + the in-context signal RICL is meant to exploit. + """ + + def __init__( + self, + action_dim: int = 32, + output_dim: int = 256, + hidden_dim: int = 256, + num_layers: int = 2, + kernel_size: int = 3, + temporal_encoder: str = "conv", + n_heads: int = 8, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.temporal_encoder = temporal_encoder + if temporal_encoder == "conv": + chans = [action_dim] + [hidden_dim] * (num_layers - 1) + [output_dim] + self.convs = nn.ModuleList( + [ + Conv1dBlock(chans[i], chans[i + 1], kernel_size=kernel_size) + for i in range(len(chans) - 1) + ] + ) + elif temporal_encoder == "transformer": + self.in_proj = nn.Linear(action_dim, output_dim) + layer = nn.TransformerEncoderLayer( + d_model=output_dim, + nhead=n_heads, + dim_feedforward=output_dim * 4, + dropout=0.1, + batch_first=True, + activation="gelu", + ) + self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers) + else: + raise ValueError( + f"ActionChunkStem: unknown temporal_encoder {temporal_encoder!r} " + "(expected 'conv' or 'transformer')" + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, Ha, Da) -> (B, Ha, output_dim) + if self.temporal_encoder == "conv": + x = x.transpose(1, 2) # (B, Da, Ha) + for conv in self.convs: + x = conv(x) + return x.transpose(1, 2) # (B, Ha, output_dim) + # transformer + h = self.in_proj(x) # (B, Ha, output_dim) + pos = get_sinusoid_encoding_table(0, h.shape[1], h.shape[-1]).to(h) + h = h + pos # (1, Ha, D) broadcasts over batch + return self.encoder(h) + + def _qwen_last_token_pool( last_hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: @@ -670,7 +765,9 @@ def _qwen_last_token_pool( if left_padded: return last_hidden_states[:, -1] seq_lens = attention_mask.sum(dim=1) - 1 - batch_idx = torch.arange(last_hidden_states.size(0), device=last_hidden_states.device) + batch_idx = torch.arange( + last_hidden_states.size(0), device=last_hidden_states.device + ) return last_hidden_states[batch_idx, seq_lens] diff --git a/egomimic/ricl/HPT_RICL.md b/egomimic/ricl/HPT_RICL.md new file mode 100644 index 000000000..769014e6a --- /dev/null +++ b/egomimic/ricl/HPT_RICL.md @@ -0,0 +1,236 @@ +# HPT RICL — retrieval in-context learning on the Qwen HPT model + +Branch: `ryanco/hpt-icl`. Last updated: 2026-06-20. + +This document covers the architecture, the tests run so far, and the tests that +still need a cluster/data to run. It is the HPT analog of the pi0.5 RICL work +(`egomimic/ricl/CLAUDE.md`, `README.md`). + +--- + +## 1. Context — why HPT differs from pi0.5 + +RICL retrieves, for each query frame, its **k≈4 nearest demonstrations** (DINOv2 +kNN, precomputed offline) and conditions the policy on their +`(image, state, action-chunk)`. + +- **pi0.5** is a unified VLM, so RICL there is "no model surgery": retrieved + images are appended to the image dict and retrieved state/action are discretized + into the text prompt — the same encoders handle everything. +- **HPT** is a `stems → trunk → heads` model: every modality is a *separate stem* + that cross-attention-pools its input into a fixed set of latent tokens. Crucially + **HPT never encodes an action chunk as input** (actions only appear as learnable + query tokens + head outputs). + +So HPT RICL turns each retrieved demo's `(image, state, action-chunk)` into tokens +via **dedicated retrieved stems** and **flat-concatenates** them into the trunk. + +--- + +## 2. Architecture overview + +### 2.1 Data (reused unchanged from `egomimic/ricl`) + +The retrieval pipeline is model-agnostic and reused as-is: `RiclDataModuleWrapper` +(`pl_utils/pl_data_utils.py`), the collate + bank loader (`ricl/data.py`), and the +DINOv2 kNN cache (`ricl/retrieval.py`). The collate attaches 5 keys per query: + +| batch key | shape | notes | +|---|---|---| +| `ricl_retrieved_images` | `(B, k, C, H, W)` | CHW float in **[0,1]** (matches the ImageNet-Normalize augs) | +| `ricl_retrieved_state` | `(B, k, Ds=32)` | normalized in the bank convention | +| `ricl_retrieved_action` | `(B, k, Ha=15, Da=32)` | normalized, converted to the 32-D shared space | +| `ricl_retrieved_mask` | `(B, k)` bool | valid neighbor (handles `< k`) | +| `ricl_retrieved_dist` | `(B, k)` float | kNN distance (unused in v1) | + +### 2.2 Model (`egomimic/algo/hpt_ricl.py`) + +`HptRicl(HPT)` + `HptRiclModel(HPTModel)`. Three retrieved stems (`embed_dim=256`): + +- **`ricl_image`** — runs the **shared** query `front_img_1` ResNet backbone, then a + *separate* cross-attention pooling head → 16 tokens. +- **`ricl_state`** — a separate `MLPPolicyStem` (32-D input) → 16 tokens. +- **`ricl_action`** — a new `ActionChunkStem` (`egomimic/models/hpt_nets.py`): + temporal **Conv1D** over the chunk (`temporal_encoder='conv'`, swappable to + `'transformer'`) → cross-attn pool → 16 tokens. This is HPT's first + action-chunk-as-input encoder. + +**Fusion = flat concatenation (no trunk architecture surgery).** Retrieved token +blocks are appended after the query stem tokens, tagged with learned +demo-index + modality embeddings (zero-init), and invalid demos are (a) zeroed at +the stem output and (b) hidden from attention via a per-sample `key_padding_mask`. + +Trunk input sequence (each `[..]` = a 16-token block): + +``` +[action queries (64)] [q-img] [q-state] [q-lang] <- plain HPT (positions unchanged) +[d0-img][d0-state][d0-act] ... [d(k-1)-img][...][...] <- retrieved, appended last + + learned (demo-index, modality) embeddings on the retrieved spans + + key_padding_mask hides invalid-demo token spans (per-sample, True = ignore) +``` + +Because retrieved modalities are simply **absent when k=0** and **appended after** +the query blocks, query token positions and the global sinusoidal position embedding +are byte-identical to plain HPT, and the head always reads `trunk_tokens[:, :action_horizon]`. +**=> k=0 reduces to plain HPT exactly** (verified bit-exact). + +Token budget at k=4: `64 (action) + 48 (3 query stems × 16) + 192 (3 mod × 4 demos × 16) = 304`. + +### 2.3 Trunk masking (`egomimic/models/hpt_nets.py`) + +The trunk previously threaded only PyTorch's `attn_mask` (`L×L`, batch-shared). +Added an optional, `None`-defaulted **`key_padding_mask`** (`(B, L)`, per-sample) +through `MultiheadAttention.forward`, `BlockWithMasking.forward`, and +`SimpleTransformer.forward` (incl. the checkpoint path). `HPTModel.forward_features` +reads it via `getattr(self, "_ricl_key_padding_mask", None)` — absent/None for plain +HPT, so every existing call site is unchanged. + +### 2.4 Evaluator (`egomimic/eval/hpt_ricl_eval.py`) + +`HptRiclEval(HPTEvalVideo)` runs the model on the same frames twice — **retrieval** +(full batch) vs **floor** (`ricl_*` stripped from the *raw* batch → genuine k=0) — +and reports `RICL/retrieval_*`, `RICL/floor_*`, `RICL/delta_*`, `RICL/retrieval_helps`. +`wants_raw_batch = True` so the floor is built *before* `process_batch_for_training`. +Random-demo and paired-seed flow-loss controls are wired but **off by default** (v1). + +### 2.5 Key files + +| File | Role | +|---|---| +| `egomimic/algo/hpt_ricl.py` | `HptRicl` + `HptRiclModel` (3 overrides + stem encode/mask) | +| `egomimic/models/hpt_nets.py` | `ActionChunkStem`; `key_padding_mask` plumbing | +| `egomimic/algo/hpt.py` | `HPT.model_cls` hook; `forward_features`/`resume_from_depth` read the mask | +| `egomimic/eval/hpt_ricl_eval.py` | `HptRiclEval` (retrieval vs floor) | +| `egomimic/hydra_configs/model/hpt_ricl_pickplace_qwen.yaml` | model recipe | +| `egomimic/hydra_configs/data/cotrain_hpt_ricl_pickplace.yaml` | data (forks pi RICL pickplace) | +| `egomimic/hydra_configs/evaluator/eval_hpt_ricl.yaml` | evaluator recipe | +| `egomimic/hydra_configs/train_zarr_hpt_ricl.yaml` | top-level train config | + +### 2.6 Locked design decisions + +- Fusion = **flat concat** + learned demo/modality embeds. +- First run = **eva→eva oracle** (robot bank). aria→eva cross-embodiment is a later + data-config swap, not a code change. +- Retrieved image encoder **shares the query ResNet** (separate pooling head); + state/action encoders fully separate. +- Eval v1 = **retrieval vs floor** (random/flow-loss flags exist, off). +- Retrieved inputs in the **32-D shared space** (`data state_dim/action_dim=32`); + query stays 14-D — the retrieved stems are separate/learned, so the model config is + identical for a later aria→eva run. +- Query keymap is `cartesian` (→ `front_img_1`); the **bank** stays `cartesian_pi` + (`base_0_rgb`) and is repackaged into the model-agnostic `ricl_*` keys. +- HPT trunk `action_horizon` (# cond tokens, 64) ≠ head `act_seq` (prediction length, + 15) **by design** — they cross-attend; do not force-align. Data `chunk_length`=15. + +--- + +## 3. Tests run (local, CPU — all passing) + +No GPU/S3/cache needed; a fake `norm_stats` exercises the real algo + config path. + +### 3.1 Component / correctness + +| Test | Result | +|---|---| +| **k=0 ≡ plain HPT** (mask all-false vs no-ricl) | **bit-exact**, max diff `0.0` | +| Per-sample variable-k masking (partial mask) | sample with all demos invalid == no-ricl; 1 valid demo differs | +| Demos change outputs (mask all-true) | output differs from no-ricl | +| `ActionChunkStem` shapes + grads (conv **and** transformer) | `(B*k,15,32) → (B*k,16,256)`, grads flow | +| 3 retrieved stems instantiate from the **committed config** | each → `(4,16,256)` | +| Plain `HPTModel` regression (Step-1 mask edits) | default path unchanged, finite | +| `ruff` on all changed Python files | clean | + +### 3.2 Full real-pipeline path (config + `HptRicl` + fake `norm_stats`) + +| Test | Result | +|---|---| +| Construct `HptRicl` via real config | builds `HptRiclModel`, 3 ricl stems registered | +| `process_batch_for_training` carries `ricl_*`, cleans base `None`-key | ✅ | +| `forward_training` (k>0) finite loss + `backward` | grads reach `ricl_action` stem **and** `demo_embed` | +| k=0 floor via `strip_ricl_keys` | finite loss | +| `forward_eval` (sampled actions) | `(2,15,14)` finite predictions | +| **kpm length == trunk sequence length** | `304 == 64+48+192` (no off-by-one) | +| Full `HptRiclEval.compute_metrics_and_viz` | emits `RICL/retrieval_*`, `floor_*`, `delta_*`, `retrieval_helps`, `Valid/action_loss` | +| **Overfit a fixed batch** (80 AdamW steps) | loss `67.6 → 6.0` (**91% drop**) → the model genuinely learns | + +### 3.3 Config + +- Full top-level `train_zarr_hpt_ricl` composes (Hydra `--cfg job`); RICL fields + resolve correctly (`HptRicl` target, `num_retrieved_observations` at algo+trunk, + 3 stems, head horizon 15, query keymap `cartesian`, bank `cartesian_pi`, 32-D, + `bank_norm_path` interpolation). +- Plain qwen HPT config (`train_zarr_cartesian model=hpt_bc_pickplace_qwen_pooled`) + still composes — **no regression**. + +> Caveat: §3.2 metric values use random weights + random data, so `retrieval_helps` +> etc. are **meaningless** there — they validate the *machinery*, not learning. The +> overfit drop validates *learnability of the architecture*, not task performance. + +--- + +## 4. Tests that still need the cluster (GPU + S3 + a retrieval cache) + +The fake-`norm_stats` harness skips only the real S3 data loading + retrieval cache. +Run on a node (per `CLAUDE.md`; eval/smoke fits 48 GB, export `TORCH_COMPILE_DISABLE=1` +for short runs): + +```bash +salloc -A gts-dxu345-rl2 -N1 -q inferno -t 1:00:00 --mem=75G --gres=gpu:l40s:1 +source emimic/bin/activate +``` + +1. **1-step real `trainHydra`** — exercises the data pipeline → collate → loss: + ```bash + python egomimic/trainHydra.py --config-name train_zarr_hpt_ricl \ + data.retrieval_cache_dir= \ + norm_stats.precomputed_norm_path= \ + trainer=debug logger=debug + ``` + Pass: finite training loss; `front_img_1`/`state_ee_pose`/`actions_cartesian` + query keys line up; `ricl_*` flow through. + +2. **Eval retrieval vs floor on real frames** (the scientific check) — on the eva→eva + oracle expect `RICL/delta_*_mse ≤ 0` / `RICL/retrieval_helps = True`. + +3. **k-sweep** (`model.robomimic_model.num_retrieved_observations` ∈ {0,1,4}, and the + matching `trunk.num_retrieved_observations`): k=0 must match plain HPT; larger k + should change/improve predictions. + +4. **Real overfit** — drive one real episode's loss to ~0 (learnability on real data). + +5. **Full training run** + (optional) the eval ablations below. + +### Optional deeper tests (also cluster) +- `evaluator.compute_random=true` (random-demo control) and + `evaluator.compute_flow_loss=true` (paired-seed flow loss) — already wired. +- A/B the action encoder: + `model.robomimic_model.stem_specs.eva_bimanual.ricl_action.temporal_encoder=transformer`. +- Build a tiny smoke cache via `egomimic/ricl/scripts/build_ricl_smoke_cache.py` so + step 1 needs no large precomputed cache. +- aria→eva cross-embodiment: new data config (bank=aria, `HumanBimanualCartesianEuler`, + `bank_norm_path` → aria stats); model config unchanged (32-D retrieved stems). + +--- + +## 5. How to run / wire + +- **Norm stats invariant:** `norm_stats.precomputed_norm_path` must be the eva (bank) + stats; `data.bank_norm_path` interpolates from it (`${norm_stats.precomputed_norm_path}`) + for the eva→eva oracle. For a cross-embodiment run, set `bank_norm_path` to the + bank's own stats explicitly. +- **k is set in two places** (keep in sync): `robomimic_model.num_retrieved_observations` + (algo) and `robomimic_model.trunk.num_retrieved_observations` (sizes the demo-index + embedding table). +- The data config reuses the validated pi RICL pickplace splits + eva bank wiring + (`data/cotrain_pi_ricl_pickplace.yaml`); only the query camera keymap (→ `cartesian`) + and the retrieved widths (→ 32-D) change. + +--- + +## 6. Status summary + +- **Code + configs:** complete; `ruff` clean; plain HPT unregressed. +- **Local verification:** complete (correctness incl. bit-exact k=0 equivalence, + full algo path, learnability). +- **Remaining:** the cluster end-to-end (§4) — the real data pipeline, the + retrieval-vs-floor science on real frames, and the k-sweep.