diff --git a/egomimic/eval/latent_dataset.py b/egomimic/eval/latent_dataset.py index 161bbd09e..3d8edec8d 100644 --- a/egomimic/eval/latent_dataset.py +++ b/egomimic/eval/latent_dataset.py @@ -61,6 +61,11 @@ "_shuffle_random", "_shuffle_pairs", "_shuffle_custom", + "use_tokenizer", + "model_name", + "sampling_mode", + "annotation_key", + "default_prompt", } ) @@ -86,18 +91,38 @@ def build_dataset( swallowing) — typos in yaml fail loudly. """ if mode == "random": - lam = ( - "lambda row: row['task'] == " - f"{task!r} and row['robot_name'] == {embodiment!r}" - ) - filters = DatasetFilter(filter_lambdas=[lam]) - logger.info("[build_dataset] %s | random mode | filter=%s", embodiment, lam) - return MultiDataset._from_resolver( + if filters is not None: + logger.info( + "[build_dataset] %s | random mode | using passed filters", embodiment + ) + else: + lam = ( + "lambda row: row['task'] == " + f"{task!r} and row['robot_name'] == {embodiment!r}" + ) + filters = DatasetFilter(filter_lambdas=[lam]) + logger.info("[build_dataset] %s | random mode | filter=%s", embodiment, lam) + base = MultiDataset._from_resolver( resolver, filters=filters, mode="total", valid_ratio=valid_ratio, ) + if stride is not None and stride > 0: + logger.info( + "[build_dataset] %s | random+stride=%d wrapping", + embodiment, + stride, + ) + return EvenStrideDataset(base, stride=stride) + if frames_per_episode is not None: + logger.info( + "[build_dataset] %s | random+frames_per_episode=%d wrapping", + embodiment, + frames_per_episode, + ) + return EvenStrideDataset(base, frames_per_episode=frames_per_episode) + return base if mode in ("pairs", "custom"): if not hashes: