Skip to content

Commit 0b51a33

Browse files
Add a test for MultiLeRobotDataset making sure it produces all frames. (#230)
Co-authored-by: Remi <[email protected]>
1 parent 111cd58 commit 0b51a33

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

tests/test_datasets.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
from copy import deepcopy
19+
from itertools import chain
1920
from pathlib import Path
2021

2122
import einops
@@ -31,7 +32,7 @@
3132
get_stats_einops_patterns,
3233
)
3334
from lerobot.common.datasets.factory import make_dataset
34-
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
35+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
3536
from lerobot.common.datasets.utils import (
3637
flatten_dict,
3738
hf_transform_to_torch,
@@ -113,6 +114,32 @@ def test_factory(env_name, repo_id, policy_name):
113114
assert key in item, f"{key}"
114115

115116

117+
def test_multilerobotdataset_frames():
118+
"""Check that all dataset frames are incorporated."""
119+
# Note: use the image variants of the dataset to make the test approx 3x faster.
120+
repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"]
121+
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
122+
dataset = MultiLeRobotDataset(repo_ids)
123+
assert len(dataset) == sum(len(d) for d in sub_datasets)
124+
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
125+
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
126+
127+
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
128+
# check they match.
129+
expected_dataset_indices = []
130+
for i, sub_dataset in enumerate(sub_datasets):
131+
expected_dataset_indices.extend([i] * len(sub_dataset))
132+
133+
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
134+
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
135+
):
136+
dataset_index = dataset_item.pop("dataset_index")
137+
assert dataset_index == expected_dataset_index
138+
assert sub_dataset_item.keys() == dataset_item.keys()
139+
for k in sub_dataset_item:
140+
assert torch.equal(sub_dataset_item[k], dataset_item[k])
141+
142+
116143
def test_compute_stats_on_xarm():
117144
"""Check that the statistics are computed correctly according to the stats_patterns property.
118145

0 commit comments

Comments
 (0)