Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions egomimic/hydra_configs/data/mecka_pi_fold_clothes_freeform.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper

train_datasets:
mecka_bimanual:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver
resolver:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.S3EpisodeResolver
folder_path: ${paths.dataset_dir}
key_map:
_target_: egomimic.rldb.embodiment.human.Mecka.get_keymap
mode: cartesian_pi
annotation_key: annotations
transform_list:
_target_: egomimic.rldb.embodiment.human.Mecka.get_transform_list
mode: cartesian
filters:
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['lab'] == 'mecka' and row['task'] == 'folding_clothes'"
mode: train

valid_datasets:
mecka_bimanual:
_target_: ${data.train_datasets.mecka_bimanual._target_}
resolver: ${data.train_datasets.mecka_bimanual.resolver}
filters: ${data.train_datasets.mecka_bimanual.filters}
mode: valid

train_dataloader_params:
mecka_bimanual:
batch_size: 64
num_workers: 10
valid_dataloader_params:
mecka_bimanual:
batch_size: 64
num_workers: 10

# `+evaluator@train_viz_evaluator=train_viz_pi`.
train_viz_datasets:
mecka_bimanual:
_target_: ${data.train_datasets.mecka_bimanual._target_}
resolver: ${data.train_datasets.mecka_bimanual.resolver}
filters: ${data.train_datasets.mecka_bimanual.filters}
mode: train
train_viz_dataloader_params:
mecka_bimanual:
batch_size: 64
num_workers: 10
2 changes: 1 addition & 1 deletion egomimic/hydra_configs/trainer/ddp_pi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ accelerator: gpu
devices: ${eval:'${launch_params.gpus_per_node} * ${launch_params.nodes}'}
num_nodes: ${launch_params.nodes}
sync_batchnorm: True
check_val_every_n_epoch: 100
check_val_every_n_epoch: 150
num_sanity_val_steps: 0
51 changes: 51 additions & 0 deletions egomimic/rldb/zarr/test_multi_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import random

import pytest
import torch

from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset

Expand Down Expand Up @@ -103,3 +104,53 @@ def test_retry_exhausts_when_everything_is_bad():

with pytest.raises(RuntimeError, match="Entire MultiDataset bad"):
mds[0]


def _cartesian_stats(low: float = -1.0, high: float = 1.0) -> dict:
q_low = torch.full((1, 12), low)
q_high = torch.full((1, 12), high)
return {
"quantile_0_01": q_low,
"quantile_99_99": q_high,
"quantile_1": q_low,
"quantile_99": q_high,
}


def _make_cartesian_mds(reject_outliers: bool) -> MultiDataset:
mds = MultiDataset(
datasets={"ep": _DummyLeaf("ep", 1)},
mode="total",
reject_outliers=reject_outliers,
)
mds.key_types = {9: {"actions_cartesian": "action_keys"}}
mds.zarr_keys = {9: {"actions_cartesian": "actions_cartesian"}}
mds.norm_stats = {9: {"actions_cartesian": _cartesian_stats()}}
return mds


def test_cartesian_rotation_columns_are_not_quantile_checked():
mds = _make_cartesian_mds(reject_outliers=True)
actions = torch.zeros(1, 12)
actions[..., list(MultiDataset.CARTESIAN_ACTION_YPR_INDICES)] = 10.0
data = {"embodiment": 9, "actions_cartesian": actions}

assert mds._check_bounds(data, _DummyLeaf("ep", 1), 0, "ep") is None


def test_reject_outliers_false_accepts_finite_xyz_quantile_violation():
mds = _make_cartesian_mds(reject_outliers=False)
actions = torch.zeros(1, 12)
actions[..., 0] = 2.0
data = {"embodiment": 9, "actions_cartesian": actions}

assert mds._check_bounds(data, _DummyLeaf("ep", 1), 0, "ep") is None


def test_reject_outliers_true_rejects_xyz_quantile_violation():
mds = _make_cartesian_mds(reject_outliers=True)
actions = torch.zeros(1, 12)
actions[..., 0] = 2.0
data = {"embodiment": 9, "actions_cartesian": actions}

assert "Bounds violation" in mds._check_bounds(data, _DummyLeaf("ep", 1), 0, "ep")
57 changes: 42 additions & 15 deletions egomimic/rldb/zarr/zarr_dataset_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,8 @@ class MultiDataset(torch.utils.data.Dataset):
"""

NORMALIZE_KEY_TYPES = ("proprio_keys", "action_keys")
CARTESIAN_ACTION_XYZ_INDICES = (0, 1, 2, 6, 7, 8)
CARTESIAN_ACTION_YPR_INDICES = (3, 4, 5, 9, 10, 11)

def __init__(
self,
Expand All @@ -762,6 +764,7 @@ def __init__(
valid_ratio: float = 0.2,
norm_mode: str = "zscore",
state: dict | None = None,
reject_outliers: bool = True,
**kwargs,
):
"""
Expand All @@ -772,6 +775,11 @@ def __init__(
valid_ratio: Train/valid split ratio.
norm_mode: One of "zscore", "minmax", "quantile".
state: If provided, populate stats fields from this dict (deploy mode).
reject_outliers: If true, reject finite values outside quantile bounds
for keys where quantile rejection is enabled. NaN/Inf rejection
always remains active when norm stats are present. For
``actions_cartesian`` bimanual samples, rotation columns are
excluded from finite quantile checks.
"""
super().__init__()

Expand All @@ -783,6 +791,7 @@ def __init__(
self.shapes: dict[int, dict[str, tuple]] = {}
self.norm_stats: dict[int, dict[str, dict[str, np.ndarray]]] = {}
self._norm_run_metadata: dict[str, float | int | None] | None = None
self.reject_outliers = bool(reject_outliers)

# ---- Dataset graph fields ----
self.datasets: dict = {}
Expand Down Expand Up @@ -858,6 +867,7 @@ def set_norm_stats_from(self, source: "MultiDataset") -> None:
self.shapes = source.shapes
self.embodiments = source.embodiments
self.norm_mode = source.norm_mode
self.reject_outliers = source.reject_outliers
# Each MultiDataset keeps its own warning-dedup state.
self._warned_violations = set()
for ds in self.datasets.values():
Expand All @@ -868,8 +878,8 @@ def _check_bounds(
self, data: dict, dataset, idx: int, dataset_name: str
) -> str | None:
"""Return a violation message if any tracked key in ``data`` has NaN/Inf
or values outside per-key quantile bounds. ``None`` means the sample
passes. Logs each (episode, key) violation once.
or rejected finite values. ``None`` means the sample passes. Logs each
(episode, key) violation once.
"""
embodiment_id = data.get("embodiment")
if embodiment_id is None:
Expand All @@ -881,6 +891,7 @@ def _check_bounds(
episode_name = self._episode_name_for_dataset(dataset, dataset_name)

for key_name, stats in per_emb_stats.items():
key_type = self.key_types.get(embodiment_id, {}).get(key_name)
zarr_key = self.zarr_keys.get(embodiment_id, {}).get(key_name)
if zarr_key is None or zarr_key not in data:
continue
Expand All @@ -892,6 +903,22 @@ def _check_bounds(
else:
continue

if torch.any(torch.isnan(arr)) or torch.any(torch.isinf(arr)):
prefix = f"NaN/Inf in {zarr_key} ep={episode_name} frame={idx}"
warn_key = f"naninf:{episode_name}:{zarr_key}"
if warn_key not in self._warned_violations:
self._warned_violations.add(warn_key)
logger.warning(prefix)
return prefix

is_cartesian_action = (
key_type == "action_keys"
and key_name == "actions_cartesian"
and arr.shape[-1] == 12
)
if not self.reject_outliers:
continue

q_low = stats.get(
"quantile_0_01", stats.get("quantile_0_1", stats["quantile_1"])
)
Expand All @@ -906,16 +933,16 @@ def _check_bounds(
except RuntimeError:
continue

if torch.any(torch.isnan(arr)) or torch.any(torch.isinf(arr)):
prefix = f"NaN/Inf in {zarr_key} ep={episode_name} frame={idx}"
warn_key = f"naninf:{episode_name}:{zarr_key}"
if warn_key not in self._warned_violations:
self._warned_violations.add(warn_key)
logger.warning(prefix)
return prefix
if is_cartesian_action:
xyz_idx = list(self.CARTESIAN_ACTION_XYZ_INDICES)
arr_for_quantiles = arr[..., xyz_idx]
q_low = q_low[..., xyz_idx]
q_high = q_high[..., xyz_idx]
else:
arr_for_quantiles = arr

below = arr < q_low
above = arr > q_high
below = arr_for_quantiles < q_low
above = arr_for_quantiles > q_high
if torch.any(below) or torch.any(above):
prefix = f"Bounds violation in {zarr_key} ep={episode_name} frame={idx}"
warn_key = f"bounds:{episode_name}:{zarr_key}"
Expand Down Expand Up @@ -1387,11 +1414,9 @@ def attach_normalize_transforms(
MultiDataset level in ``__getitem__``, not as per-leaf transforms.

Kept as a thin shim that calls ``set_norm_stats_from(self)`` on each
MultiDataset in ``datasets`` so existing callers keep working. The
``reject_outliers`` flag is no longer honored — bounds checking is
always on when stats are populated. To disable, clear ``norm_stats``.
MultiDataset in ``datasets`` so existing callers keep working.
"""
del reject_outliers # unused
self.reject_outliers = bool(reject_outliers)
graph = datasets if datasets is not None else self.datasets
for ds in graph.values():
if isinstance(ds, MultiDataset):
Expand Down Expand Up @@ -1425,6 +1450,7 @@ def to_state(self) -> dict:
"zarr_keys": copy.deepcopy(self.zarr_keys),
"shapes": copy.deepcopy(self.shapes),
"norm_stats": self._clone_norm_stats(self.norm_stats),
"reject_outliers": self.reject_outliers,
}

@classmethod
Expand All @@ -1442,6 +1468,7 @@ def _load_state(self, state: dict) -> None:
self.zarr_keys = copy.deepcopy(state.get("zarr_keys", {}))
self.shapes = copy.deepcopy(state.get("shapes", {}))
self.norm_stats = self._clone_norm_stats(state.get("norm_stats", {}))
self.reject_outliers = bool(state.get("reject_outliers", self.reject_outliers))
for emb in self.embodiments:
self.key_types.setdefault(emb, {})
self.zarr_keys.setdefault(emb, {})
Expand Down
2 changes: 2 additions & 0 deletions egomimic/trainHydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,11 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

# Stats-only MultiDataset (no graph of its own; explicitly populated from
# datamodule.train_datasets). MultiDataset now owns NormStats's role too.
reject_outliers = OmegaConf.select(cfg, "reject_outliers", default=True)
norm_stats = MultiDataset(
state={},
norm_mode=OmegaConf.select(cfg, "norm_stats.norm_mode", default="quantile"),
reject_outliers=reject_outliers,
)
norm_stats.populate_from_datasets(datamodule.train_datasets)

Expand Down
Loading