13
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
+ import logging
16
17
import os
17
18
from pathlib import Path
19
+ from typing import Callable
18
20
19
21
import datasets
20
22
import torch
23
+ import torch .utils
21
24
25
+ from lerobot .common .datasets .compute_stats import aggregate_stats
22
26
from lerobot .common .datasets .utils import (
23
27
calculate_episode_data_index ,
24
28
load_episode_data_index ,
@@ -42,7 +46,7 @@ def __init__(
42
46
version : str | None = CODEBASE_VERSION ,
43
47
root : Path | None = DATA_DIR ,
44
48
split : str = "train" ,
45
- transform : callable = None ,
49
+ transform : Callable | None = None ,
46
50
delta_timestamps : dict [list [float ]] | None = None ,
47
51
):
48
52
super ().__init__ ()
@@ -171,7 +175,7 @@ def __repr__(self):
171
175
@classmethod
172
176
def from_preloaded (
173
177
cls ,
174
- repo_id : str ,
178
+ repo_id : str = "from_preloaded" ,
175
179
version : str | None = CODEBASE_VERSION ,
176
180
root : Path | None = None ,
177
181
split : str = "train" ,
@@ -183,7 +187,15 @@ def from_preloaded(
183
187
stats = None ,
184
188
info = None ,
185
189
videos_dir = None ,
186
- ):
190
+ ) -> "LeRobotDataset" :
191
+ """Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
192
+
193
+ It is especially useful when converting raw data into LeRobotDataset before saving the dataset
194
+ on the filesystem or uploading to the hub.
195
+
196
+ Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
197
+ meaningless depending on the downstream usage of the return dataset.
198
+ """
187
199
# create an empty object of type LeRobotDataset
188
200
obj = cls .__new__ (cls )
189
201
obj .repo_id = repo_id
@@ -195,6 +207,192 @@ def from_preloaded(
195
207
obj .hf_dataset = hf_dataset
196
208
obj .episode_data_index = episode_data_index
197
209
obj .stats = stats
198
- obj .info = info
210
+ obj .info = info if info is not None else {}
199
211
obj .videos_dir = videos_dir
200
212
return obj
213
+
214
+
215
+ class MultiLeRobotDataset (torch .utils .data .Dataset ):
216
+ """A dataset consisting of multiple underlying `LeRobotDataset`s.
217
+
218
+ The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
219
+ structure of `LeRobotDataset`.
220
+ """
221
+
222
+ def __init__ (
223
+ self ,
224
+ repo_ids : list [str ],
225
+ version : str | None = CODEBASE_VERSION ,
226
+ root : Path | None = DATA_DIR ,
227
+ split : str = "train" ,
228
+ transform : Callable | None = None ,
229
+ delta_timestamps : dict [list [float ]] | None = None ,
230
+ ):
231
+ super ().__init__ ()
232
+ self .repo_ids = repo_ids
233
+ # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
234
+ # are handled by this class.
235
+ self ._datasets = [
236
+ LeRobotDataset (
237
+ repo_id ,
238
+ version = version ,
239
+ root = root ,
240
+ split = split ,
241
+ delta_timestamps = delta_timestamps ,
242
+ transform = transform ,
243
+ )
244
+ for repo_id in repo_ids
245
+ ]
246
+ # Check that some properties are consistent across datasets. Note: We may relax some of these
247
+ # consistency requirements in future iterations of this class.
248
+ for repo_id , dataset in zip (self .repo_ids , self ._datasets , strict = True ):
249
+ if dataset .info != self ._datasets [0 ].info :
250
+ raise ValueError (
251
+ f"Detected a mismatch in dataset info between { self .repo_ids [0 ]} and { repo_id } . This is "
252
+ "not yet supported."
253
+ )
254
+ # Disable any data keys that are not common across all of the datasets. Note: we may relax this
255
+ # restriction in future iterations of this class. For now, this is necessary at least for being able
256
+ # to use PyTorch's default DataLoader collate function.
257
+ self .disabled_data_keys = set ()
258
+ intersection_data_keys = set (self ._datasets [0 ].hf_dataset .features )
259
+ for dataset in self ._datasets :
260
+ intersection_data_keys .intersection_update (dataset .hf_dataset .features )
261
+ if len (intersection_data_keys ) == 0 :
262
+ raise RuntimeError (
263
+ "Multiple datasets were provided but they had no keys common to all of them. The "
264
+ "multi-dataset functionality currently only keeps common keys."
265
+ )
266
+ for repo_id , dataset in zip (self .repo_ids , self ._datasets , strict = True ):
267
+ extra_keys = set (dataset .hf_dataset .features ).difference (intersection_data_keys )
268
+ logging .warning (
269
+ f"keys { extra_keys } of { repo_id } were disabled as they are not contained in all the "
270
+ "other datasets."
271
+ )
272
+ self .disabled_data_keys .update (extra_keys )
273
+
274
+ self .version = version
275
+ self .root = root
276
+ self .split = split
277
+ self .transform = transform
278
+ self .delta_timestamps = delta_timestamps
279
+ self .stats = aggregate_stats (self ._datasets )
280
+
281
+ @property
282
+ def repo_id_to_index (self ):
283
+ """Return a mapping from dataset repo_id to a dataset index automatically created by this class.
284
+
285
+ This index is incorporated as a data key in the dictionary returned by `__getitem__`.
286
+ """
287
+ return {repo_id : i for i , repo_id in enumerate (self .repo_ids )}
288
+
289
+ @property
290
+ def repo_index_to_id (self ):
291
+ """Return the inverse mapping if repo_id_to_index."""
292
+ return {v : k for k , v in self .repo_id_to_index }
293
+
294
+ @property
295
+ def fps (self ) -> int :
296
+ """Frames per second used during data collection.
297
+
298
+ NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
299
+ """
300
+ return self ._datasets [0 ].info ["fps" ]
301
+
302
+ @property
303
+ def video (self ) -> bool :
304
+ """Returns True if this dataset loads video frames from mp4 files.
305
+
306
+ Returns False if it only loads images from png files.
307
+
308
+ NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
309
+ """
310
+ return self ._datasets [0 ].info .get ("video" , False )
311
+
312
+ @property
313
+ def features (self ) -> datasets .Features :
314
+ features = {}
315
+ for dataset in self ._datasets :
316
+ features .update ({k : v for k , v in dataset .features .items () if k not in self .disabled_data_keys })
317
+ return features
318
+
319
+ @property
320
+ def camera_keys (self ) -> list [str ]:
321
+ """Keys to access image and video stream from cameras."""
322
+ keys = []
323
+ for key , feats in self .features .items ():
324
+ if isinstance (feats , (datasets .Image , VideoFrame )):
325
+ keys .append (key )
326
+ return keys
327
+
328
+ @property
329
+ def video_frame_keys (self ) -> list [str ]:
330
+ """Keys to access video frames that requires to be decoded into images.
331
+
332
+ Note: It is empty if the dataset contains images only,
333
+ or equal to `self.cameras` if the dataset contains videos only,
334
+ or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
335
+ """
336
+ video_frame_keys = []
337
+ for key , feats in self .features .items ():
338
+ if isinstance (feats , VideoFrame ):
339
+ video_frame_keys .append (key )
340
+ return video_frame_keys
341
+
342
+ @property
343
+ def num_samples (self ) -> int :
344
+ """Number of samples/frames."""
345
+ return sum (d .num_samples for d in self ._datasets )
346
+
347
+ @property
348
+ def num_episodes (self ) -> int :
349
+ """Number of episodes."""
350
+ return sum (d .num_episodes for d in self ._datasets )
351
+
352
+ @property
353
+ def tolerance_s (self ) -> float :
354
+ """Tolerance in seconds used to discard loaded frames when their timestamps
355
+ are not close enough from the requested frames. It is only used when `delta_timestamps`
356
+ is provided or when loading video frames from mp4 files.
357
+ """
358
+ # 1e-4 to account for possible numerical error
359
+ return 1 / self .fps - 1e-4
360
+
361
+ def __len__ (self ):
362
+ return self .num_samples
363
+
364
+ def __getitem__ (self , idx : int ) -> dict [str , torch .Tensor ]:
365
+ if idx >= len (self ):
366
+ raise IndexError (f"Index { idx } out of bounds." )
367
+ # Determine which dataset to get an item from based on the index.
368
+ start_idx = 0
369
+ dataset_idx = 0
370
+ for dataset in self ._datasets :
371
+ if idx >= start_idx + dataset .num_samples :
372
+ start_idx += dataset .num_samples
373
+ dataset_idx += 1
374
+ break
375
+ else :
376
+ raise AssertionError ("We expect the loop to break out as long as the index is within bounds." )
377
+ item = self ._datasets [dataset_idx ][idx - start_idx ]
378
+ item ["dataset_index" ] = torch .tensor (dataset_idx )
379
+ for data_key in self .disabled_data_keys :
380
+ if data_key in item :
381
+ del item [data_key ]
382
+ return item
383
+
384
+ def __repr__ (self ):
385
+ return (
386
+ f"{ self .__class__ .__name__ } (\n "
387
+ f" Repository IDs: '{ self .repo_ids } ',\n "
388
+ f" Version: '{ self .version } ',\n "
389
+ f" Split: '{ self .split } ',\n "
390
+ f" Number of Samples: { self .num_samples } ,\n "
391
+ f" Number of Episodes: { self .num_episodes } ,\n "
392
+ f" Type: { 'video (.mp4)' if self .video else 'image (.png)' } ,\n "
393
+ f" Recorded Frames per Second: { self .fps } ,\n "
394
+ f" Camera Keys: { self .camera_keys } ,\n "
395
+ f" Video Frame Keys: { self .video_frame_keys if self .video else 'N/A' } ,\n "
396
+ f" Transformations: { self .transform } ,\n "
397
+ f")"
398
+ )
0 commit comments