From 635fcbac8c62c695ba0a08eca282163de74d2355 Mon Sep 17 00:00:00 2001 From: Ryan Co Date: Mon, 1 Jun 2026 15:55:58 -0400 Subject: [PATCH] remove rotation bounds checks for norm stats --- .../data/mecka_pi_fold_clothes_freeform.yaml | 48 ++++++++++++++++ egomimic/hydra_configs/trainer/ddp_pi.yaml | 2 +- egomimic/rldb/zarr/test_multi_retry.py | 51 +++++++++++++++++ egomimic/rldb/zarr/zarr_dataset_multi.py | 57 ++++++++++++++----- egomimic/trainHydra.py | 2 + 5 files changed, 144 insertions(+), 16 deletions(-) create mode 100644 egomimic/hydra_configs/data/mecka_pi_fold_clothes_freeform.yaml diff --git a/egomimic/hydra_configs/data/mecka_pi_fold_clothes_freeform.yaml b/egomimic/hydra_configs/data/mecka_pi_fold_clothes_freeform.yaml new file mode 100644 index 000000000..e34d0a8e5 --- /dev/null +++ b/egomimic/hydra_configs/data/mecka_pi_fold_clothes_freeform.yaml @@ -0,0 +1,48 @@ +_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper + +train_datasets: + mecka_bimanual: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver + resolver: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.S3EpisodeResolver + folder_path: ${paths.dataset_dir} + key_map: + _target_: egomimic.rldb.embodiment.human.Mecka.get_keymap + mode: cartesian_pi + annotation_key: annotations + transform_list: + _target_: egomimic.rldb.embodiment.human.Mecka.get_transform_list + mode: cartesian + filters: + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['lab'] == 'mecka' and row['task'] == 'folding_clothes'" + mode: train + +valid_datasets: + mecka_bimanual: + _target_: ${data.train_datasets.mecka_bimanual._target_} + resolver: ${data.train_datasets.mecka_bimanual.resolver} + filters: ${data.train_datasets.mecka_bimanual.filters} + mode: valid + +train_dataloader_params: + mecka_bimanual: + batch_size: 64 + num_workers: 10 +valid_dataloader_params: + mecka_bimanual: + batch_size: 64 + num_workers: 10 + +# `+evaluator@train_viz_evaluator=train_viz_pi`. +train_viz_datasets: + mecka_bimanual: + _target_: ${data.train_datasets.mecka_bimanual._target_} + resolver: ${data.train_datasets.mecka_bimanual.resolver} + filters: ${data.train_datasets.mecka_bimanual.filters} + mode: train +train_viz_dataloader_params: + mecka_bimanual: + batch_size: 64 + num_workers: 10 diff --git a/egomimic/hydra_configs/trainer/ddp_pi.yaml b/egomimic/hydra_configs/trainer/ddp_pi.yaml index 6e786fc88..4fdba78d1 100644 --- a/egomimic/hydra_configs/trainer/ddp_pi.yaml +++ b/egomimic/hydra_configs/trainer/ddp_pi.yaml @@ -7,5 +7,5 @@ accelerator: gpu devices: ${eval:'${launch_params.gpus_per_node} * ${launch_params.nodes}'} num_nodes: ${launch_params.nodes} sync_batchnorm: True -check_val_every_n_epoch: 100 +check_val_every_n_epoch: 150 num_sanity_val_steps: 0 diff --git a/egomimic/rldb/zarr/test_multi_retry.py b/egomimic/rldb/zarr/test_multi_retry.py index c9001b014..b3ecc7a9d 100644 --- a/egomimic/rldb/zarr/test_multi_retry.py +++ b/egomimic/rldb/zarr/test_multi_retry.py @@ -9,6 +9,7 @@ import random import pytest +import torch from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset @@ -103,3 +104,53 @@ def test_retry_exhausts_when_everything_is_bad(): with pytest.raises(RuntimeError, match="Entire MultiDataset bad"): mds[0] + + +def _cartesian_stats(low: float = -1.0, high: float = 1.0) -> dict: + q_low = torch.full((1, 12), low) + q_high = torch.full((1, 12), high) + return { + "quantile_0_01": q_low, + "quantile_99_99": q_high, + "quantile_1": q_low, + "quantile_99": q_high, + } + + +def _make_cartesian_mds(reject_outliers: bool) -> MultiDataset: + mds = MultiDataset( + datasets={"ep": _DummyLeaf("ep", 1)}, + mode="total", + reject_outliers=reject_outliers, + ) + mds.key_types = {9: {"actions_cartesian": "action_keys"}} + mds.zarr_keys = {9: {"actions_cartesian": "actions_cartesian"}} + mds.norm_stats = {9: {"actions_cartesian": _cartesian_stats()}} + return mds + + +def test_cartesian_rotation_columns_are_not_quantile_checked(): + mds = _make_cartesian_mds(reject_outliers=True) + actions = torch.zeros(1, 12) + actions[..., list(MultiDataset.CARTESIAN_ACTION_YPR_INDICES)] = 10.0 + data = {"embodiment": 9, "actions_cartesian": actions} + + assert mds._check_bounds(data, _DummyLeaf("ep", 1), 0, "ep") is None + + +def test_reject_outliers_false_accepts_finite_xyz_quantile_violation(): + mds = _make_cartesian_mds(reject_outliers=False) + actions = torch.zeros(1, 12) + actions[..., 0] = 2.0 + data = {"embodiment": 9, "actions_cartesian": actions} + + assert mds._check_bounds(data, _DummyLeaf("ep", 1), 0, "ep") is None + + +def test_reject_outliers_true_rejects_xyz_quantile_violation(): + mds = _make_cartesian_mds(reject_outliers=True) + actions = torch.zeros(1, 12) + actions[..., 0] = 2.0 + data = {"embodiment": 9, "actions_cartesian": actions} + + assert "Bounds violation" in mds._check_bounds(data, _DummyLeaf("ep", 1), 0, "ep") diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index 47eb2bce2..e7060cdc9 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -751,6 +751,8 @@ class MultiDataset(torch.utils.data.Dataset): """ NORMALIZE_KEY_TYPES = ("proprio_keys", "action_keys") + CARTESIAN_ACTION_XYZ_INDICES = (0, 1, 2, 6, 7, 8) + CARTESIAN_ACTION_YPR_INDICES = (3, 4, 5, 9, 10, 11) def __init__( self, @@ -762,6 +764,7 @@ def __init__( valid_ratio: float = 0.2, norm_mode: str = "zscore", state: dict | None = None, + reject_outliers: bool = True, **kwargs, ): """ @@ -772,6 +775,11 @@ def __init__( valid_ratio: Train/valid split ratio. norm_mode: One of "zscore", "minmax", "quantile". state: If provided, populate stats fields from this dict (deploy mode). + reject_outliers: If true, reject finite values outside quantile bounds + for keys where quantile rejection is enabled. NaN/Inf rejection + always remains active when norm stats are present. For + ``actions_cartesian`` bimanual samples, rotation columns are + excluded from finite quantile checks. """ super().__init__() @@ -783,6 +791,7 @@ def __init__( self.shapes: dict[int, dict[str, tuple]] = {} self.norm_stats: dict[int, dict[str, dict[str, np.ndarray]]] = {} self._norm_run_metadata: dict[str, float | int | None] | None = None + self.reject_outliers = bool(reject_outliers) # ---- Dataset graph fields ---- self.datasets: dict = {} @@ -858,6 +867,7 @@ def set_norm_stats_from(self, source: "MultiDataset") -> None: self.shapes = source.shapes self.embodiments = source.embodiments self.norm_mode = source.norm_mode + self.reject_outliers = source.reject_outliers # Each MultiDataset keeps its own warning-dedup state. self._warned_violations = set() for ds in self.datasets.values(): @@ -868,8 +878,8 @@ def _check_bounds( self, data: dict, dataset, idx: int, dataset_name: str ) -> str | None: """Return a violation message if any tracked key in ``data`` has NaN/Inf - or values outside per-key quantile bounds. ``None`` means the sample - passes. Logs each (episode, key) violation once. + or rejected finite values. ``None`` means the sample passes. Logs each + (episode, key) violation once. """ embodiment_id = data.get("embodiment") if embodiment_id is None: @@ -881,6 +891,7 @@ def _check_bounds( episode_name = self._episode_name_for_dataset(dataset, dataset_name) for key_name, stats in per_emb_stats.items(): + key_type = self.key_types.get(embodiment_id, {}).get(key_name) zarr_key = self.zarr_keys.get(embodiment_id, {}).get(key_name) if zarr_key is None or zarr_key not in data: continue @@ -892,6 +903,22 @@ def _check_bounds( else: continue + if torch.any(torch.isnan(arr)) or torch.any(torch.isinf(arr)): + prefix = f"NaN/Inf in {zarr_key} ep={episode_name} frame={idx}" + warn_key = f"naninf:{episode_name}:{zarr_key}" + if warn_key not in self._warned_violations: + self._warned_violations.add(warn_key) + logger.warning(prefix) + return prefix + + is_cartesian_action = ( + key_type == "action_keys" + and key_name == "actions_cartesian" + and arr.shape[-1] == 12 + ) + if not self.reject_outliers: + continue + q_low = stats.get( "quantile_0_01", stats.get("quantile_0_1", stats["quantile_1"]) ) @@ -906,16 +933,16 @@ def _check_bounds( except RuntimeError: continue - if torch.any(torch.isnan(arr)) or torch.any(torch.isinf(arr)): - prefix = f"NaN/Inf in {zarr_key} ep={episode_name} frame={idx}" - warn_key = f"naninf:{episode_name}:{zarr_key}" - if warn_key not in self._warned_violations: - self._warned_violations.add(warn_key) - logger.warning(prefix) - return prefix + if is_cartesian_action: + xyz_idx = list(self.CARTESIAN_ACTION_XYZ_INDICES) + arr_for_quantiles = arr[..., xyz_idx] + q_low = q_low[..., xyz_idx] + q_high = q_high[..., xyz_idx] + else: + arr_for_quantiles = arr - below = arr < q_low - above = arr > q_high + below = arr_for_quantiles < q_low + above = arr_for_quantiles > q_high if torch.any(below) or torch.any(above): prefix = f"Bounds violation in {zarr_key} ep={episode_name} frame={idx}" warn_key = f"bounds:{episode_name}:{zarr_key}" @@ -1387,11 +1414,9 @@ def attach_normalize_transforms( MultiDataset level in ``__getitem__``, not as per-leaf transforms. Kept as a thin shim that calls ``set_norm_stats_from(self)`` on each - MultiDataset in ``datasets`` so existing callers keep working. The - ``reject_outliers`` flag is no longer honored — bounds checking is - always on when stats are populated. To disable, clear ``norm_stats``. + MultiDataset in ``datasets`` so existing callers keep working. """ - del reject_outliers # unused + self.reject_outliers = bool(reject_outliers) graph = datasets if datasets is not None else self.datasets for ds in graph.values(): if isinstance(ds, MultiDataset): @@ -1425,6 +1450,7 @@ def to_state(self) -> dict: "zarr_keys": copy.deepcopy(self.zarr_keys), "shapes": copy.deepcopy(self.shapes), "norm_stats": self._clone_norm_stats(self.norm_stats), + "reject_outliers": self.reject_outliers, } @classmethod @@ -1442,6 +1468,7 @@ def _load_state(self, state: dict) -> None: self.zarr_keys = copy.deepcopy(state.get("zarr_keys", {})) self.shapes = copy.deepcopy(state.get("shapes", {})) self.norm_stats = self._clone_norm_stats(state.get("norm_stats", {})) + self.reject_outliers = bool(state.get("reject_outliers", self.reject_outliers)) for emb in self.embodiments: self.key_types.setdefault(emb, {}) self.zarr_keys.setdefault(emb, {}) diff --git a/egomimic/trainHydra.py b/egomimic/trainHydra.py index c93419cd0..bf62b5a62 100644 --- a/egomimic/trainHydra.py +++ b/egomimic/trainHydra.py @@ -119,9 +119,11 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: # Stats-only MultiDataset (no graph of its own; explicitly populated from # datamodule.train_datasets). MultiDataset now owns NormStats's role too. + reject_outliers = OmegaConf.select(cfg, "reject_outliers", default=True) norm_stats = MultiDataset( state={}, norm_mode=OmegaConf.select(cfg, "norm_stats.norm_mode", default="quantile"), + reject_outliers=reject_outliers, ) norm_stats.populate_from_datasets(datamodule.train_datasets)