From 6c28ef894af215bfaaf665aa3015c5645d91e53f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 8 Oct 2025 14:27:52 +0200 Subject: [PATCH 01/12] chore(docs): add missing license headers (#2140) --- src/lerobot/motors/__init__.py | 16 ++++++++++++++++ src/lerobot/processor/policy_robot_bridge.py | 16 ++++++++++++++++ src/lerobot/robots/__init__.py | 16 ++++++++++++++++ tests/plugins/reachy2_sdk.py | 16 ++++++++++++++++ tests/policies/pi0_pi05/test_pi0.py | 14 ++++++++++++++ tests/policies/pi0_pi05/test_pi05.py | 14 ++++++++++++++ .../pi0_pi05/test_pi05_original_vs_lerobot.py | 16 ++++++++++++++++ .../pi0_pi05/test_pi0_original_vs_lerobot.py | 16 ++++++++++++++++ tests/processor/test_batch_conversion.py | 16 ++++++++++++++++ tests/processor/test_converters.py | 16 ++++++++++++++++ tests/processor/test_tokenizer_processor.py | 16 ++++++++++++++++ tests/utils/test_io_utils.py | 5 ++++- tests/utils/test_logging_utils.py | 5 ++++- tests/utils/test_random_utils.py | 5 ++++- tests/utils/test_train_utils.py | 5 ++++- tests/utils/test_visualization_utils.py | 16 ++++++++++++++++ 16 files changed, 204 insertions(+), 4 deletions(-) diff --git a/src/lerobot/motors/__init__.py b/src/lerobot/motors/__init__.py index dfbfbaee8fc..850ef33d74e 100644 --- a/src/lerobot/motors/__init__.py +++ b/src/lerobot/motors/__init__.py @@ -1 +1,17 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus diff --git a/src/lerobot/processor/policy_robot_bridge.py b/src/lerobot/processor/policy_robot_bridge.py index 845ee065a64..25887d414ee 100644 --- a/src/lerobot/processor/policy_robot_bridge.py +++ b/src/lerobot/processor/policy_robot_bridge.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import asdict, dataclass from typing import Any diff --git a/src/lerobot/robots/__init__.py b/src/lerobot/robots/__init__.py index d8fd0de9311..1dba0f1b089 100644 --- a/src/lerobot/robots/__init__.py +++ b/src/lerobot/robots/__init__.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config import RobotConfig from .robot import Robot from .utils import make_robot_from_config diff --git a/tests/plugins/reachy2_sdk.py b/tests/plugins/reachy2_sdk.py index f56b59efbe4..457fcf0f9ba 100644 --- a/tests/plugins/reachy2_sdk.py +++ b/tests/plugins/reachy2_sdk.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import sys import types from unittest.mock import MagicMock diff --git a/tests/policies/pi0_pi05/test_pi0.py b/tests/policies/pi0_pi05/test_pi0.py index 65f64e6bc08..b580310eb46 100644 --- a/tests/policies/pi0_pi05/test_pi0.py +++ b/tests/policies/pi0_pi05/test_pi0.py @@ -1,5 +1,19 @@ #!/usr/bin/env python +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!""" import os diff --git a/tests/policies/pi0_pi05/test_pi05.py b/tests/policies/pi0_pi05/test_pi05.py index 72828a02f79..964539446a5 100644 --- a/tests/policies/pi0_pi05/test_pi05.py +++ b/tests/policies/pi0_pi05/test_pi05.py @@ -1,5 +1,19 @@ #!/usr/bin/env python +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!""" import os diff --git a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py index 7bea8948620..0d5244e1c36 100644 --- a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!""" import os diff --git a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py index d91f716f19f..41db2dceb2e 100644 --- a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Test script to verify PI0 policy integration with LeRobot vs the original implementation, only meant to be run locally!""" import os diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 88b87312856..477381618e6 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from lerobot.processor import DataProcessorPipeline, TransitionKey diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index bc58f7a61aa..47a6eea1823 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np import pytest import torch diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index b81710db1ae..d6f87f56796 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Tests for the TokenizerProcessorStep class. """ diff --git a/tests/utils/test_io_utils.py b/tests/utils/test_io_utils.py index 9768a5ef9d9..0beea639d6a 100644 --- a/tests/utils/test_io_utils.py +++ b/tests/utils/test_io_utils.py @@ -1,4 +1,6 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import json from pathlib import Path from typing import Any diff --git a/tests/utils/test_logging_utils.py b/tests/utils/test_logging_utils.py index 927fdc14dbc..560ba570155 100644 --- a/tests/utils/test_logging_utils.py +++ b/tests/utils/test_logging_utils.py @@ -1,4 +1,6 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import pytest from lerobot.utils.logging_utils import AverageMeter, MetricsTracker diff --git a/tests/utils/test_random_utils.py b/tests/utils/test_random_utils.py index 5865361d0b5..e3a5d420f2c 100644 --- a/tests/utils/test_random_utils.py +++ b/tests/utils/test_random_utils.py @@ -1,4 +1,6 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import random import numpy as np diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 0eeaf907cc4..892503e9772 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -1,4 +1,6 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from pathlib import Path from unittest.mock import Mock, patch diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 65a97c6a305..08a82757080 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import importlib import sys from types import SimpleNamespace From 9a49e57c728f09d29876759f23b71e4d553b95c9 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 8 Oct 2025 20:06:56 +0200 Subject: [PATCH 02/12] refactor(datasets): add compress_level parameter to write_image() and set it to 1 (#2135) * refactor(datasets): add compress_level parameter to write_image() and set it to 1 * docs(dataset): add docs to write_image() --- src/lerobot/datasets/image_writer.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index 4a4e1ab0586..ee10df6e19c 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -68,7 +68,30 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) return PIL.Image.fromarray(image_array) -def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): +def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1): + """ + Saves a NumPy array or PIL Image to a file. + + This function handles both NumPy arrays and PIL Image objects, converting + the former to a PIL Image before saving. It includes error handling for + the save operation. + + Args: + image (np.ndarray | PIL.Image.Image): The image data to save. + fpath (Path): The destination file path for the image. + compress_level (int, optional): The compression level for the saved + image, as used by PIL.Image.save(). Defaults to 1. + Refer to: https://github.com/huggingface/lerobot/pull/2135 + for more details on the default value rationale. + + Raises: + TypeError: If the input 'image' is not a NumPy array or a + PIL.Image.Image object. + + Side Effects: + Prints an error message to the console if the image writing process + fails for any reason. + """ try: if isinstance(image, np.ndarray): img = image_array_to_pil_image(image) @@ -76,7 +99,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): img = image else: raise TypeError(f"Unsupported image type: {type(image)}") - img.save(fpath) + img.save(fpath, compress_level=compress_level) except Exception as e: print(f"Error writing image {fpath}: {e}") From 4ccf28437a785e453888de8e4b415dc9d35ac4e0 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Wed, 8 Oct 2025 20:07:14 +0200 Subject: [PATCH 03/12] Add act documentation (#2139) * Add act documentation * remove citation as we link the paper * simplify docs * fix pre commit --- docs/source/_toctree.yml | 2 + docs/source/act.mdx | 92 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 docs/source/act.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 36eaea165cc..3b6cccc9596 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -27,6 +27,8 @@ title: Porting Large Datasets title: "Datasets" - sections: + - local: act + title: ACT - local: smolvla title: SmolVLA - local: pi0 diff --git a/docs/source/act.mdx b/docs/source/act.mdx new file mode 100644 index 00000000000..e3294ca6945 --- /dev/null +++ b/docs/source/act.mdx @@ -0,0 +1,92 @@ +# ACT (Action Chunking with Transformers) + +ACT is a **lightweight and efficient policy for imitation learning**, especially well-suited for fine-grained manipulation tasks. It's the **first model we recommend when you're starting out** with LeRobot due to its fast training time, low computational requirements, and strong performance. + +
+ +
+ +_Watch this tutorial from the LeRobot team to learn how ACT works: [LeRobot ACT Tutorial](https://www.youtube.com/watch?v=ft73x0LfGpM)_ + +## Model Overview + +Action Chunking with Transformers (ACT) was introduced in the paper [Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware](https://arxiv.org/abs/2304.13705) by Zhao et al. The policy was designed to enable precise, contact-rich manipulation tasks using affordable hardware and minimal demonstration data. + +### Why ACT is Great for Beginners + +ACT stands out as an excellent starting point for several reasons: + +- **Fast Training**: Trains in a few hours on a single GPU +- **Lightweight**: Only ~80M parameters, making it efficient and easy to work with +- **Data Efficient**: Often achieves high success rates with just 50 demonstrations + +### Architecture + +ACT uses a transformer-based architecture with three main components: + +1. **Vision Backbone**: ResNet-18 processes images from multiple camera viewpoints +2. **Transformer Encoder**: Synthesizes information from camera features, joint positions, and a learned latent variable +3. **Transformer Decoder**: Generates coherent action sequences using cross-attention + +The policy takes as input: + +- Multiple RGB images (e.g., from wrist cameras, front/top cameras) +- Current robot joint positions +- A latent style variable `z` (learned during training, set to zero during inference) + +And outputs a chunk of `k` future action sequences. + +## Installation Requirements + +1. Install LeRobot by following our [Installation Guide](./installation). +2. ACT is included in the base LeRobot installation, so no additional dependencies are needed! + +## Training ACT + +ACT works seamlessly with the standard LeRobot training pipeline. Here's a complete example for training ACT on your dataset: + +```bash +lerobot-train \ + --dataset.repo_id=${HF_USER}/your_dataset \ + --policy.type=act \ + --output_dir=outputs/train/act_your_dataset \ + --job_name=act_your_dataset \ + --policy.device=cuda \ + --wandb.enable=true \ + --policy.repo_id=${HF_USER}/act_policy +``` + +### Training Tips + +1. **Start with defaults**: ACT's default hyperparameters work well for most tasks +2. **Training duration**: Expect a few hours for 100k training steps on a single GPU +3. **Batch size**: Start with batch size 8 and adjust based on your GPU memory + +### Train using Google Colab + +If your local computer doesn't have a powerful GPU, you can utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act). + +## Evaluating ACT + +Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes: + +```bash +lerobot-record \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM0 \ + --robot.id=my_robot \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --display_data=true \ + --dataset.repo_id=${HF_USER}/eval_act_your_dataset \ + --dataset.num_episodes=10 \ + --dataset.single_task="Your task description" \ + --policy.path=${HF_USER}/act_policy +``` From 829d2d1ad9bc0acc20fbf64f22027c615055385e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 9 Oct 2025 15:20:07 +0200 Subject: [PATCH 04/12] fic(docs): local docs links (#2149) --- docs/source/integrate_hardware.mdx | 4 ++-- docs/source/introduction_processors.mdx | 6 +++--- docs/source/phone_teleop.mdx | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx index 7e7fe0bffc7..ed9dc8dd563 100644 --- a/docs/source/integrate_hardware.mdx +++ b/docs/source/integrate_hardware.mdx @@ -8,7 +8,7 @@ To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blo - Your own robot which exposes a communication interface (e.g. serial, CAN, TCP) - A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation. -- LeRobot installed in your environment. Follow our [Installation Guide](./installation.mdx). +- LeRobot installed in your environment. Follow our [Installation Guide](./installation). ## Choose your motors @@ -65,7 +65,7 @@ class MyCoolRobotConfig(RobotConfig): ``` -[Cameras tutorial](./cameras.mdx) to understand how to detect and add your camera. +[Cameras tutorial](./cameras) to understand how to detect and add your camera. Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools. diff --git a/docs/source/introduction_processors.mdx b/docs/source/introduction_processors.mdx index 308edbb3bee..6f376861541 100644 --- a/docs/source/introduction_processors.mdx +++ b/docs/source/introduction_processors.mdx @@ -297,9 +297,9 @@ LeRobot provides many registered processor steps. Here are the most commonly use ### Next Steps -- **[Implement Your Own Processor](implement_your_own_processor.mdx)** - Create custom processor steps -- **[Debug Your Pipeline](debug_processor_pipeline.mdx)** - Troubleshoot and optimize pipelines -- **[Processors for Robots and Teleoperators](processors_robots_teleop.mdx)** - Real-world integration patterns +- **[Implement Your Own Processor](./implement_your_own_processor)** - Create custom processor steps +- **[Debug Your Pipeline](./debug_processor_pipeline)** - Troubleshoot and optimize pipelines +- **[Processors for Robots and Teleoperators](./processors_robots_teleop)** - Real-world integration patterns ## Summary diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx index 22159193cd5..76e3c367c38 100644 --- a/docs/source/phone_teleop.mdx +++ b/docs/source/phone_teleop.mdx @@ -79,7 +79,7 @@ After running the example: - Android: after starting the script, open the printed local URL on your phone, tap Start, then press and hold Move. - iOS: open HEBI Mobile I/O first; B1 enables motion. A3 controls the gripper. -Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop.mdx) guide. +Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop) guide. - Run this example to record a dataset, which saves absolute end effector observations and actions: From 656fc0f05956d5192f12b70fe4f0bbc25b17fc2e Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 10 Oct 2025 11:34:21 +0200 Subject: [PATCH 05/12] Remove validate_robot_cameras_for_policy (#2150) * Remove validate_robot_cameras_for_policy as with rename processor the image keys can be renamed an mapped * fix precommit --- src/lerobot/async_inference/configs.py | 5 ----- src/lerobot/async_inference/helpers.py | 9 --------- src/lerobot/async_inference/robot_client.py | 10 ---------- tests/async_inference/test_e2e.py | 1 - tests/async_inference/test_robot_client.py | 1 - 5 files changed, 26 deletions(-) diff --git a/src/lerobot/async_inference/configs.py b/src/lerobot/async_inference/configs.py index 24f889df17d..d1768a323e4 100644 --- a/src/lerobot/async_inference/configs.py +++ b/src/lerobot/async_inference/configs.py @@ -142,11 +142,6 @@ class RobotClientConfig: default=False, metadata={"help": "Visualize the action queue size"} ) - # Verification configuration - verify_robot_cameras: bool = field( - default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"} - ) - @property def environment_dt(self) -> float: """Environment time step, in seconds""" diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 54fad8c546a..f73cbc1dabe 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -62,15 +62,6 @@ def visualize_action_queue_size(action_queue_size: list[int]) -> None: plt.show() -def validate_robot_cameras_for_policy( - lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature] -) -> None: - image_keys = list(filter(is_image_key, lerobot_observation_features)) - assert set(image_keys) == set(policy_image_features.keys()), ( - f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}" - ) - - def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]: return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False) diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index 8c4425c6b5f..f9d70a64ece 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -48,7 +48,6 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.configs.policies import PreTrainedConfig from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -76,7 +75,6 @@ TimedObservation, get_logger, map_robot_keys_to_lerobot_features, - validate_robot_cameras_for_policy, visualize_action_queue_size, ) @@ -98,14 +96,6 @@ def __init__(self, config: RobotClientConfig): lerobot_features = map_robot_keys_to_lerobot_features(self.robot) - if config.verify_robot_cameras: - # Load policy config for validation - policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path) - policy_image_features = policy_config.image_features - - # The cameras specified for inference must match the one supported by the policy chosen - validate_robot_cameras_for_policy(lerobot_features, policy_image_features) - # Use environment variable if server_address is not provided in config self.server_address = config.server_address diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index ebaef2ef1a2..11941ce32e0 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -139,7 +139,6 @@ def _fake_send_policy_instructions(self, request, context): # noqa: N802 policy_type="test", pretrained_name_or_path="test", actions_per_chunk=20, - verify_robot_cameras=False, ) client = RobotClient(client_config) diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py index dfdb8ce4200..5b138d91bc9 100644 --- a/tests/async_inference/test_robot_client.py +++ b/tests/async_inference/test_robot_client.py @@ -51,7 +51,6 @@ def robot_client(): policy_type="test", pretrained_name_or_path="test", actions_per_chunk=20, - verify_robot_cameras=False, ) client = RobotClient(test_config) From b8f7e401d42a17d1ac90355f39a1ee7171afb58f Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 10 Oct 2025 12:32:07 +0200 Subject: [PATCH 06/12] Dataset tools (#2100) * feat(dataset-tools): add dataset utilities and example script - Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets. - Added an example script demonstrating the usage of these utilities. - Implemented comprehensive tests for all new functionalities to ensure reliability and correctness. * style fixes * move example to dataset dir * missing lisence * fixes mostly path * clean comments * move tests to functions instead of class based * - fix video editting, decode, delete frames and rencode video - copy unchanged video and parquet files to avoid recreating the entire dataset * Fortify tooling tests * Fix type issue resulting from saving numpy arrays with shape 3,1,1 * added lerobot_edit_dataset * - revert changes in examples - remove hardcoded split names * update comment * fix comment add lerobot-edit-dataset shortcut * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Michel Aractingi * style nit after copilot review * fix: bug in dataset root when editing the dataset in place (without setting new_repo_id * Fix bug in aggregate.py when accumelating video timestamps; add tests to fortify aggregate videos * Added missing output repo id * migrate delete episode to using pyav instead of decoding, writing frames to disk and encoding again. Co-authored-by: Caroline Pascal * added modified suffix in case repo_id is not set in delete_episode * adding docs for dataset tools * bump av version and add back time_base assignment * linter * modified push_to_hub logic in lerobot_edit_dataset * fix(progress bar): fixing the progress bar issue in dataset tools * chore(concatenate): removing no longer needed concatenate_datasets usage * fix(file sizes forwarding): forwarding files and chunk sizes in metadata info when splitting and aggregating datasets * style fix * refactor(aggregate): Fix video indexing and timestamp bugs in dataset merging There were three critical bugs in aggregate.py that prevented correct dataset merging: 1. Video file indices: Changed from += to = assignment to correctly reference merged video files 2. Video timestamps: Implemented per-source-file offset tracking to maintain continuous timestamps when merging split datasets (was causing non-monotonic timestamp warnings) 3. File rotation offsets: Store timestamp offsets after rotation decision to prevent out-of-bounds frame access (was causing "Invalid frame index" errors with small file size limits) Changes: - Updated update_meta_data() to apply per-source-file timestamp offsets - Updated aggregate_videos() to track offsets correctly during file rotation - Added get_video_duration_in_s import for duration calculation * Improved docs for split dataset and added a check for the possible case that the split size results in zero episodes * chore(docs): update merge documentation details Signed-off-by: Steven Palma --------- Co-authored-by: CarolinePascal Co-authored-by: Jack Vial Co-authored-by: Steven Palma --- docs/source/_toctree.yml | 2 + docs/source/using_dataset_tools.mdx | 102 ++ examples/dataset/use_dataset_tools.py | 117 +++ pyproject.toml | 3 +- src/lerobot/datasets/aggregate.py | 82 +- src/lerobot/datasets/dataset_tools.py | 1004 +++++++++++++++++++ src/lerobot/datasets/lerobot_dataset.py | 14 +- src/lerobot/datasets/utils.py | 9 +- src/lerobot/datasets/video_utils.py | 3 + src/lerobot/scripts/lerobot_edit_dataset.py | 286 ++++++ src/lerobot/utils/utils.py | 20 + tests/datasets/test_aggregate.py | 90 ++ tests/datasets/test_dataset_tools.py | 891 ++++++++++++++++ 13 files changed, 2593 insertions(+), 30 deletions(-) create mode 100644 docs/source/using_dataset_tools.mdx create mode 100644 examples/dataset/use_dataset_tools.py create mode 100644 src/lerobot/datasets/dataset_tools.py create mode 100644 src/lerobot/scripts/lerobot_edit_dataset.py create mode 100644 tests/datasets/test_dataset_tools.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 3b6cccc9596..568bd638057 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -25,6 +25,8 @@ title: Using LeRobotDataset - local: porting_datasets_v3 title: Porting Large Datasets + - local: using_dataset_tools + title: Using the Dataset Tools title: "Datasets" - sections: - local: act diff --git a/docs/source/using_dataset_tools.mdx b/docs/source/using_dataset_tools.mdx new file mode 100644 index 00000000000..affca0ee5ca --- /dev/null +++ b/docs/source/using_dataset_tools.mdx @@ -0,0 +1,102 @@ +# Using Dataset Tools + +This guide covers the dataset tools utilities available in LeRobot for modifying and editing existing datasets. + +## Overview + +LeRobot provides several utilities for manipulating datasets: + +1. **Delete Episodes** - Remove specific episodes from a dataset +2. **Split Dataset** - Divide a dataset into multiple smaller datasets +3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids` +4. **Add Features** - Add new features to a dataset +5. **Remove Features** - Remove features from a dataset + +The core implementation is in `lerobot.datasets.dataset_tools`. +An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`. + +## Command-Line Tool: lerobot-edit-dataset + +`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features. + +Run `lerobot-edit-dataset --help` for more information on the configuration of each operation. + +### Usage Examples + +#### Delete Episodes + +Remove specific episodes from a dataset. This is useful for filtering out undesired data. + +```bash +# Delete episodes 0, 2, and 5 (modifies original dataset) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +# Delete episodes and save to a new dataset (preserves original dataset) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_after_deletion \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" +``` + +#### Split Dataset + +Divide a dataset into multiple subsets. + +```bash +# Split by fractions (e.g. 80% train, 20% test, 20% val) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.8, "test": 0.2, "val": 0.2}' + +# Split by specific episode indices +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"task1": [0, 1, 2, 3], "task2": [4, 5]}' +``` + +There are no constraints on the split names, they can be determined by the user. Resulting datasets are saved under the repo id with the split name appended, e.g. `lerobot/pusht_train`, `lerobot/pusht_task1`, `lerobot/pusht_task2`. + +#### Merge Datasets + +Combine multiple datasets into a single dataset. + +```bash +# Merge train and validation splits back into one dataset +lerobot-edit-dataset \ + --repo_id lerobot/pusht_merged \ + --operation.type merge \ + --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" +``` + +#### Remove Features + +Remove features from a dataset. + +```bash +# Remove a camera feature +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type remove_feature \ + --operation.feature_names "['observation.images.top']" +``` + +### Push to Hub + +Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub: + +```bash +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_after_deletion \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" \ + --push_to_hub +``` + +There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`. diff --git a/examples/dataset/use_dataset_tools.py b/examples/dataset/use_dataset_tools.py new file mode 100644 index 00000000000..24425987239 --- /dev/null +++ b/examples/dataset/use_dataset_tools.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example script demonstrating dataset tools utilities. + +This script shows how to: +1. Delete episodes from a dataset +2. Split a dataset into train/val sets +3. Add/remove features +4. Merge datasets + +Usage: + python examples/dataset/use_dataset_tools.py +""" + +import numpy as np + +from lerobot.datasets.dataset_tools import ( + add_feature, + delete_episodes, + merge_datasets, + remove_feature, + split_dataset, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def main(): + dataset = LeRobotDataset("lerobot/pusht") + + print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames") + print(f"Features: {list(dataset.meta.features.keys())}") + + print("\n1. Deleting episodes 0 and 2...") + filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="lerobot/pusht_filtered") + print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes") + + print("\n2. Splitting dataset into train/val...") + splits = split_dataset( + dataset, + splits={"train": 0.8, "val": 0.2}, + ) + print(f"Train split: {splits['train'].meta.total_episodes} episodes") + print(f"Val split: {splits['val'].meta.total_episodes} episodes") + + print("\n3. Adding a reward feature...") + + reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32) + dataset_with_reward = add_feature( + dataset, + feature_name="reward", + feature_values=reward_values, + feature_info={ + "dtype": "float32", + "shape": (1,), + "names": None, + }, + repo_id="lerobot/pusht_with_reward", + ) + + def compute_success(row_dict, episode_index, frame_index): + episode_length = 10 + return float(frame_index >= episode_length - 10) + + dataset_with_success = add_feature( + dataset_with_reward, + feature_name="success", + feature_values=compute_success, + feature_info={ + "dtype": "float32", + "shape": (1,), + "names": None, + }, + repo_id="lerobot/pusht_with_reward_and_success", + ) + + print(f"New features: {list(dataset_with_success.meta.features.keys())}") + + print("\n4. Removing the success feature...") + dataset_cleaned = remove_feature( + dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned" + ) + print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}") + + print("\n5. Merging train and val splits back together...") + merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged") + print(f"Merged dataset: {merged.meta.total_episodes} episodes") + + print("\n6. Complex workflow example...") + + if len(dataset.meta.camera_keys) > 1: + camera_to_remove = dataset.meta.camera_keys[0] + print(f"Removing camera: {camera_to_remove}") + dataset_no_cam = remove_feature( + dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera" + ) + print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}") + + print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index c67b481f09b..a70208cb2e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ dependencies = [ "cmake>=3.29.0.1,<4.2.0", "einops>=0.8.0,<0.9.0", "opencv-python-headless>=4.9.0,<4.13.0", - "av>=14.2.0,<16.0.0", + "av>=15.0.0,<16.0.0", "jsonlines>=4.0.0,<5.0.0", "packaging>=24.2,<26.0", "pynput>=1.7.7,<1.9.0", @@ -175,6 +175,7 @@ lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" lerobot-info="lerobot.scripts.lerobot_info:main" lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" +lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.packages.find] diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 803645f292e..e7ea59ed030 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -39,7 +39,7 @@ write_stats, write_tasks, ) -from lerobot.datasets.video_utils import concatenate_video_files +from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): @@ -130,10 +130,34 @@ def update_meta_data( df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] df["data/file_index"] = df["data/file_index"] + data_idx["file"] for key, video_idx in videos_idx.items(): - df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"] - df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"] - df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"] - df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"] + # Store original video file indices before updating + orig_chunk_col = f"videos/{key}/chunk_index" + orig_file_col = f"videos/{key}/file_index" + df["_orig_chunk"] = df[orig_chunk_col].copy() + df["_orig_file"] = df[orig_file_col].copy() + + # Update chunk and file indices to point to destination + df[orig_chunk_col] = video_idx["chunk"] + df[orig_file_col] = video_idx["file"] + + # Apply per-source-file timestamp offsets + src_to_offset = video_idx.get("src_to_offset", {}) + if src_to_offset: + # Apply offset based on original source file + for idx in df.index: + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) + offset = src_to_offset.get(src_key, 0) + df.at[idx, f"videos/{key}/from_timestamp"] += offset + df.at[idx, f"videos/{key}/to_timestamp"] += offset + else: + # Fallback to simple offset (for backward compatibility) + df[f"videos/{key}/from_timestamp"] = ( + df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"] + ) + df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"] + + # Clean up temporary columns + df = df.drop(columns=["_orig_chunk", "_orig_file"]) df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"] df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"] @@ -193,6 +217,9 @@ def aggregate_datasets( robot_type=robot_type, features=features, root=aggr_root, + chunks_size=chunk_size, + data_files_size_in_mb=data_files_size_in_mb, + video_files_size_in_mb=video_files_size_in_mb, ) logging.info("Find all tasks") @@ -236,6 +263,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu Returns: dict: Updated videos_idx with current chunk and file indices. """ + for key in videos_idx: + videos_idx[key]["episode_duration"] = 0 + # Track offset for each source (chunk, file) pair + videos_idx[key]["src_to_offset"] = {} + for key, video_idx in videos_idx.items(): unique_chunk_file_pairs = { (chunk, file) @@ -249,6 +281,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu chunk_idx = video_idx["chunk"] file_idx = video_idx["file"] + current_offset = video_idx["latest_duration"] for src_chunk_idx, src_file_idx in unique_chunk_file_pairs: src_path = src_meta.root / DEFAULT_VIDEO_PATH.format( @@ -263,21 +296,24 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu file_index=file_idx, ) - # If a new file is created, we don't want to increment the latest_duration - update_latest_duration = False + src_duration = get_video_duration_in_s(src_path) if not dst_path.exists(): - # First write to this destination file + # Store offset before incrementing + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(str(src_path), str(dst_path)) - continue # not accumulating further, already copied the file in place + videos_idx[key]["episode_duration"] += src_duration + current_offset += src_duration + continue - # Check file sizes before appending src_size = get_video_size_in_mb(src_path) dst_size = get_video_size_in_mb(dst_path) if dst_size + src_size >= video_files_size_in_mb: - # Rotate to a new chunk/file + # Rotate to a new file, this source becomes start of new destination + # So its offset should be 0 + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0 chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size) dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format( video_key=key, @@ -286,25 +322,22 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu ) dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(str(src_path), str(dst_path)) + # Reset offset for next file + current_offset = src_duration else: - # Get the timestamps shift for this video - timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"] - - # Append to existing video file + # Append to existing video file - use current accumulated offset + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset concatenate_video_files( [dst_path, src_path], dst_path, ) - # Update the latest_duration when appending (shifts timestamps!) - update_latest_duration = not update_latest_duration + current_offset += src_duration + + videos_idx[key]["episode_duration"] += src_duration - # Update the videos_idx with the final chunk and file indices for this key videos_idx[key]["chunk"] = chunk_idx videos_idx[key]["file"] = file_idx - if update_latest_duration: - videos_idx[key]["latest_duration"] += timestamps_shift_s - return videos_idx @@ -389,9 +422,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): videos_idx, ) - for k in videos_idx: - videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"] - meta_idx = append_or_create_parquet_file( df, src_path, @@ -403,6 +433,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): aggr_root=dst_meta.root, ) + # Increment latest_duration by the total duration added from this source dataset + for k in videos_idx: + videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"] + return meta_idx diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py new file mode 100644 index 00000000000..fdeb24a729a --- /dev/null +++ b/src/lerobot/datasets/dataset_tools.py @@ -0,0 +1,1004 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset tools utilities for LeRobotDataset. + +This module provides utilities for: +- Deleting episodes from datasets +- Splitting datasets into multiple smaller datasets +- Adding/removing features from datasets +- Merging datasets (wrapper around aggregate functionality) +""" + +import logging +import shutil +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm + +from lerobot.datasets.aggregate import aggregate_datasets +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + DEFAULT_EPISODES_PATH, + get_parquet_file_size_in_mb, + to_parquet_with_hf_images, + update_chunk_file_indices, + write_info, + write_stats, + write_tasks, +) +from lerobot.utils.constants import HF_LEROBOT_HOME + + +def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict: + """Load a single episode's metadata including stats from parquet file. + + Args: + src_dataset: Source dataset + episode_idx: Episode index to load + + Returns: + dict containing episode metadata and stats + """ + ep_meta = src_dataset.meta.episodes[episode_idx] + chunk_idx = ep_meta["meta/episodes/chunk_index"] + file_idx = ep_meta["meta/episodes/file_index"] + + parquet_path = src_dataset.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + df = pd.read_parquet(parquet_path) + + episode_row = df[df["episode_index"] == episode_idx].iloc[0] + + return episode_row.to_dict() + + +def delete_episodes( + dataset: LeRobotDataset, + episode_indices: list[int], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Delete episodes from a LeRobotDataset and create a new dataset. + + Args: + dataset: The source LeRobotDataset. + episode_indices: List of episode indices to delete. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + """ + if not episode_indices: + raise ValueError("No episodes to delete") + + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = set(episode_indices) - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + logging.info(f"Deleting {len(episode_indices)} episodes from dataset") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_modified" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + episodes_to_keep = [i for i in range(dataset.meta.total_episodes) if i not in episode_indices] + if not episodes_to_keep: + raise ValueError("Cannot delete all episodes from dataset") + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=dataset.meta.features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + ) + + episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(episodes_to_keep)} + + video_metadata = None + if dataset.meta.video_keys: + video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping) + + data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping) + + _copy_and_reindex_episodes_metadata(dataset, new_meta, episode_mapping, data_metadata, video_metadata) + + new_dataset = LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + logging.info(f"Created new dataset with {len(episodes_to_keep)} episodes") + return new_dataset + + +def split_dataset( + dataset: LeRobotDataset, + splits: dict[str, float | list[int]], + output_dir: str | Path | None = None, +) -> dict[str, LeRobotDataset]: + """Split a LeRobotDataset into multiple smaller datasets. + + Args: + dataset: The source LeRobotDataset to split. + splits: Either a dict mapping split names to episode indices, or a dict mapping + split names to fractions (must sum to <= 1.0). + output_dir: Base directory for output datasets. If None, uses default location. + + Examples: + Split by specific episodes + splits = {"train": [0, 1, 2], "val": [3, 4]} + datasets = split_dataset(dataset, splits) + + Split by fractions + splits = {"train": 0.8, "val": 0.2} + datasets = split_dataset(dataset, splits) + """ + if not splits: + raise ValueError("No splits provided") + + if all(isinstance(v, float) for v in splits.values()): + splits = _fractions_to_episode_indices(dataset.meta.total_episodes, splits) + + all_episodes = set() + for split_name, episodes in splits.items(): + if not episodes: + raise ValueError(f"Split '{split_name}' has no episodes") + episode_set = set(episodes) + if episode_set & all_episodes: + raise ValueError("Episodes cannot appear in multiple splits") + all_episodes.update(episode_set) + + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = all_episodes - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + if output_dir is not None: + output_dir = Path(output_dir) + + result_datasets = {} + + for split_name, episodes in splits.items(): + logging.info(f"Creating split '{split_name}' with {len(episodes)} episodes") + + split_repo_id = f"{dataset.repo_id}_{split_name}" + + split_output_dir = ( + output_dir / split_name if output_dir is not None else HF_LEROBOT_HOME / split_repo_id + ) + + episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(sorted(episodes))} + + new_meta = LeRobotDatasetMetadata.create( + repo_id=split_repo_id, + fps=dataset.meta.fps, + features=dataset.meta.features, + robot_type=dataset.meta.robot_type, + root=split_output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + chunks_size=dataset.meta.chunks_size, + data_files_size_in_mb=dataset.meta.data_files_size_in_mb, + video_files_size_in_mb=dataset.meta.video_files_size_in_mb, + ) + + video_metadata = None + if dataset.meta.video_keys: + video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping) + + data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping) + + _copy_and_reindex_episodes_metadata(dataset, new_meta, episode_mapping, data_metadata, video_metadata) + + new_dataset = LeRobotDataset( + repo_id=split_repo_id, + root=split_output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + result_datasets[split_name] = new_dataset + + return result_datasets + + +def merge_datasets( + datasets: list[LeRobotDataset], + output_repo_id: str, + output_dir: str | Path | None = None, +) -> LeRobotDataset: + """Merge multiple LeRobotDatasets into a single dataset. + + This is a wrapper around the aggregate_datasets functionality with a cleaner API. + + Args: + datasets: List of LeRobotDatasets to merge. + output_repo_id: Repository ID for the merged dataset. + output_dir: Directory to save the merged dataset. If None, uses default location. + """ + if not datasets: + raise ValueError("No datasets to merge") + + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / output_repo_id + + repo_ids = [ds.repo_id for ds in datasets] + roots = [ds.root for ds in datasets] + + aggregate_datasets( + repo_ids=repo_ids, + aggr_repo_id=output_repo_id, + roots=roots, + aggr_root=output_dir, + ) + + merged_dataset = LeRobotDataset( + repo_id=output_repo_id, + root=output_dir, + image_transforms=datasets[0].image_transforms, + delta_timestamps=datasets[0].delta_timestamps, + tolerance_s=datasets[0].tolerance_s, + ) + + return merged_dataset + + +def add_feature( + dataset: LeRobotDataset, + feature_name: str, + feature_values: np.ndarray | torch.Tensor | Callable, + feature_info: dict, + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Add a new feature to a LeRobotDataset. + + Args: + dataset: The source LeRobotDataset. + feature_name: Name of the new feature. + feature_values: Either: + - Array/tensor of shape (num_frames, ...) with values for each frame + - Callable that takes (frame_dict, episode_index, frame_index) and returns feature value + feature_info: Dictionary with feature metadata (dtype, shape, names). + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + """ + if feature_name in dataset.meta.features: + raise ValueError(f"Feature '{feature_name}' already exists in dataset") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_modified" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + required_keys = {"dtype", "shape"} + if not required_keys.issubset(feature_info.keys()): + raise ValueError(f"feature_info must contain keys: {required_keys}") + + new_features = dataset.meta.features.copy() + new_features[feature_name] = feature_info + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=new_features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + ) + + _copy_data_with_feature_changes( + dataset=dataset, + new_meta=new_meta, + add_features={feature_name: (feature_values, feature_info)}, + ) + + if dataset.meta.video_keys: + _copy_videos(dataset, new_meta) + + new_dataset = LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + return new_dataset + + +def remove_feature( + dataset: LeRobotDataset, + feature_names: str | list[str], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Remove features from a LeRobotDataset. + + Args: + dataset: The source LeRobotDataset. + feature_names: Name(s) of features to remove. Can be a single string or list. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + + """ + if isinstance(feature_names, str): + feature_names = [feature_names] + + for name in feature_names: + if name not in dataset.meta.features: + raise ValueError(f"Feature '{name}' not found in dataset") + + required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"} + if any(name in required_features for name in feature_names): + raise ValueError(f"Cannot remove required features: {required_features}") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_modified" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + new_features = {k: v for k, v in dataset.meta.features.items() if k not in feature_names} + + video_keys_to_remove = [name for name in feature_names if name in dataset.meta.video_keys] + + remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove] + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=new_features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(remaining_video_keys) > 0, + ) + + _copy_data_with_feature_changes( + dataset=dataset, + new_meta=new_meta, + remove_features=feature_names, + ) + + if new_meta.video_keys: + _copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove) + + new_dataset = LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + return new_dataset + + +def _fractions_to_episode_indices( + total_episodes: int, + splits: dict[str, float], +) -> dict[str, list[int]]: + """Convert split fractions to episode indices.""" + if sum(splits.values()) > 1.0: + raise ValueError("Split fractions must sum to <= 1.0") + + indices = list(range(total_episodes)) + result = {} + start_idx = 0 + + for split_name, fraction in splits.items(): + num_episodes = int(total_episodes * fraction) + if num_episodes == 0: + logging.warning(f"Split '{split_name}' has no episodes, skipping...") + continue + end_idx = start_idx + num_episodes + if split_name == list(splits.keys())[-1]: + end_idx = total_episodes + result[split_name] = indices[start_idx:end_idx] + start_idx = end_idx + + return result + + +def _copy_and_reindex_data( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], +) -> dict[int, dict]: + """Copy and filter data files, only modifying files with deleted episodes. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + + Returns: + dict mapping episode index to its data file metadata (chunk_index, file_index, etc.) + """ + file_to_episodes: dict[Path, set[int]] = {} + for old_idx in episode_mapping: + file_path = src_dataset.meta.get_data_file_path(old_idx) + if file_path not in file_to_episodes: + file_to_episodes[file_path] = set() + file_to_episodes[file_path].add(old_idx) + + global_index = 0 + episode_data_metadata: dict[int, dict] = {} + + if dst_meta.tasks is None: + all_task_indices = set() + for src_path in file_to_episodes: + df = pd.read_parquet(src_dataset.root / src_path) + mask = df["episode_index"].isin(list(episode_mapping.keys())) + task_series: pd.Series = df[mask]["task_index"] + all_task_indices.update(task_series.unique().tolist()) + tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices] + dst_meta.save_episode_tasks(list(set(tasks))) + + task_mapping = {} + for old_task_idx in range(len(src_dataset.meta.tasks)): + task_name = src_dataset.meta.tasks.iloc[old_task_idx].name + new_task_idx = dst_meta.get_task_index(task_name) + if new_task_idx is not None: + task_mapping[old_task_idx] = new_task_idx + + for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"): + df = pd.read_parquet(src_dataset.root / src_path) + + all_episodes_in_file = set(df["episode_index"].unique()) + episodes_to_keep = file_to_episodes[src_path] + + if all_episodes_in_file == episodes_to_keep: + df["episode_index"] = df["episode_index"].replace(episode_mapping) + df["index"] = range(global_index, global_index + len(df)) + df["task_index"] = df["task_index"].replace(task_mapping) + + first_ep_old_idx = min(episodes_to_keep) + src_ep = src_dataset.meta.episodes[first_ep_old_idx] + chunk_idx = src_ep["data/chunk_index"] + file_idx = src_ep["data/file_index"] + else: + mask = df["episode_index"].isin(list(episode_mapping.keys())) + df = df[mask].copy().reset_index(drop=True) + + if len(df) == 0: + continue + + df["episode_index"] = df["episode_index"].replace(episode_mapping) + df["index"] = range(global_index, global_index + len(df)) + df["task_index"] = df["task_index"].replace(task_mapping) + + first_ep_old_idx = min(episodes_to_keep) + src_ep = src_dataset.meta.episodes[first_ep_old_idx] + chunk_idx = src_ep["data/chunk_index"] + file_idx = src_ep["data/file_index"] + + dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + dst_path.parent.mkdir(parents=True, exist_ok=True) + + if len(dst_meta.image_keys) > 0: + to_parquet_with_hf_images(df, dst_path) + else: + df.to_parquet(dst_path, index=False) + + for ep_old_idx in episodes_to_keep: + ep_new_idx = episode_mapping[ep_old_idx] + ep_df = df[df["episode_index"] == ep_new_idx] + episode_data_metadata[ep_new_idx] = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": int(ep_df["index"].min()), + "dataset_to_index": int(ep_df["index"].max() + 1), + } + + global_index += len(df) + + return episode_data_metadata + + +def _keep_episodes_from_video_with_av( + input_path: Path, + output_path: Path, + episodes_to_keep: list[tuple[float, float]], + fps: float, + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", +) -> None: + """Keep only specified episodes from a video file using PyAV. + + This function decodes frames from specified time ranges and re-encodes them with + properly reset timestamps to ensure monotonic progression. + + Args: + input_path: Source video file path. + output_path: Destination video file path. + episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep. + fps: Frame rate of the video. + vcodec: Video codec to use for encoding. + pix_fmt: Pixel format for output video. + """ + from fractions import Fraction + + import av + + if not episodes_to_keep: + raise ValueError("No episodes to keep") + + in_container = av.open(str(input_path)) + + # Check if video stream exists. + if not in_container.streams.video: + raise ValueError( + f"No video streams found in {input_path}. " + "The video file may be corrupted or empty. " + "Try re-downloading the dataset or checking the video file." + ) + + v_in = in_container.streams.video[0] + + out = av.open(str(output_path), mode="w") + + # Convert fps to Fraction for PyAV compatibility. + fps_fraction = Fraction(fps).limit_denominator(1000) + v_out = out.add_stream(vcodec, rate=fps_fraction) + + # PyAV type stubs don't distinguish video streams from audio/subtitle streams. + v_out.width = v_in.codec_context.width + v_out.height = v_in.codec_context.height + v_out.pix_fmt = pix_fmt + + # Set time_base to match the frame rate for proper timestamp handling. + v_out.time_base = Fraction(1, int(fps)) + + out.start_encoding() + + # Create set of (start, end) ranges for fast lookup. + # Convert to a sorted list for efficient checking. + time_ranges = sorted(episodes_to_keep) + + # Track frame index for setting PTS and current range being processed. + frame_count = 0 + range_idx = 0 + + # Read through entire video once and filter frames. + for packet in in_container.demux(v_in): + for frame in packet.decode(): + if frame is None: + continue + + # Get frame timestamp. + frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0 + + # Check if frame is in any of our desired time ranges. + # Skip ranges that have already passed. + while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]: + range_idx += 1 + + # If we've passed all ranges, stop processing. + if range_idx >= len(time_ranges): + break + + # Check if frame is in current range. + start_ts, end_ts = time_ranges[range_idx] + if frame_time < start_ts: + continue + + # Frame is in range - create a new frame with reset timestamps. + # We need to create a copy to avoid modifying the original. + new_frame = frame.reformat(width=v_out.width, height=v_out.height, format=v_out.pix_fmt) + new_frame.pts = frame_count + new_frame.time_base = Fraction(1, int(fps)) + + # Encode and mux the frame. + for pkt in v_out.encode(new_frame): + out.mux(pkt) + + frame_count += 1 + + # Flush encoder. + for pkt in v_out.encode(): + out.mux(pkt) + + out.close() + in_container.close() + + +def _copy_and_reindex_videos( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", +) -> dict[int, dict]: + """Copy and filter video files, only re-encoding files with deleted episodes. + + For video files that only contain kept episodes, we copy them directly. + For files with mixed kept/deleted episodes, we use PyAV filters to efficiently + re-encode only the desired segments. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + + Returns: + dict mapping episode index to its video metadata (chunk_index, file_index, timestamps) + """ + + episodes_video_metadata: dict[int, dict] = {new_idx: {} for new_idx in episode_mapping.values()} + + for video_key in src_dataset.meta.video_keys: + logging.info(f"Processing videos for {video_key}") + + if dst_meta.video_path is None: + raise ValueError("Destination metadata has no video_path defined") + + file_to_episodes: dict[tuple[int, int], list[int]] = {} + for old_idx in episode_mapping: + src_ep = src_dataset.meta.episodes[old_idx] + chunk_idx = src_ep[f"videos/{video_key}/chunk_index"] + file_idx = src_ep[f"videos/{video_key}/file_index"] + file_key = (chunk_idx, file_idx) + if file_key not in file_to_episodes: + file_to_episodes[file_key] = [] + file_to_episodes[file_key].append(old_idx) + + for (src_chunk_idx, src_file_idx), episodes_in_file in tqdm( + sorted(file_to_episodes.items()), desc=f"Processing {video_key} video files" + ): + all_episodes_in_file = [ + ep_idx + for ep_idx in range(src_dataset.meta.total_episodes) + if src_dataset.meta.episodes[ep_idx].get(f"videos/{video_key}/chunk_index") == src_chunk_idx + and src_dataset.meta.episodes[ep_idx].get(f"videos/{video_key}/file_index") == src_file_idx + ] + + episodes_to_keep_set = set(episodes_in_file) + all_in_file_set = set(all_episodes_in_file) + + if all_in_file_set == episodes_to_keep_set: + assert src_dataset.meta.video_path is not None + src_video_path = src_dataset.root / src_dataset.meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path = dst_meta.root / dst_meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(src_video_path, dst_video_path) + + for old_idx in episodes_in_file: + new_idx = episode_mapping[old_idx] + src_ep = src_dataset.meta.episodes[old_idx] + episodes_video_metadata[new_idx][f"videos/{video_key}/chunk_index"] = src_chunk_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/file_index"] = src_file_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/from_timestamp"] = src_ep[ + f"videos/{video_key}/from_timestamp" + ] + episodes_video_metadata[new_idx][f"videos/{video_key}/to_timestamp"] = src_ep[ + f"videos/{video_key}/to_timestamp" + ] + else: + # Build list of time ranges to keep, in sorted order. + sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x]) + episodes_to_keep_ranges: list[tuple[float, float]] = [] + + for old_idx in sorted_keep_episodes: + src_ep = src_dataset.meta.episodes[old_idx] + from_ts = src_ep[f"videos/{video_key}/from_timestamp"] + to_ts = src_ep[f"videos/{video_key}/to_timestamp"] + episodes_to_keep_ranges.append((from_ts, to_ts)) + + # Use PyAV filters to efficiently re-encode only the desired segments. + assert src_dataset.meta.video_path is not None + src_video_path = src_dataset.root / src_dataset.meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path = dst_meta.root / dst_meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path.parent.mkdir(parents=True, exist_ok=True) + + logging.info( + f"Re-encoding {video_key} (chunk {src_chunk_idx}, file {src_file_idx}) " + f"with {len(episodes_to_keep_ranges)} episodes" + ) + _keep_episodes_from_video_with_av( + src_video_path, + dst_video_path, + episodes_to_keep_ranges, + src_dataset.meta.fps, + vcodec, + pix_fmt, + ) + + cumulative_ts = 0.0 + for old_idx in sorted_keep_episodes: + new_idx = episode_mapping[old_idx] + src_ep = src_dataset.meta.episodes[old_idx] + ep_length = src_ep["length"] + ep_duration = ep_length / src_dataset.meta.fps + + episodes_video_metadata[new_idx][f"videos/{video_key}/chunk_index"] = src_chunk_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/file_index"] = src_file_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/from_timestamp"] = cumulative_ts + episodes_video_metadata[new_idx][f"videos/{video_key}/to_timestamp"] = ( + cumulative_ts + ep_duration + ) + + cumulative_ts += ep_duration + + return episodes_video_metadata + + +def _copy_and_reindex_episodes_metadata( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], + data_metadata: dict[int, dict], + video_metadata: dict[int, dict] | None = None, +) -> None: + """Copy and reindex episodes metadata using provided data and video metadata. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + data_metadata: Dict mapping new episode index to its data file metadata + video_metadata: Optional dict mapping new episode index to its video metadata + """ + from lerobot.datasets.utils import flatten_dict + + all_stats = [] + total_frames = 0 + + for old_idx, new_idx in tqdm( + sorted(episode_mapping.items(), key=lambda x: x[1]), desc="Processing episodes metadata" + ): + src_episode_full = _load_episode_with_stats(src_dataset, old_idx) + + src_episode = src_dataset.meta.episodes[old_idx] + + episode_meta = data_metadata[new_idx].copy() + + if video_metadata and new_idx in video_metadata: + episode_meta.update(video_metadata[new_idx]) + + # Extract episode statistics from parquet metadata. + # Note (maractingi): When pandas/pyarrow serializes numpy arrays with shape (3, 1, 1) to parquet, + # they are being deserialized as nested object arrays like: + # array([array([array([0.])]), array([array([0.])]), array([array([0.])])]) + # This happens particularly with image/video statistics. We need to detect and flatten + # these nested structures back to proper (3, 1, 1) arrays so aggregate_stats can process them. + episode_stats = {} + for key in src_episode_full: + if key.startswith("stats/"): + stat_key = key.replace("stats/", "") + parts = stat_key.split("/") + if len(parts) == 2: + feature_name, stat_name = parts + if feature_name not in episode_stats: + episode_stats[feature_name] = {} + + value = src_episode_full[key] + + if feature_name in src_dataset.meta.features: + feature_dtype = src_dataset.meta.features[feature_name]["dtype"] + if feature_dtype in ["image", "video"] and stat_name != "count": + if isinstance(value, np.ndarray) and value.dtype == object: + flat_values = [] + for item in value: + while isinstance(item, np.ndarray): + item = item.flatten()[0] + flat_values.append(item) + value = np.array(flat_values, dtype=np.float64).reshape(3, 1, 1) + elif isinstance(value, np.ndarray) and value.shape == (3,): + value = value.reshape(3, 1, 1) + + episode_stats[feature_name][stat_name] = value + + all_stats.append(episode_stats) + + episode_dict = { + "episode_index": new_idx, + "tasks": src_episode["tasks"], + "length": src_episode["length"], + } + episode_dict.update(episode_meta) + episode_dict.update(flatten_dict({"stats": episode_stats})) + dst_meta._save_episode_metadata(episode_dict) + + total_frames += src_episode["length"] + + dst_meta.info.update( + { + "total_episodes": len(episode_mapping), + "total_frames": total_frames, + "total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0, + "splits": {"train": f"0:{len(episode_mapping)}"}, + } + ) + write_info(dst_meta.info, dst_meta.root) + + if not all_stats: + logging.warning("No statistics found to aggregate") + return + + logging.info(f"Aggregating statistics for {len(all_stats)} episodes") + aggregated_stats = aggregate_stats(all_stats) + filtered_stats = {k: v for k, v in aggregated_stats.items() if k in dst_meta.features} + write_stats(filtered_stats, dst_meta.root) + + +def _save_data_chunk( + df: pd.DataFrame, + meta: LeRobotDatasetMetadata, + chunk_idx: int = 0, + file_idx: int = 0, +) -> tuple[int, int, dict[int, dict]]: + """Save a data chunk and return updated indices and episode metadata. + + Returns: + tuple: (next_chunk_idx, next_file_idx, episode_metadata_dict) + where episode_metadata_dict maps episode_index to its data file metadata + """ + path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + + if len(meta.image_keys) > 0: + to_parquet_with_hf_images(df, path) + else: + df.to_parquet(path, index=False) + + episode_metadata = {} + for ep_idx in df["episode_index"].unique(): + ep_df = df[df["episode_index"] == ep_idx] + episode_metadata[ep_idx] = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": int(ep_df["index"].min()), + "dataset_to_index": int(ep_df["index"].max() + 1), + } + + file_size = get_parquet_file_size_in_mb(path) + if file_size >= DEFAULT_DATA_FILE_SIZE_IN_MB * 0.9: + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE) + + return chunk_idx, file_idx, episode_metadata + + +def _copy_data_with_feature_changes( + dataset: LeRobotDataset, + new_meta: LeRobotDatasetMetadata, + add_features: dict[str, tuple] | None = None, + remove_features: list[str] | None = None, +) -> None: + """Copy data while adding or removing features.""" + file_paths = set() + for ep_idx in range(dataset.meta.total_episodes): + file_paths.add(dataset.meta.get_data_file_path(ep_idx)) + + frame_idx = 0 + + for src_path in tqdm(sorted(file_paths), desc="Processing data files"): + df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True) + + if remove_features: + df = df.drop(columns=remove_features, errors="ignore") + + if add_features: + for feature_name, (values, _) in add_features.items(): + if callable(values): + feature_values = [] + for _, row in df.iterrows(): + ep_idx = row["episode_index"] + frame_in_ep = row["frame_index"] + value = values(row.to_dict(), ep_idx, frame_in_ep) + if isinstance(value, np.ndarray) and value.size == 1: + value = value.item() + feature_values.append(value) + df[feature_name] = feature_values + else: + end_idx = frame_idx + len(df) + feature_slice = values[frame_idx:end_idx] + if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1: + df[feature_name] = feature_slice.flatten() + else: + df[feature_name] = feature_slice + frame_idx = end_idx + + _save_data_chunk(df, new_meta) + + _copy_episodes_metadata_and_stats(dataset, new_meta) + + +def _copy_videos( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + exclude_keys: list[str] | None = None, +) -> None: + """Copy video files, optionally excluding certain keys.""" + if exclude_keys is None: + exclude_keys = [] + + for video_key in src_dataset.meta.video_keys: + if video_key in exclude_keys: + continue + + video_files = set() + for ep_idx in range(len(src_dataset.meta.episodes)): + try: + video_files.add(src_dataset.meta.get_video_file_path(ep_idx, video_key)) + except KeyError: + continue + + for src_path in tqdm(sorted(video_files), desc=f"Copying {video_key} videos"): + dst_path = dst_meta.root / src_path + dst_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(src_dataset.root / src_path, dst_path) + + +def _copy_episodes_metadata_and_stats( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, +) -> None: + """Copy episodes metadata and recalculate stats.""" + if src_dataset.meta.tasks is not None: + write_tasks(src_dataset.meta.tasks, dst_meta.root) + dst_meta.tasks = src_dataset.meta.tasks.copy() + + episodes_dir = src_dataset.root / "meta/episodes" + dst_episodes_dir = dst_meta.root / "meta/episodes" + if episodes_dir.exists(): + shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True) + + dst_meta.info.update( + { + "total_episodes": src_dataset.meta.total_episodes, + "total_frames": src_dataset.meta.total_frames, + "total_tasks": src_dataset.meta.total_tasks, + "splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}), + } + ) + + if dst_meta.video_keys and src_dataset.meta.video_keys: + for key in dst_meta.video_keys: + if key in src_dataset.meta.features: + dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get( + "info", {} + ) + + write_info(dst_meta.info, dst_meta.root) + + if set(dst_meta.features.keys()) != set(src_dataset.meta.features.keys()): + logging.info("Recalculating dataset statistics...") + if src_dataset.meta.stats: + new_stats = {} + for key in dst_meta.features: + if key in src_dataset.meta.stats: + new_stats[key] = src_dataset.meta.stats[key] + write_stats(new_stats, dst_meta.root) + else: + if src_dataset.meta.stats: + write_stats(src_dataset.meta.stats, dst_meta.root) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b661b21b038..229d376413a 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -438,6 +438,9 @@ def create( robot_type: str | None = None, root: str | Path | None = None, use_videos: bool = True, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, ) -> "LeRobotDatasetMetadata": """Creates metadata for a LeRobotDataset.""" obj = cls.__new__(cls) @@ -452,7 +455,16 @@ def create( obj.tasks = None obj.episodes = None obj.stats = None - obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type) + obj.info = create_empty_dataset_info( + CODEBASE_VERSION, + fps, + features, + use_videos, + robot_type, + chunks_size, + data_files_size_in_mb, + video_files_size_in_mb, + ) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index a2f2850141a..422a7010a6e 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -30,7 +30,7 @@ import pandas as pd import pyarrow.parquet as pq import torch -from datasets import Dataset, concatenate_datasets +from datasets import Dataset from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError @@ -44,7 +44,7 @@ ForwardCompatibilityError, ) from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR -from lerobot.utils.utils import is_valid_numpy_dtype_string +from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file @@ -123,8 +123,9 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") # TODO(rcadene): set num_proc to accelerate conversion to pyarrow - datasets = [Dataset.from_parquet(str(path), features=features) for path in paths] - return concatenate_datasets(datasets) + with SuppressProgressBars(): + datasets = Dataset.from_parquet([str(path) for path in paths], features=features) + return datasets def get_parquet_num_frames(parquet_path: str | Path) -> int: diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 1d4f07c769a..620ba863aca 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -452,6 +452,9 @@ def concatenate_video_files( template=input_stream, opaque=True ) + # set the time base to the input stream time base (missing in the codec context) + stream_map[input_stream.index].time_base = input_stream.time_base + # Demux + remux packets (no re-encode) for packet in input_container.demux(): # Skip packets from un-mapped streams diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py new file mode 100644 index 00000000000..83ba027bcce --- /dev/null +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Edit LeRobot datasets using various transformation tools. + +This script allows you to delete episodes, split datasets, merge datasets, +and remove features. When new_repo_id is specified, creates a new dataset. + +Usage Examples: + +Delete episodes 0, 2, and 5 from a dataset: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +Delete episodes and save to a new dataset: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_filtered \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +Split dataset by fractions: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.8, "val": 0.2}' + +Split dataset by episode indices: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}' + +Split into more than two splits: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}' + +Merge multiple datasets: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_merged \ + --operation.type merge \ + --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" + +Remove camera feature: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type remove_feature \ + --operation.feature_names "['observation.images.top']" + +Using JSON config file: + python -m lerobot.scripts.lerobot_edit_dataset \ + --config_path path/to/edit_config.json +""" + +import logging +import shutil +from dataclasses import dataclass +from pathlib import Path + +from lerobot.configs import parser +from lerobot.datasets.dataset_tools import ( + delete_episodes, + merge_datasets, + remove_feature, + split_dataset, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.utils import init_logging + + +@dataclass +class DeleteEpisodesConfig: + type: str = "delete_episodes" + episode_indices: list[int] | None = None + + +@dataclass +class SplitConfig: + type: str = "split" + splits: dict[str, float | list[int]] | None = None + + +@dataclass +class MergeConfig: + type: str = "merge" + repo_ids: list[str] | None = None + + +@dataclass +class RemoveFeatureConfig: + type: str = "remove_feature" + feature_names: list[str] | None = None + + +@dataclass +class EditDatasetConfig: + repo_id: str + operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig + root: str | None = None + new_repo_id: str | None = None + push_to_hub: bool = False + + +def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]: + if new_repo_id: + output_repo_id = new_repo_id + output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id + else: + output_repo_id = repo_id + dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id + old_path = Path(str(dataset_path) + "_old") + + if dataset_path.exists(): + if old_path.exists(): + shutil.rmtree(old_path) + shutil.move(str(dataset_path), str(old_path)) + + output_dir = dataset_path + + return output_repo_id, output_dir + + +def handle_delete_episodes(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, DeleteEpisodesConfig): + raise ValueError("Operation config must be DeleteEpisodesConfig") + + if not cfg.operation.episode_indices: + raise ValueError("episode_indices must be specified for delete_episodes operation") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + output_repo_id, output_dir = get_output_path( + cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + ) + + if cfg.new_repo_id is None: + dataset.root = Path(str(dataset.root) + "_old") + + logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}") + new_dataset = delete_episodes( + dataset, + episode_indices=cfg.operation.episode_indices, + output_dir=output_dir, + repo_id=output_repo_id, + ) + + logging.info(f"Dataset saved to {output_dir}") + logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}") + LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() + + +def handle_split(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, SplitConfig): + raise ValueError("Operation config must be SplitConfig") + + if not cfg.operation.splits: + raise ValueError( + "splits dict must be specified with split names as keys and fractions/episode lists as values" + ) + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + + logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}") + split_datasets = split_dataset(dataset, splits=cfg.operation.splits) + + for split_name, split_ds in split_datasets.items(): + split_repo_id = f"{cfg.repo_id}_{split_name}" + logging.info( + f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames" + ) + + if cfg.push_to_hub: + logging.info(f"Pushing {split_name} split to hub as {split_repo_id}") + LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub() + + +def handle_merge(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, MergeConfig): + raise ValueError("Operation config must be MergeConfig") + + if not cfg.operation.repo_ids: + raise ValueError("repo_ids must be specified for merge operation") + + if not cfg.repo_id: + raise ValueError("repo_id must be specified as the output repository for merged dataset") + + logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge") + datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids] + + output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id + + logging.info(f"Merging datasets into {cfg.repo_id}") + merged_dataset = merge_datasets( + datasets, + output_repo_id=cfg.repo_id, + output_dir=output_dir, + ) + + logging.info(f"Merged dataset saved to {output_dir}") + logging.info( + f"Episodes: {merged_dataset.meta.total_episodes}, Frames: {merged_dataset.meta.total_frames}" + ) + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {cfg.repo_id}") + LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub() + + +def handle_remove_feature(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, RemoveFeatureConfig): + raise ValueError("Operation config must be RemoveFeatureConfig") + + if not cfg.operation.feature_names: + raise ValueError("feature_names must be specified for remove_feature operation") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + output_repo_id, output_dir = get_output_path( + cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + ) + + if cfg.new_repo_id is None: + dataset.root = Path(str(dataset.root) + "_old") + + logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}") + new_dataset = remove_feature( + dataset, + feature_names=cfg.operation.feature_names, + output_dir=output_dir, + repo_id=output_repo_id, + ) + + logging.info(f"Dataset saved to {output_dir}") + logging.info(f"Remaining features: {list(new_dataset.meta.features.keys())}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}") + LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() + + +@parser.wrap() +def edit_dataset(cfg: EditDatasetConfig) -> None: + operation_type = cfg.operation.type + + if operation_type == "delete_episodes": + handle_delete_episodes(cfg) + elif operation_type == "split": + handle_split(cfg) + elif operation_type == "merge": + handle_merge(cfg) + elif operation_type == "remove_feature": + handle_remove_feature(cfg) + else: + raise ValueError( + f"Unknown operation type: {operation_type}\n" + f"Available operations: delete_episodes, split, merge, remove_feature" + ) + + +def main() -> None: + init_logging() + edit_dataset() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 8777d5a9db8..dfcd4a6b105 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -27,6 +27,7 @@ import numpy as np import torch +from datasets.utils.logging import disable_progress_bar, enable_progress_bar def inside_slurm(): @@ -247,6 +248,25 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float): return days, hours, minutes, seconds +class SuppressProgressBars: + """ + Context manager to suppress progress bars. + + Example + -------- + ```python + with SuppressProgressBars(): + # Code that would normally show progress bars + ``` + """ + + def __enter__(self): + disable_progress_bar() + + def __exit__(self, exc_type, exc_val, exc_tb): + enable_progress_bar() + + class TimerManager: """ Lightweight utility to measure elapsed time. diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 4f316f80eb2..b710a3a4be2 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -181,6 +181,54 @@ def assert_dataset_iteration_works(aggr_ds): pass +def assert_video_timestamps_within_bounds(aggr_ds): + """Test that all video timestamps are within valid bounds for their respective video files. + + This catches bugs where timestamps point to frames beyond the actual video length, + which would cause "Invalid frame index" errors during data loading. + """ + try: + from torchcodec.decoders import VideoDecoder + except ImportError: + return + + for ep_idx in range(aggr_ds.num_episodes): + ep = aggr_ds.meta.episodes[ep_idx] + + for vid_key in aggr_ds.meta.video_keys: + from_ts = ep[f"videos/{vid_key}/from_timestamp"] + to_ts = ep[f"videos/{vid_key}/to_timestamp"] + video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key) + + if not video_path.exists(): + continue + + from_frame_idx = round(from_ts * aggr_ds.fps) + to_frame_idx = round(to_ts * aggr_ds.fps) + + try: + decoder = VideoDecoder(str(video_path)) + num_frames = len(decoder) + + # Verify timestamps don't exceed video bounds + assert from_frame_idx >= 0, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0" + ) + assert from_frame_idx < num_frames, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})" + ) + assert to_frame_idx <= num_frames, ( + f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})" + ) + assert from_frame_idx < to_frame_idx, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})" + ) + except Exception as e: + raise AssertionError( + f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}" + ) from e + + def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): """Test basic aggregation functionality with standard parameters.""" ds_0_num_frames = 400 @@ -227,6 +275,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): assert_metadata_consistency(aggr_ds, ds_0, ds_1) assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_video_timestamps_within_bounds(aggr_ds) assert_dataset_iteration_works(aggr_ds) @@ -277,6 +326,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): assert_metadata_consistency(aggr_ds, ds_0, ds_1) assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_video_timestamps_within_bounds(aggr_ds) assert_dataset_iteration_works(aggr_ds) # Check that multiple files were actually created due to small size limits @@ -290,3 +340,43 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): if video_dir.exists(): video_files = list(video_dir.rglob("*.mp4")) assert len(video_files) > 1, "Small file size limits should create multiple video files" + + +def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): + """Regression test for video timestamp bug when merging datasets. + + This test specifically checks that video timestamps are correctly calculated + and accumulated when merging multiple datasets. + """ + datasets = [] + for i in range(3): + ds = lerobot_dataset_factory( + root=tmp_path / f"regression_{i}", + repo_id=f"{DUMMY_REPO_ID}_regression_{i}", + total_episodes=2, + total_frames=100, + ) + datasets.append(ds) + + aggregate_datasets( + repo_ids=[ds.repo_id for ds in datasets], + roots=[ds.root for ds in datasets], + aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr", + aggr_root=tmp_path / "regression_aggr", + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "regression_aggr") + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr") + + assert_video_timestamps_within_bounds(aggr_ds) + + for i in range(len(aggr_ds)): + item = aggr_ds[i] + for key in aggr_ds.meta.video_keys: + assert key in item, f"Video key {key} missing from item {i}" + assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}" diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py new file mode 100644 index 00000000000..fe117b35b80 --- /dev/null +++ b/tests/datasets/test_dataset_tools.py @@ -0,0 +1,891 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for dataset tools utilities.""" + +from unittest.mock import patch + +import numpy as np +import pytest +import torch + +from lerobot.datasets.dataset_tools import ( + add_feature, + delete_episodes, + merge_datasets, + remove_feature, + split_dataset, +) + + +@pytest.fixture +def sample_dataset(tmp_path, empty_lerobot_dataset_factory): + """Create a sample dataset for testing.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset", + features=features, + ) + + for ep_idx in range(5): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset.add_frame(frame) + dataset.save_episode() + + return dataset + + +def test_delete_single_episode(sample_dataset, tmp_path): + """Test deleting a single episode.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[2], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 4 + assert new_dataset.meta.total_frames == 40 + + episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]} + assert episode_indices == {0, 1, 2, 3} + + assert len(new_dataset) == 40 + + +def test_delete_multiple_episodes(sample_dataset, tmp_path): + """Test deleting multiple episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[1, 3], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 3 + assert new_dataset.meta.total_frames == 30 + + episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]} + assert episode_indices == {0, 1, 2} + + +def test_delete_invalid_episodes(sample_dataset, tmp_path): + """Test error handling for invalid episode indices.""" + with pytest.raises(ValueError, match="Invalid episode indices"): + delete_episodes( + sample_dataset, + episode_indices=[10, 20], + output_dir=tmp_path / "filtered", + ) + + +def test_delete_all_episodes(sample_dataset, tmp_path): + """Test error when trying to delete all episodes.""" + with pytest.raises(ValueError, match="Cannot delete all episodes"): + delete_episodes( + sample_dataset, + episode_indices=list(range(5)), + output_dir=tmp_path / "filtered", + ) + + +def test_delete_empty_list(sample_dataset, tmp_path): + """Test error when no episodes specified.""" + with pytest.raises(ValueError, match="No episodes to delete"): + delete_episodes( + sample_dataset, + episode_indices=[], + output_dir=tmp_path / "filtered", + ) + + +def test_split_by_episodes(sample_dataset, tmp_path): + """Test splitting dataset by specific episode indices.""" + splits = { + "train": [0, 1, 2], + "val": [3, 4], + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + if "train" in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_train") + elif "val" in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_val") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert set(result.keys()) == {"train", "val"} + + assert result["train"].meta.total_episodes == 3 + assert result["train"].meta.total_frames == 30 + + assert result["val"].meta.total_episodes == 2 + assert result["val"].meta.total_frames == 20 + + train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]} + assert train_episodes == {0, 1, 2} + + val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]} + assert val_episodes == {0, 1} + + +def test_split_by_fractions(sample_dataset, tmp_path): + """Test splitting dataset by fractions.""" + splits = { + "train": 0.6, + "val": 0.4, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert result["train"].meta.total_episodes == 3 + assert result["val"].meta.total_episodes == 2 + + +def test_split_overlapping_episodes(sample_dataset, tmp_path): + """Test error when episodes appear in multiple splits.""" + splits = { + "train": [0, 1, 2], + "val": [2, 3, 4], + } + + with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"): + split_dataset(sample_dataset, splits=splits, output_dir=tmp_path) + + +def test_split_invalid_fractions(sample_dataset, tmp_path): + """Test error when fractions sum to more than 1.""" + splits = { + "train": 0.7, + "val": 0.5, + } + + with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"): + split_dataset(sample_dataset, splits=splits, output_dir=tmp_path) + + +def test_split_empty(sample_dataset, tmp_path): + """Test error with empty splits.""" + with pytest.raises(ValueError, match="No splits provided"): + split_dataset(sample_dataset, splits={}, output_dir=tmp_path) + + +def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test merging two datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset2 = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset2", + features=features, + ) + + for ep_idx in range(3): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset2.add_frame(frame) + dataset2.save_episode() + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + [sample_dataset, dataset2], + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.total_episodes == 8 # 5 + 3 + assert merged.meta.total_frames == 80 # 50 + 30 + + episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]}) + assert episode_indices == list(range(8)) + + +def test_merge_empty_list(tmp_path): + """Test error when merging empty list.""" + with pytest.raises(ValueError, match="No datasets to merge"): + merge_datasets([], output_repo_id="merged", output_dir=tmp_path) + + +def test_add_feature_with_values(sample_dataset, tmp_path): + """Test adding a feature with pre-computed values.""" + num_frames = sample_dataset.meta.total_frames + reward_values = np.random.randn(num_frames, 1).astype(np.float32) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=reward_values, + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + assert "reward" in new_dataset.meta.features + assert new_dataset.meta.features["reward"] == feature_info + + assert len(new_dataset) == num_frames + sample_item = new_dataset[0] + assert "reward" in sample_item + assert isinstance(sample_item["reward"], torch.Tensor) + + +def test_add_feature_with_callable(sample_dataset, tmp_path): + """Test adding a feature with a callable.""" + + def compute_reward(frame_dict, episode_idx, frame_idx): + return float(episode_idx * 10 + frame_idx) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=compute_reward, + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + assert "reward" in new_dataset.meta.features + + items = [new_dataset[i] for i in range(10)] + first_episode_items = [item for item in items if item["episode_index"] == 0] + assert len(first_episode_items) == 10 + + first_frame = first_episode_items[0] + assert first_frame["frame_index"] == 0 + assert float(first_frame["reward"]) == 0.0 + + +def test_add_existing_feature(sample_dataset, tmp_path): + """Test error when adding an existing feature.""" + feature_info = {"dtype": "float32", "shape": (1,)} + + with pytest.raises(ValueError, match="Feature 'action' already exists"): + add_feature( + sample_dataset, + feature_name="action", + feature_values=np.zeros(50), + feature_info=feature_info, + output_dir=tmp_path / "modified", + ) + + +def test_add_feature_invalid_info(sample_dataset, tmp_path): + """Test error with invalid feature info.""" + with pytest.raises(ValueError, match="feature_info must contain keys"): + add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.zeros(50), + feature_info={"dtype": "float32"}, + output_dir=tmp_path / "modified", + ) + + +def test_remove_single_feature(sample_dataset, tmp_path): + """Test removing a single feature.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset_with_reward = add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.random.randn(50, 1).astype(np.float32), + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + dataset_without_reward = remove_feature( + dataset_with_reward, + feature_names="reward", + output_dir=tmp_path / "without_reward", + ) + + assert "reward" not in dataset_without_reward.meta.features + + sample_item = dataset_without_reward[0] + assert "reward" not in sample_item + + +def test_remove_multiple_features(sample_dataset, tmp_path): + """Test removing multiple features at once.""" + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset = sample_dataset + for feature_name in ["reward", "success"]: + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + dataset = add_feature( + dataset, + feature_name=feature_name, + feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32), + feature_info=feature_info, + output_dir=tmp_path / f"with_{feature_name}", + ) + + dataset_clean = remove_feature( + dataset, + feature_names=["reward", "success"], + output_dir=tmp_path / "clean", + ) + + assert "reward" not in dataset_clean.meta.features + assert "success" not in dataset_clean.meta.features + + +def test_remove_nonexistent_feature(sample_dataset, tmp_path): + """Test error when removing non-existent feature.""" + with pytest.raises(ValueError, match="Feature 'nonexistent' not found"): + remove_feature( + sample_dataset, + feature_names="nonexistent", + output_dir=tmp_path / "modified", + ) + + +def test_remove_required_feature(sample_dataset, tmp_path): + """Test error when trying to remove required features.""" + with pytest.raises(ValueError, match="Cannot remove required features"): + remove_feature( + sample_dataset, + feature_names="timestamp", + output_dir=tmp_path / "modified", + ) + + +def test_remove_camera_feature(sample_dataset, tmp_path): + """Test removing a camera feature.""" + camera_keys = sample_dataset.meta.camera_keys + if not camera_keys: + pytest.skip("No camera keys in dataset") + + camera_to_remove = camera_keys[0] + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "without_camera") + + dataset_without_camera = remove_feature( + sample_dataset, + feature_names=camera_to_remove, + output_dir=tmp_path / "without_camera", + ) + + assert camera_to_remove not in dataset_without_camera.meta.features + assert camera_to_remove not in dataset_without_camera.meta.camera_keys + + sample_item = dataset_without_camera[0] + assert camera_to_remove not in sample_item + + +def test_complex_workflow_integration(sample_dataset, tmp_path): + """Test a complex workflow combining multiple operations.""" + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.random.randn(50, 1).astype(np.float32), + feature_info={"dtype": "float32", "shape": (1,), "names": None}, + output_dir=tmp_path / "step1", + ) + + dataset = delete_episodes( + dataset, + episode_indices=[2], + output_dir=tmp_path / "step2", + ) + + splits = split_dataset( + dataset, + splits={"train": 0.75, "val": 0.25}, + output_dir=tmp_path / "step3", + ) + + merged = merge_datasets( + list(splits.values()), + output_repo_id="final_dataset", + output_dir=tmp_path / "step4", + ) + + assert merged.meta.total_episodes == 4 + assert merged.meta.total_frames == 40 + assert "reward" in merged.meta.features + + assert len(merged) == 40 + sample_item = merged[0] + assert "reward" in sample_item + + +def test_delete_episodes_preserves_stats(sample_dataset, tmp_path): + """Test that deleting episodes preserves statistics correctly.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[2], + output_dir=output_dir, + ) + + assert new_dataset.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in new_dataset.meta.stats + assert "mean" in new_dataset.meta.stats[feature] + assert "std" in new_dataset.meta.stats[feature] + + +def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path): + """Test that tasks are preserved correctly after deletion.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[0], + output_dir=output_dir, + ) + + assert new_dataset.meta.tasks is not None + assert len(new_dataset.meta.tasks) == 2 + + tasks_in_dataset = {str(item["task"]) for item in new_dataset} + assert len(tasks_in_dataset) > 0 + + +def test_split_three_ways(sample_dataset, tmp_path): + """Test splitting dataset into three splits.""" + splits = { + "train": 0.6, + "val": 0.2, + "test": 0.2, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert set(result.keys()) == {"train", "val", "test"} + assert result["train"].meta.total_episodes == 3 + assert result["val"].meta.total_episodes == 1 + assert result["test"].meta.total_episodes == 1 + + total_frames = sum(ds.meta.total_frames for ds in result.values()) + assert total_frames == sample_dataset.meta.total_frames + + +def test_split_preserves_stats(sample_dataset, tmp_path): + """Test that statistics are preserved when splitting.""" + splits = {"train": [0, 1, 2], "val": [3, 4]} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + for split_ds in result.values(): + assert split_ds.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in split_ds.meta.stats + assert "mean" in split_ds.meta.stats[feature] + assert "std" in split_ds.meta.stats[feature] + + +def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test merging three datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + datasets = [sample_dataset] + + for i in range(2): + dataset = empty_lerobot_dataset_factory( + root=tmp_path / f"test_dataset{i + 2}", + features=features, + ) + + for ep_idx in range(2): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx}", + } + dataset.add_frame(frame) + dataset.save_episode() + + datasets.append(dataset) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + datasets, + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.total_episodes == 9 + assert merged.meta.total_frames == 90 + + +def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test that statistics are computed for merged datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset2 = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset2", + features=features, + ) + + for ep_idx in range(3): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset2.add_frame(frame) + dataset2.save_episode() + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + [sample_dataset, dataset2], + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in merged.meta.stats + assert "mean" in merged.meta.stats[feature] + assert "std" in merged.meta.stats[feature] + + +def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path): + """Test that adding a feature preserves existing stats.""" + num_frames = sample_dataset.meta.total_frames + reward_values = np.random.randn(num_frames, 1).astype(np.float32) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=reward_values, + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + assert new_dataset.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in new_dataset.meta.stats + assert "mean" in new_dataset.meta.stats[feature] + assert "std" in new_dataset.meta.stats[feature] + + +def test_remove_feature_updates_stats(sample_dataset, tmp_path): + """Test that removing a feature removes it from stats.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset_with_reward = add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.random.randn(50, 1).astype(np.float32), + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + dataset_without_reward = remove_feature( + dataset_with_reward, + feature_names="reward", + output_dir=tmp_path / "without_reward", + ) + + if dataset_without_reward.meta.stats: + assert "reward" not in dataset_without_reward.meta.stats + + +def test_delete_consecutive_episodes(sample_dataset, tmp_path): + """Test deleting consecutive episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[1, 2, 3], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 2 + assert new_dataset.meta.total_frames == 20 + + episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) + assert episode_indices == [0, 1] + + +def test_delete_first_and_last_episodes(sample_dataset, tmp_path): + """Test deleting first and last episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[0, 4], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 3 + assert new_dataset.meta.total_frames == 30 + + episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) + assert episode_indices == [0, 1, 2] + + +def test_split_all_episodes_assigned(sample_dataset, tmp_path): + """Test that all episodes can be explicitly assigned to splits.""" + splits = { + "split1": [0, 1], + "split2": [2, 3], + "split3": [4], + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + total_episodes = sum(ds.meta.total_episodes for ds in result.values()) + assert total_episodes == sample_dataset.meta.total_episodes From 0699b46d87ded6e2394f4144dd9c92b2a5e4f1b8 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 10 Oct 2025 20:41:37 +0200 Subject: [PATCH 07/12] refactor(envs): add custom-observation-size (#2167) --- src/lerobot/envs/configs.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 0daaaf9fd1f..7a979b8645d 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -50,6 +50,8 @@ class AlohaEnv(EnvConfig): fps: int = 50 episode_length: int = 400 obs_type: str = "pixels_agent_pos" + observation_height: int = 480 + observation_width: int = 640 render_mode: str = "rgb_array" features: dict[str, PolicyFeature] = field( default_factory=lambda: { @@ -67,10 +69,14 @@ class AlohaEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels": - self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + self.features["top"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) elif self.obs_type == "pixels_agent_pos": self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,)) - self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + self.features["pixels/top"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) @property def gym_kwargs(self) -> dict: @@ -91,6 +97,8 @@ class PushtEnv(EnvConfig): render_mode: str = "rgb_array" visualization_width: int = 384 visualization_height: int = 384 + observation_height: int = 384 + observation_width: int = 384 features: dict[str, PolicyFeature] = field( default_factory=lambda: { ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), @@ -108,7 +116,9 @@ class PushtEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels_agent_pos": - self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3)) + self.features["pixels"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) elif self.obs_type == "environment_state_agent_pos": self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,)) @@ -255,6 +265,8 @@ class LiberoEnv(EnvConfig): camera_name: str = "agentview_image,robot0_eye_in_hand_image" init_states: bool = True camera_name_mapping: dict[str, str] | None = None + observation_height: int = 360 + observation_width: int = 360 features: dict[str, PolicyFeature] = field( default_factory=lambda: { ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)), @@ -272,18 +284,18 @@ class LiberoEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels": self.features["pixels/agentview_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) elif self.obs_type == "pixels_agent_pos": self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,)) self.features["pixels/agentview_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) else: raise ValueError(f"Unsupported obs_type: {self.obs_type}") From 25f60c301b6201b0eeb7bff2787c299f79a0dc40 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Sat, 11 Oct 2025 00:15:42 +0200 Subject: [PATCH 08/12] use TeleopEvents.RERECORD_EPISODE in gym_manipulator (#2165) Co-authored-by: Michel Aractingi --- src/lerobot/rl/gym_manipulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index ad36f1b3641..f9c9d0d7a71 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -696,7 +696,7 @@ def control_loop( episode_idx += 1 if dataset is not None: - if transition[TransitionKey.INFO].get("rerecord_episode", False): + if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False): logging.info(f"Re-recording episode {episode_idx}") dataset.clear_episode_buffer() episode_idx -= 1 From f2ff370459a9027319a8ab405fbe0d7c019a327e Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Sat, 11 Oct 2025 11:01:30 +0200 Subject: [PATCH 09/12] Incremental parquet writing (#1903) * incremental parquet writing * add .finalise() and a backup __del__ for stopping writers * fix missing import * precommit fixes added back the use of embed images * added lazy loading for hf_Dataset to avoid frequently reloading the dataset during recording * fix bug in video timestamps * Added proper closing of parquet file before reading * Added rigorous testing to validate the consistency of the meta data after creation of a new dataset * fix bug in episode index during clear_episode_buffer * fix(empty concat): check for empty paths list before data files concatenation * fix(v3.0 message): updating v3.0 backward compatibility message. * added fixes for the resume logic * answering co-pilot review * reverting some changes and style nits * removed unused functions * fix chunk_id and file_id when resuming * - fix parquet loading when resuming - add test to verify the parquet file integrity when resuming so that data files are now overwritten * added general function get_file_size_in_mb and removed the one for video * fix table size value when resuming * Remove unnecessary reloading of the parquet file when resuming record. Write to a new parquet file when resuming record * added back reading parquet file for image datasets only * - respond to Qlhoest comments - Use pyarrows `from_pydict` function - Add buffer for episode metadata to write to the parquet file in batches to improve efficiency - Remove the use of `to_parquet_with_hf_images` * fix(dataset_tools) with the new logic using proper finalize bug in finding the latest path of the metdata that was pointing to the data files added check for the metadata size in the case the metadatabuffer was not written yet * nit in flush_metadata_buffer * fix(lerobot_dataset) return the right dataset len when a subset of the dataset is requested --------- Co-authored-by: Harsimrat Sandhawalia --- src/lerobot/datasets/aggregate.py | 8 +- src/lerobot/datasets/dataset_tools.py | 11 + src/lerobot/datasets/lerobot_dataset.py | 339 +++++++++++++++++------- src/lerobot/datasets/utils.py | 18 +- src/lerobot/datasets/video_utils.py | 3 + src/lerobot/rl/buffer.py | 1 + tests/datasets/test_dataset_tools.py | 4 + tests/datasets/test_datasets.py | 143 ++++++++++ 8 files changed, 419 insertions(+), 108 deletions(-) diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index e7ea59ed030..870c9571e83 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -31,8 +31,8 @@ DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, + get_file_size_in_mb, get_parquet_file_size_in_mb, - get_video_size_in_mb, to_parquet_with_hf_images, update_chunk_file_indices, write_info, @@ -217,6 +217,7 @@ def aggregate_datasets( robot_type=robot_type, features=features, root=aggr_root, + use_videos=len(video_keys) > 0, chunks_size=chunk_size, data_files_size_in_mb=data_files_size_in_mb, video_files_size_in_mb=video_files_size_in_mb, @@ -307,8 +308,9 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu current_offset += src_duration continue - src_size = get_video_size_in_mb(src_path) - dst_size = get_video_size_in_mb(dst_path) + # Check file sizes before appending + src_size = get_file_size_in_mb(src_path) + dst_size = get_file_size_in_mb(dst_path) if dst_size + src_size >= video_files_size_in_mb: # Rotate to a new file, this source becomes start of new destination diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index fdeb24a729a..8ebc4a59dee 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -42,6 +42,7 @@ DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, get_parquet_file_size_in_mb, + load_episodes, to_parquet_with_hf_images, update_chunk_file_indices, write_info, @@ -436,6 +437,9 @@ def _copy_and_reindex_data( Returns: dict mapping episode index to its data file metadata (chunk_index, file_index, etc.) """ + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) + file_to_episodes: dict[Path, set[int]] = {} for old_idx in episode_mapping: file_path = src_dataset.meta.get_data_file_path(old_idx) @@ -645,6 +649,8 @@ def _copy_and_reindex_videos( Returns: dict mapping episode index to its video metadata (chunk_index, file_index, timestamps) """ + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) episodes_video_metadata: dict[int, dict] = {new_idx: {} for new_idx in episode_mapping.values()} @@ -770,6 +776,9 @@ def _copy_and_reindex_episodes_metadata( """ from lerobot.datasets.utils import flatten_dict + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) + all_stats = [] total_frames = 0 @@ -831,6 +840,8 @@ def _copy_and_reindex_episodes_metadata( total_frames += src_episode["length"] + dst_meta._close_writer() + dst_meta.info.update( { "total_episodes": len(episode_mapping), diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 229d376413a..ae142c1e8ff 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -import gc import logging import shutil import tempfile @@ -26,6 +25,8 @@ import packaging.version import pandas as pd import PIL.Image +import pyarrow as pa +import pyarrow.parquet as pq import torch import torch.utils from huggingface_hub import HfApi, snapshot_download @@ -46,13 +47,9 @@ embed_images, flatten_dict, get_delta_indices, - get_hf_dataset_cache_dir, - get_hf_dataset_size_in_mb, + get_file_size_in_mb, get_hf_features_from_features, - get_parquet_file_size_in_mb, - get_parquet_num_frames, get_safe_version, - get_video_size_in_mb, hf_transform_to_torch, is_valid_version, load_episodes, @@ -60,7 +57,6 @@ load_nested_dataset, load_stats, load_tasks, - to_parquet_with_hf_images, update_chunk_file_indices, validate_episode_buffer, validate_frame, @@ -90,10 +86,15 @@ def __init__( root: str | Path | None = None, revision: str | None = None, force_cache_sync: bool = False, + metadata_buffer_size: int = 10, ): self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + self.writer = None + self.latest_episode = None + self.metadata_buffer: list[dict] = [] + self.metadata_buffer_size = metadata_buffer_size try: if force_cache_sync: @@ -107,6 +108,54 @@ def __init__( self.pull_from_repo(allow_patterns="meta/") self.load_metadata() + def _flush_metadata_buffer(self) -> None: + """Write all buffered episode metadata to parquet file.""" + if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: + return + + combined_dict = {} + for episode_dict in self.metadata_buffer: + for key, value in episode_dict.items(): + if key not in combined_dict: + combined_dict[key] = [] + # Extract value and serialize numpy arrays + # because PyArrow's from_pydict function doesn't support numpy arrays + val = value[0] if isinstance(value, list) else value + combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) + + first_ep = self.metadata_buffer[0] + chunk_idx = first_ep["meta/episodes/chunk_index"][0] + file_idx = first_ep["meta/episodes/file_index"][0] + + table = pa.Table.from_pydict(combined_dict) + + if not self.writer: + path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) + path.parent.mkdir(parents=True, exist_ok=True) + self.writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + + self.writer.write_table(table) + + self.latest_episode = self.metadata_buffer[-1] + self.metadata_buffer.clear() + + def _close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + self._flush_metadata_buffer() + + writer = getattr(self, "writer", None) + if writer is not None: + writer.close() + self.writer = None + + def __del__(self): + """ + Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + """ + self._close_writer() + def load_metadata(self): self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) @@ -138,6 +187,12 @@ def _version(self) -> packaging.version.Version: return packaging.version.parse(self.info["codebase_version"]) def get_data_file_path(self, ep_index: int) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) ep = self.episodes[ep_index] chunk_idx = ep["data/chunk_index"] file_idx = ep["data/file_index"] @@ -145,6 +200,12 @@ def get_data_file_path(self, ep_index: int) -> Path: return Path(fpath) def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) ep = self.episodes[ep_index] chunk_idx = ep[f"videos/{vid_key}/chunk_index"] file_idx = ep[f"videos/{vid_key}/file_index"] @@ -260,72 +321,75 @@ def save_episode_tasks(self, tasks: list[str]): write_tasks(self.tasks, self.root) def _save_episode_metadata(self, episode_dict: dict) -> None: - """Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata. + """Buffer episode metadata and write to parquet in batches for efficiency. - This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset, - and saves it as a parquet file. It handles both the creation of new parquet files and the - updating of existing ones based on size constraints. After saving the metadata, it reloads - the Hugging Face dataset to ensure it is up-to-date. + This function accumulates episode metadata in a buffer and flushes it when the buffer + reaches the configured size. This reduces I/O overhead by writing multiple episodes + at once instead of one row at a time. Notes: We both need to update parquet files and HF dataset: - `pandas` loads parquet file in RAM - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, or loads directly from pyarrow cache. """ - # Convert buffer into HF Dataset + # Convert to list format for each value episode_dict = {key: [value] for key, value in episode_dict.items()} - ep_dataset = datasets.Dataset.from_dict(episode_dict) - ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) - df = pd.DataFrame(ep_dataset) num_frames = episode_dict["length"][0] - if self.episodes is None: + if self.latest_episode is None: # Initialize indices and frame count for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 - df["meta/episodes/chunk_index"] = [chunk_idx] - df["meta/episodes/file_index"] = [file_idx] - df["dataset_from_index"] = [0] - df["dataset_to_index"] = [num_frames] - else: - # Retrieve information from the latest parquet file - latest_ep = self.episodes[-1] - chunk_idx = latest_ep["meta/episodes/chunk_index"] - file_idx = latest_ep["meta/episodes/file_index"] + if self.episodes is not None and len(self.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"] + file_idx = self.episodes[-1]["meta/episodes/file_index"] + latest_num_frames = self.episodes[-1]["dataset_to_index"] + episode_dict["dataset_from_index"] = [latest_num_frames] + episode_dict["dataset_to_index"] = [latest_num_frames + num_frames] + + # When resuming, move to the next file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + else: + episode_dict["dataset_from_index"] = [0] + episode_dict["dataset_to_index"] = [num_frames] - latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + else: + chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0] + file_idx = self.latest_episode["meta/episodes/file_index"][0] - if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb: - # Size limit is reached, prepare new parquet file - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + latest_path = ( + self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + if self.writer is None + else self.writer.where + ) - # Update the existing pandas dataframe with new row - df["meta/episodes/chunk_index"] = [chunk_idx] - df["meta/episodes/file_index"] = [file_idx] - df["dataset_from_index"] = [latest_ep["dataset_to_index"]] - df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames] + if Path(latest_path).exists(): + latest_size_in_mb = get_file_size_in_mb(Path(latest_path)) + latest_num_frames = self.latest_episode["episode_index"][0] - if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb: - # Size limit wasnt reached, concatenate latest dataframe with new one - latest_df = pd.read_parquet(latest_path) - df = pd.concat([latest_df, df], ignore_index=True) + av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0 - # Memort optimization - del latest_df - gc.collect() + if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb: + # Size limit is reached, flush buffer and prepare new parquet file + self._flush_metadata_buffer() + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + self._close_writer() - # Write the resulting dataframe from RAM to disk - path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - path.parent.mkdir(parents=True, exist_ok=True) - df.to_parquet(path, index=False) + # Update the existing pandas dataframe with new row + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]] + episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] - if self.episodes is not None: - # Remove the episodes cache directory, necessary to avoid cache bloat - cached_dir = get_hf_dataset_cache_dir(self.episodes) - if cached_dir is not None: - shutil.rmtree(cached_dir) + # Add to buffer + self.metadata_buffer.append(episode_dict) + self.latest_episode = episode_dict - self.episodes = load_episodes(self.root) + if len(self.metadata_buffer) >= self.metadata_buffer_size: + self._flush_metadata_buffer() def save_episode( self, @@ -438,6 +502,7 @@ def create( robot_type: str | None = None, root: str | Path | None = None, use_videos: bool = True, + metadata_buffer_size: int = 10, chunks_size: int | None = None, data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, @@ -469,6 +534,10 @@ def create( raise ValueError() write_json(obj.info, obj.root / INFO_PATH) obj.revision = None + obj.writer = None + obj.latest_episode = None + obj.metadata_buffer = [] + obj.metadata_buffer_size = metadata_buffer_size return obj @@ -615,6 +684,8 @@ def __init__( # Unused attributes self.image_writer = None self.episode_buffer = None + self.writer = None + self.latest_episode = None self.root.mkdir(exist_ok=True, parents=True) @@ -623,6 +694,11 @@ def __init__( self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync ) + # Track dataset state for efficient incremental writing + self._lazy_loading = False + self._recorded_frames = self.meta.total_frames + self._writer_closed_for_reading = False + # Load actual data try: if force_cache_sync: @@ -641,6 +717,19 @@ def __init__( check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + def _close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + writer = getattr(self, "writer", None) + if writer is not None: + writer.close() + self.writer = None + + def __del__(self): + """ + Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + """ + self._close_writer() + def push_to_hub( self, branch: str | None = None, @@ -781,8 +870,15 @@ def fps(self) -> int: @property def num_frames(self) -> int: - """Number of frames in selected episodes.""" - return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames + """Number of frames in selected episodes. + + Note: When episodes a subset of the full dataset is requested, we must return the + actual loaded data length (len(self.hf_dataset)) rather than metadata total_frames. + self.meta.total_frames is the total number of frames in the full dataset. + """ + if self.episodes is not None and self.hf_dataset is not None: + return len(self.hf_dataset) + return self.meta.total_frames @property def num_episodes(self) -> int: @@ -860,10 +956,22 @@ def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) - return item + def _ensure_hf_dataset_loaded(self): + """Lazy load the HF dataset only when needed for reading.""" + if self._lazy_loading or self.hf_dataset is None: + # Close the writer before loading to ensure parquet file is properly finalized + if self.writer is not None: + self._close_writer() + self._writer_closed_for_reading = True + self.hf_dataset = self.load_hf_dataset() + self._lazy_loading = False + def __len__(self): return self.num_frames def __getitem__(self, idx) -> dict: + # Ensure dataset is loaded when we actually need to read from it + self._ensure_hf_dataset_loaded() item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() @@ -902,6 +1010,14 @@ def __repr__(self): "})',\n" ) + def finalize(self): + """ + Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files. + The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo)) + """ + self._close_writer() + self.meta._close_writer() + def create_episode_buffer(self, episode_index: int | None = None) -> dict: current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index ep_buffer = {} @@ -1109,74 +1225,101 @@ def _save_episode_data(self, episode_buffer: dict) -> dict: ep_dict = {key: episode_buffer[key] for key in self.hf_features} ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train") ep_dataset = embed_images(ep_dataset) - ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) ep_num_frames = len(ep_dataset) - df = pd.DataFrame(ep_dataset) - if self.meta.episodes is None: + if self.latest_episode is None: # Initialize indices and frame count for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 - latest_num_frames = 0 + global_frame_index = 0 + # However, if the episodes already exists + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + if self.meta.episodes is not None and len(self.meta.episodes) > 0: + latest_ep = self.meta.episodes[-1] + global_frame_index = latest_ep["dataset_to_index"] + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + + # When resuming, move to the next file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) else: # Retrieve information from the latest parquet file - latest_ep = self.meta.episodes[-1] + latest_ep = self.latest_episode chunk_idx = latest_ep["data/chunk_index"] file_idx = latest_ep["data/file_index"] + global_frame_index = latest_ep["index"][-1] + 1 latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) - latest_num_frames = get_parquet_num_frames(latest_path) + latest_size_in_mb = get_file_size_in_mb(latest_path) + + frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"] + av_size_per_frame = ( + latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0 + ) # Determine if a new parquet file is needed - if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb: - # Size limit is reached, prepare new parquet file + if ( + latest_size_in_mb + av_size_per_frame * ep_num_frames >= self.meta.data_files_size_in_mb + or self._writer_closed_for_reading + ): + # Size limit is reached or writer was closed for reading, prepare new parquet file chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - latest_num_frames = 0 - else: - # Update the existing parquet file with new rows - latest_df = pd.read_parquet(latest_path) - df = pd.concat([latest_df, df], ignore_index=True) + self._close_writer() + self._writer_closed_for_reading = False - # Memort optimization - del latest_df - gc.collect() + ep_dict["data/chunk_index"] = chunk_idx + ep_dict["data/file_index"] = file_idx # Write the resulting dataframe from RAM to disk path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) path.parent.mkdir(parents=True, exist_ok=True) - if len(self.meta.image_keys) > 0: - to_parquet_with_hf_images(df, path) - else: - df.to_parquet(path) - if self.hf_dataset is not None: - # Remove hf dataset cache directory, necessary to avoid cache bloat - cached_dir = get_hf_dataset_cache_dir(self.hf_dataset) - if cached_dir is not None: - shutil.rmtree(cached_dir) - - self.hf_dataset = self.load_hf_dataset() + table = ep_dataset.with_format("arrow")[:] + if not self.writer: + self.writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + self.writer.write_table(table) metadata = { "data/chunk_index": chunk_idx, "data/file_index": file_idx, - "dataset_from_index": latest_num_frames, - "dataset_to_index": latest_num_frames + ep_num_frames, + "dataset_from_index": global_frame_index, + "dataset_to_index": global_frame_index + ep_num_frames, } + + # Store metadata with episode data for next episode + self.latest_episode = {**ep_dict, **metadata} + + # Mark that the HF dataset needs reloading (lazy loading approach) + # This avoids expensive reloading during sequential recording + self._lazy_loading = True + # Update recorded frames count for efficient length tracking + self._recorded_frames += ep_num_frames + return metadata def _save_episode_video(self, video_key: str, episode_index: int) -> dict: # Encode episode frames into a temporary video ep_path = self._encode_temporary_episode_video(video_key, episode_index) - ep_size_in_mb = get_video_size_in_mb(ep_path) + ep_size_in_mb = get_file_size_in_mb(ep_path) ep_duration_in_s = get_video_duration_in_s(ep_path) - if self.meta.episodes is None or ( - f"videos/{video_key}/chunk_index" not in self.meta.episodes.column_names - or f"videos/{video_key}/file_index" not in self.meta.episodes.column_names + if ( + episode_index == 0 + or self.meta.latest_episode is None + or f"videos/{video_key}/chunk_index" not in self.meta.latest_episode ): # Initialize indices for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 + if self.meta.episodes is not None and len(self.meta.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + old_chunk_idx = self.meta.episodes[-1][f"videos/{video_key}/chunk_index"] + old_file_idx = self.meta.episodes[-1][f"videos/{video_key}/file_index"] + chunk_idx, file_idx = update_chunk_file_indices( + old_chunk_idx, old_file_idx, self.meta.chunks_size + ) latest_duration_in_s = 0.0 new_path = self.root / self.meta.video_path.format( video_key=video_key, chunk_index=chunk_idx, file_index=file_idx @@ -1184,16 +1327,16 @@ def _save_episode_video(self, video_key: str, episode_index: int) -> dict: new_path.parent.mkdir(parents=True, exist_ok=True) shutil.move(str(ep_path), str(new_path)) else: - # Retrieve information from the latest updated video file (possibly several episodes ago) - latest_ep = self.meta.episodes[episode_index - 1] - chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"] - file_idx = latest_ep[f"videos/{video_key}/file_index"] + # Retrieve information from the latest updated video file using latest_episode + latest_ep = self.meta.latest_episode + chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0] + file_idx = latest_ep[f"videos/{video_key}/file_index"][0] latest_path = self.root / self.meta.video_path.format( video_key=video_key, chunk_index=chunk_idx, file_index=file_idx ) - latest_size_in_mb = get_video_size_in_mb(latest_path) - latest_duration_in_s = get_video_duration_in_s(latest_path) + latest_size_in_mb = get_file_size_in_mb(latest_path) + latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0] if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb: # Move temporary episode video to a new video file in the dataset @@ -1327,6 +1470,12 @@ def create( obj.delta_timestamps = None obj.delta_indices = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj.writer = None + obj.latest_episode = None + # Initialize tracking for incremental recording + obj._lazy_loading = False + obj._recorded_frames = 0 + obj._writer_closed_for_reading = False return obj diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 422a7010a6e..37d8432b2ba 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -94,12 +94,6 @@ def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: return hf_ds.data.nbytes // (1024**2) -def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None: - if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0: - return None - return Path(hf_ds.cache_files[0]["filename"]).parents[2] - - def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: if file_idx == chunks_size - 1: file_idx = 0 @@ -133,10 +127,14 @@ def get_parquet_num_frames(parquet_path: str | Path) -> int: return metadata.num_rows -def get_video_size_in_mb(mp4_path: Path) -> float: - file_size_bytes = mp4_path.stat().st_size - file_size_mb = file_size_bytes / (1024**2) - return file_size_mb +def get_file_size_in_mb(file_path: Path) -> float: + """Get file size on disk in megabytes. + + Args: + file_path (Path): Path to the file. + """ + file_size_bytes = file_path.stat().st_size + return file_size_bytes / (1024**2) def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 620ba863aca..740cdb6020d 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -642,6 +642,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): ) self.dataset._batch_save_episode_video(start_ep, end_ep) + # Finalize the dataset to properly close all writers + self.dataset.finalize() + # Clean up episode images if recording was interrupted if exc_type is not None: interrupted_episode_index = self.dataset.num_episodes diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 917e4e2cc96..81aa29c4803 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -607,6 +607,7 @@ def to_lerobot_dataset( lerobot_dataset.save_episode() lerobot_dataset.stop_image_writer() + lerobot_dataset.finalize() return lerobot_dataset diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index fe117b35b80..a9c04d6f244 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -55,6 +55,7 @@ def sample_dataset(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame(frame) dataset.save_episode() + dataset.finalize() return dataset @@ -263,6 +264,7 @@ def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fact } dataset2.add_frame(frame) dataset2.save_episode() + dataset2.finalize() with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, @@ -685,6 +687,7 @@ def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fa } dataset.add_frame(frame) dataset.save_episode() + dataset.finalize() datasets.append(dataset) @@ -728,6 +731,7 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f } dataset2.add_frame(frame) dataset2.save_episode() + dataset2.finalize() with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 2bc3bea43be..e174c578968 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -806,6 +806,8 @@ def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"}) dataset.save_episode() + dataset.finalize() + # Load the dataset and check episode indices loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) @@ -855,6 +857,8 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact dataset.add_frame({"state": torch.randn(3), ACTION: torch.randn(2), "task": tasks[episode_idx]}) dataset.save_episode() + dataset.finalize() + # Load and validate episode metadata loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) @@ -893,6 +897,8 @@ def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factor dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check data consistency - no gaps or overlaps @@ -944,6 +950,8 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory) dataset.add_frame({"state": state_data, ACTION: action_data, "task": "stats_test"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check that statistics exist for all features @@ -989,6 +997,8 @@ def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Test episode boundaries @@ -1031,6 +1041,8 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory): dataset.add_frame({"state": torch.randn(1), "task": task}) dataset.save_episode() + dataset.finalize() + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) # Check that all unique tasks are in the tasks metadata @@ -1056,3 +1068,134 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory): # Check total number of tasks assert loaded_dataset.meta.total_tasks == len(unique_tasks) + + +def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory): + """Test that resuming dataset recording preserves previously recorded episodes. + + This test validates the critical resume functionality by: + 1. Recording initial episodes and finalizing + 2. Reopening the dataset + 3. Recording additional episodes + 4. Verifying all data (old + new) is intact + + This specifically tests the bug fix where parquet files were being overwritten + instead of appended to during resume. + """ + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + } + + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + initial_episodes = 2 + frames_per_episode = 3 + + for ep_idx in range(initial_episodes): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([float(ep_idx), float(frame_idx)]), + "action": torch.tensor([0.5, 0.5]), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + + assert dataset.meta.total_episodes == initial_episodes + assert dataset.meta.total_frames == initial_episodes * frames_per_episode + + dataset.finalize() + initial_root = dataset.root + initial_repo_id = dataset.repo_id + del dataset + + dataset_verify = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + assert dataset_verify.meta.total_episodes == initial_episodes + assert dataset_verify.meta.total_frames == initial_episodes * frames_per_episode + assert len(dataset_verify.hf_dataset) == initial_episodes * frames_per_episode + + for idx in range(len(dataset_verify.hf_dataset)): + item = dataset_verify[idx] + expected_ep = idx // frames_per_episode + expected_frame = idx % frames_per_episode + assert item["episode_index"].item() == expected_ep + assert item["frame_index"].item() == expected_frame + assert item["index"].item() == idx + assert item["observation.state"][0].item() == float(expected_ep) + assert item["observation.state"][1].item() == float(expected_frame) + + del dataset_verify + + # Phase 3: Resume recording - add more episodes + dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + + assert dataset_resumed.meta.total_episodes == initial_episodes + assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode + assert dataset_resumed.latest_episode is None # Not recording yet + assert dataset_resumed.writer is None + assert dataset_resumed.meta.writer is None + + additional_episodes = 2 + for ep_idx in range(initial_episodes, initial_episodes + additional_episodes): + for frame_idx in range(frames_per_episode): + dataset_resumed.add_frame( + { + "observation.state": torch.tensor([float(ep_idx), float(frame_idx)]), + "action": torch.tensor([0.5, 0.5]), + "task": f"task_{ep_idx}", + } + ) + dataset_resumed.save_episode() + + total_episodes = initial_episodes + additional_episodes + total_frames = total_episodes * frames_per_episode + assert dataset_resumed.meta.total_episodes == total_episodes + assert dataset_resumed.meta.total_frames == total_frames + + dataset_resumed.finalize() + del dataset_resumed + + dataset_final = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + + assert dataset_final.meta.total_episodes == total_episodes + assert dataset_final.meta.total_frames == total_frames + assert len(dataset_final.hf_dataset) == total_frames + + for idx in range(total_frames): + item = dataset_final[idx] + expected_ep = idx // frames_per_episode + expected_frame = idx % frames_per_episode + + assert item["episode_index"].item() == expected_ep, ( + f"Frame {idx}: wrong episode_index. Expected {expected_ep}, got {item['episode_index'].item()}" + ) + assert item["frame_index"].item() == expected_frame, ( + f"Frame {idx}: wrong frame_index. Expected {expected_frame}, got {item['frame_index'].item()}" + ) + assert item["index"].item() == idx, ( + f"Frame {idx}: wrong index. Expected {idx}, got {item['index'].item()}" + ) + + # Verify data integrity + assert item["observation.state"][0].item() == float(expected_ep), ( + f"Frame {idx}: wrong observation.state[0]. Expected {float(expected_ep)}, " + f"got {item['observation.state'][0].item()}" + ) + assert item["observation.state"][1].item() == float(expected_frame), ( + f"Frame {idx}: wrong observation.state[1]. Expected {float(expected_frame)}, " + f"got {item['observation.state'][1].item()}" + ) + + assert len(dataset_final.meta.episodes) == total_episodes + for ep_idx in range(total_episodes): + ep_metadata = dataset_final.meta.episodes[ep_idx] + assert ep_metadata["episode_index"] == ep_idx + assert ep_metadata["length"] == frames_per_episode + assert ep_metadata["tasks"] == [f"task_{ep_idx}"] + + expected_from = ep_idx * frames_per_episode + expected_to = (ep_idx + 1) * frames_per_episode + assert ep_metadata["dataset_from_index"] == expected_from + assert ep_metadata["dataset_to_index"] == expected_to From 0c79cf8f4ed4baa98db878ee7d2d091df447d878 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Sat, 11 Oct 2025 21:15:43 +0200 Subject: [PATCH 10/12] Add missing finalize calls in example (#2175) - add missing calls to dataset.finalize in the example recording scripts - add section in the dataset docs on calling dataset.finalize --- docs/source/lerobot-dataset-v3.mdx | 33 ++++++++++++++++++++++++++ examples/lekiwi/evaluate.py | 2 ++ examples/lekiwi/record.py | 2 ++ examples/phone_to_so100/evaluate.py | 2 ++ examples/phone_to_so100/record.py | 2 ++ examples/port_datasets/port_droid.py | 2 ++ examples/so100_to_so100_EE/evaluate.py | 2 ++ examples/so100_to_so100_EE/record.py | 2 ++ 8 files changed, 47 insertions(+) diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx index cf1942fdcfb..3521914f2f2 100644 --- a/docs/source/lerobot-dataset-v3.mdx +++ b/docs/source/lerobot-dataset-v3.mdx @@ -279,3 +279,36 @@ python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id= Date: Mon, 13 Oct 2025 10:44:53 +0200 Subject: [PATCH 11/12] fix: very minor fix but hey devil is in details (#2168) Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- src/lerobot/policies/pi0/modeling_pi0.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index a2dcdaea37f..596b273d5fb 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -897,7 +897,7 @@ def from_pretrained( ) -> T: """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( - "The PI05 model is a direct port of the OpenPI implementation. \n" + "The PI0 model is a direct port of the OpenPI implementation. \n" "This implementation follows the original OpenPI structure for compatibility. \n" "Original implementation: https://github.com/Physical-Intelligence/openpi" ) From 6f5bb4d4a49fbdb47acfeaa2c190b5fa125f645a Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Mon, 13 Oct 2025 16:43:23 +0200 Subject: [PATCH 12/12] fix outdated example in docs (#2182) * fix outdated example Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> * Update docs/source/il_robots.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/source/il_robots.mdx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 91df14028cd..0d8fd56e5a8 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -513,13 +513,14 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_pre_post_processors from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun -from lerobot.record import record_loop -from lerobot.policies.factory import make_processor + NUM_EPISODES = 5 FPS = 30 @@ -562,7 +563,7 @@ init_rerun(session_name="recording") # Connect the robot robot.connect() -preprocessor, postprocessor = make_processor( +preprocessor, postprocessor = make_pre_post_processors( policy_cfg=policy, pretrained_path=HF_MODEL_ID, dataset_stats=dataset.meta.stats,