|
16 | 16 | import json
|
17 | 17 | import logging
|
18 | 18 | from copy import deepcopy
|
| 19 | +from itertools import chain |
19 | 20 | from pathlib import Path
|
20 | 21 |
|
21 | 22 | import einops
|
|
31 | 32 | get_stats_einops_patterns,
|
32 | 33 | )
|
33 | 34 | from lerobot.common.datasets.factory import make_dataset
|
34 |
| -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset |
| 35 | +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset |
35 | 36 | from lerobot.common.datasets.utils import (
|
36 | 37 | flatten_dict,
|
37 | 38 | hf_transform_to_torch,
|
@@ -113,6 +114,32 @@ def test_factory(env_name, repo_id, policy_name):
|
113 | 114 | assert key in item, f"{key}"
|
114 | 115 |
|
115 | 116 |
|
| 117 | +def test_multilerobotdataset_frames(): |
| 118 | + """Check that all dataset frames are incorporated.""" |
| 119 | + # Note: use the image variants of the dataset to make the test approx 3x faster. |
| 120 | + repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"] |
| 121 | + sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] |
| 122 | + dataset = MultiLeRobotDataset(repo_ids) |
| 123 | + assert len(dataset) == sum(len(d) for d in sub_datasets) |
| 124 | + assert dataset.num_samples == sum(d.num_samples for d in sub_datasets) |
| 125 | + assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets) |
| 126 | + |
| 127 | + # Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and |
| 128 | + # check they match. |
| 129 | + expected_dataset_indices = [] |
| 130 | + for i, sub_dataset in enumerate(sub_datasets): |
| 131 | + expected_dataset_indices.extend([i] * len(sub_dataset)) |
| 132 | + |
| 133 | + for expected_dataset_index, sub_dataset_item, dataset_item in zip( |
| 134 | + expected_dataset_indices, chain(*sub_datasets), dataset, strict=True |
| 135 | + ): |
| 136 | + dataset_index = dataset_item.pop("dataset_index") |
| 137 | + assert dataset_index == expected_dataset_index |
| 138 | + assert sub_dataset_item.keys() == dataset_item.keys() |
| 139 | + for k in sub_dataset_item: |
| 140 | + assert torch.equal(sub_dataset_item[k], dataset_item[k]) |
| 141 | + |
| 142 | + |
116 | 143 | def test_compute_stats_on_xarm():
|
117 | 144 | """Check that the statistics are computed correctly according to the stats_patterns property.
|
118 | 145 |
|
|
0 commit comments