Skip to content

Commit 111cd58

Browse files
Add MultiLerobotDataset for training with multiple LeRobotDatasets (#229)
1 parent 265b0ec commit 111cd58

File tree

8 files changed

+352
-72
lines changed

8 files changed

+352
-72
lines changed

lerobot/common/datasets/push_dataset_to_hub/compute_stats.py renamed to lerobot/common/datasets/compute_stats.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,15 @@
1616
from copy import deepcopy
1717
from math import ceil
1818

19-
import datasets
2019
import einops
2120
import torch
2221
import tqdm
2322
from datasets import Image
2423

25-
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
2624
from lerobot.common.datasets.video_utils import VideoFrame
2725

2826

29-
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
27+
def get_stats_einops_patterns(dataset, num_workers=0):
3028
"""These einops patterns will be used to aggregate batches and compute statistics.
3129
3230
Note: We assume the images are in channel first format
@@ -66,9 +64,8 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_wo
6664
return stats_patterns
6765

6866

69-
def compute_stats(
70-
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
71-
):
67+
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
68+
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
7269
if max_num_samples is None:
7370
max_num_samples = len(dataset)
7471

@@ -159,3 +156,54 @@ def create_seeded_dataloader(dataset, batch_size, seed):
159156
"min": min[key],
160157
}
161158
return stats
159+
160+
161+
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
162+
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
163+
164+
The final stats will have the union of all data keys from each of the datasets.
165+
166+
The final stats will have the union of all data keys from each of the datasets. For instance:
167+
- new_max = max(max_dataset_0, max_dataset_1, ...)
168+
- new_min = min(min_dataset_0, min_dataset_1, ...)
169+
- new_mean = (mean of all data)
170+
- new_std = (std of all data)
171+
"""
172+
data_keys = set()
173+
for dataset in ls_datasets:
174+
data_keys.update(dataset.stats.keys())
175+
stats = {k: {} for k in data_keys}
176+
for data_key in data_keys:
177+
for stat_key in ["min", "max"]:
178+
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
179+
stats[data_key][stat_key] = einops.reduce(
180+
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
181+
"n ... -> ...",
182+
stat_key,
183+
)
184+
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
185+
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
186+
# dataset, then divide by total_samples to get the overall "mean".
187+
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
188+
# numerical overflow!
189+
stats[data_key]["mean"] = sum(
190+
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
191+
for d in ls_datasets
192+
if data_key in d.stats
193+
)
194+
# The derivation for standard deviation is a little more involved but is much in the same spirit as
195+
# the computation of the mean.
196+
# Given two sets of data where the statistics are known:
197+
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
198+
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
199+
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
200+
# numerical overflow!
201+
stats[data_key]["std"] = torch.sqrt(
202+
sum(
203+
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
204+
* (d.num_samples / total_samples)
205+
for d in ls_datasets
206+
if data_key in d.stats
207+
)
208+
)
209+
return stats

lerobot/common/datasets/factory.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import logging
1717

1818
import torch
19-
from omegaconf import OmegaConf
19+
from omegaconf import ListConfig, OmegaConf
2020

21-
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
21+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
2222

2323

2424
def resolve_delta_timestamps(cfg):
@@ -35,11 +35,27 @@ def resolve_delta_timestamps(cfg):
3535
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
3636

3737

38-
def make_dataset(
39-
cfg,
40-
split="train",
41-
):
42-
if cfg.env.name not in cfg.dataset_repo_id:
38+
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
39+
"""
40+
Args:
41+
cfg: A Hydra config as per the LeRobot config scheme.
42+
split: Select the data subset used to create an instance of LeRobotDataset.
43+
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
44+
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
45+
slicer in the hugging face datasets:
46+
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
47+
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
48+
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
49+
Returns:
50+
The LeRobotDataset.
51+
"""
52+
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
53+
raise ValueError(
54+
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
55+
"strings to load multiple datasets."
56+
)
57+
58+
if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id:
4359
logging.warning(
4460
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
4561
f"environment ({cfg.env.name=})."
@@ -49,11 +65,16 @@ def make_dataset(
4965

5066
# TODO(rcadene): add data augmentations
5167

52-
dataset = LeRobotDataset(
53-
cfg.dataset_repo_id,
54-
split=split,
55-
delta_timestamps=cfg.training.get("delta_timestamps"),
56-
)
68+
if isinstance(cfg.dataset_repo_id, str):
69+
dataset = LeRobotDataset(
70+
cfg.dataset_repo_id,
71+
split=split,
72+
delta_timestamps=cfg.training.get("delta_timestamps"),
73+
)
74+
else:
75+
dataset = MultiLeRobotDataset(
76+
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
77+
)
5778

5879
if cfg.get("override_dataset_stats"):
5980
for key, stats_dict in cfg.override_dataset_stats.items():

lerobot/common/datasets/lerobot_dataset.py

Lines changed: 202 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,16 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import logging
1617
import os
1718
from pathlib import Path
19+
from typing import Callable
1820

1921
import datasets
2022
import torch
23+
import torch.utils
2124

25+
from lerobot.common.datasets.compute_stats import aggregate_stats
2226
from lerobot.common.datasets.utils import (
2327
calculate_episode_data_index,
2428
load_episode_data_index,
@@ -42,7 +46,7 @@ def __init__(
4246
version: str | None = CODEBASE_VERSION,
4347
root: Path | None = DATA_DIR,
4448
split: str = "train",
45-
transform: callable = None,
49+
transform: Callable | None = None,
4650
delta_timestamps: dict[list[float]] | None = None,
4751
):
4852
super().__init__()
@@ -171,7 +175,7 @@ def __repr__(self):
171175
@classmethod
172176
def from_preloaded(
173177
cls,
174-
repo_id: str,
178+
repo_id: str = "from_preloaded",
175179
version: str | None = CODEBASE_VERSION,
176180
root: Path | None = None,
177181
split: str = "train",
@@ -183,7 +187,15 @@ def from_preloaded(
183187
stats=None,
184188
info=None,
185189
videos_dir=None,
186-
):
190+
) -> "LeRobotDataset":
191+
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
192+
193+
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
194+
on the filesystem or uploading to the hub.
195+
196+
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
197+
meaningless depending on the downstream usage of the return dataset.
198+
"""
187199
# create an empty object of type LeRobotDataset
188200
obj = cls.__new__(cls)
189201
obj.repo_id = repo_id
@@ -195,6 +207,192 @@ def from_preloaded(
195207
obj.hf_dataset = hf_dataset
196208
obj.episode_data_index = episode_data_index
197209
obj.stats = stats
198-
obj.info = info
210+
obj.info = info if info is not None else {}
199211
obj.videos_dir = videos_dir
200212
return obj
213+
214+
215+
class MultiLeRobotDataset(torch.utils.data.Dataset):
216+
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
217+
218+
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
219+
structure of `LeRobotDataset`.
220+
"""
221+
222+
def __init__(
223+
self,
224+
repo_ids: list[str],
225+
version: str | None = CODEBASE_VERSION,
226+
root: Path | None = DATA_DIR,
227+
split: str = "train",
228+
transform: Callable | None = None,
229+
delta_timestamps: dict[list[float]] | None = None,
230+
):
231+
super().__init__()
232+
self.repo_ids = repo_ids
233+
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
234+
# are handled by this class.
235+
self._datasets = [
236+
LeRobotDataset(
237+
repo_id,
238+
version=version,
239+
root=root,
240+
split=split,
241+
delta_timestamps=delta_timestamps,
242+
transform=transform,
243+
)
244+
for repo_id in repo_ids
245+
]
246+
# Check that some properties are consistent across datasets. Note: We may relax some of these
247+
# consistency requirements in future iterations of this class.
248+
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
249+
if dataset.info != self._datasets[0].info:
250+
raise ValueError(
251+
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
252+
"not yet supported."
253+
)
254+
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
255+
# restriction in future iterations of this class. For now, this is necessary at least for being able
256+
# to use PyTorch's default DataLoader collate function.
257+
self.disabled_data_keys = set()
258+
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
259+
for dataset in self._datasets:
260+
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
261+
if len(intersection_data_keys) == 0:
262+
raise RuntimeError(
263+
"Multiple datasets were provided but they had no keys common to all of them. The "
264+
"multi-dataset functionality currently only keeps common keys."
265+
)
266+
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
267+
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
268+
logging.warning(
269+
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
270+
"other datasets."
271+
)
272+
self.disabled_data_keys.update(extra_keys)
273+
274+
self.version = version
275+
self.root = root
276+
self.split = split
277+
self.transform = transform
278+
self.delta_timestamps = delta_timestamps
279+
self.stats = aggregate_stats(self._datasets)
280+
281+
@property
282+
def repo_id_to_index(self):
283+
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
284+
285+
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
286+
"""
287+
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
288+
289+
@property
290+
def repo_index_to_id(self):
291+
"""Return the inverse mapping if repo_id_to_index."""
292+
return {v: k for k, v in self.repo_id_to_index}
293+
294+
@property
295+
def fps(self) -> int:
296+
"""Frames per second used during data collection.
297+
298+
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
299+
"""
300+
return self._datasets[0].info["fps"]
301+
302+
@property
303+
def video(self) -> bool:
304+
"""Returns True if this dataset loads video frames from mp4 files.
305+
306+
Returns False if it only loads images from png files.
307+
308+
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
309+
"""
310+
return self._datasets[0].info.get("video", False)
311+
312+
@property
313+
def features(self) -> datasets.Features:
314+
features = {}
315+
for dataset in self._datasets:
316+
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
317+
return features
318+
319+
@property
320+
def camera_keys(self) -> list[str]:
321+
"""Keys to access image and video stream from cameras."""
322+
keys = []
323+
for key, feats in self.features.items():
324+
if isinstance(feats, (datasets.Image, VideoFrame)):
325+
keys.append(key)
326+
return keys
327+
328+
@property
329+
def video_frame_keys(self) -> list[str]:
330+
"""Keys to access video frames that requires to be decoded into images.
331+
332+
Note: It is empty if the dataset contains images only,
333+
or equal to `self.cameras` if the dataset contains videos only,
334+
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
335+
"""
336+
video_frame_keys = []
337+
for key, feats in self.features.items():
338+
if isinstance(feats, VideoFrame):
339+
video_frame_keys.append(key)
340+
return video_frame_keys
341+
342+
@property
343+
def num_samples(self) -> int:
344+
"""Number of samples/frames."""
345+
return sum(d.num_samples for d in self._datasets)
346+
347+
@property
348+
def num_episodes(self) -> int:
349+
"""Number of episodes."""
350+
return sum(d.num_episodes for d in self._datasets)
351+
352+
@property
353+
def tolerance_s(self) -> float:
354+
"""Tolerance in seconds used to discard loaded frames when their timestamps
355+
are not close enough from the requested frames. It is only used when `delta_timestamps`
356+
is provided or when loading video frames from mp4 files.
357+
"""
358+
# 1e-4 to account for possible numerical error
359+
return 1 / self.fps - 1e-4
360+
361+
def __len__(self):
362+
return self.num_samples
363+
364+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
365+
if idx >= len(self):
366+
raise IndexError(f"Index {idx} out of bounds.")
367+
# Determine which dataset to get an item from based on the index.
368+
start_idx = 0
369+
dataset_idx = 0
370+
for dataset in self._datasets:
371+
if idx >= start_idx + dataset.num_samples:
372+
start_idx += dataset.num_samples
373+
dataset_idx += 1
374+
break
375+
else:
376+
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
377+
item = self._datasets[dataset_idx][idx - start_idx]
378+
item["dataset_index"] = torch.tensor(dataset_idx)
379+
for data_key in self.disabled_data_keys:
380+
if data_key in item:
381+
del item[data_key]
382+
return item
383+
384+
def __repr__(self):
385+
return (
386+
f"{self.__class__.__name__}(\n"
387+
f" Repository IDs: '{self.repo_ids}',\n"
388+
f" Version: '{self.version}',\n"
389+
f" Split: '{self.split}',\n"
390+
f" Number of Samples: {self.num_samples},\n"
391+
f" Number of Episodes: {self.num_episodes},\n"
392+
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
393+
f" Recorded Frames per Second: {self.fps},\n"
394+
f" Camera Keys: {self.camera_keys},\n"
395+
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
396+
f" Transformations: {self.transform},\n"
397+
f")"
398+
)

0 commit comments

Comments
 (0)