10
10
//! See [`LeRobotDataset`] for more information on the dataset format.
11
11
12
12
use std:: borrow:: Cow ;
13
+ use std:: fmt;
13
14
use std:: fs:: File ;
14
15
use std:: io:: BufReader ;
15
16
use std:: path:: { Path , PathBuf } ;
16
17
17
18
use ahash:: HashMap ;
18
19
use arrow:: array:: RecordBatch ;
19
20
use parquet:: arrow:: arrow_reader:: ParquetRecordBatchReaderBuilder ;
20
- use serde:: de:: DeserializeOwned ;
21
- use serde:: { Deserialize , Serialize } ;
21
+ use serde:: de:: { DeserializeOwned , MapAccess , SeqAccess , Visitor } ;
22
+ use serde:: { Deserialize , Deserializer , Serialize } ;
22
23
23
24
/// Check whether the provided path contains a `LeRobot` dataset.
24
25
pub fn is_lerobot_dataset ( path : impl AsRef < Path > ) -> bool {
@@ -388,6 +389,34 @@ pub struct Feature {
388
389
pub names : Option < Names > ,
389
390
}
390
391
392
+ impl Feature {
393
+ /// Get the channel dimension for this [`Feature`].
394
+ ///
395
+ /// Returns the number of channels in the feature's data representation.
396
+ ///
397
+ /// # Note
398
+ ///
399
+ /// This is primarily intended for [`DType::Image`] and [`DType::Video`] features,
400
+ /// where it represents color channels (e.g., 3 for RGB, 4 for RGBA).
401
+ /// For other feature types, this function returns the size of the last dimension
402
+ /// from the feature's shape.
403
+ pub fn channel_dim ( & self ) -> usize {
404
+ // first check if there's a "channels" name, if there is we can use that index.
405
+ if let Some ( names) = & self . names {
406
+ if let Some ( channel_idx) = names. 0 . iter ( ) . position ( |name| name == "channels" ) {
407
+ // If channel_idx is within bounds of shape, return that dimension
408
+ if channel_idx < self . shape . len ( ) {
409
+ return self . shape [ channel_idx] ;
410
+ }
411
+ }
412
+ }
413
+
414
+ // Default to the last dimension if no channels name is found
415
+ // or if the found index is out of bounds
416
+ self . shape . last ( ) . copied ( ) . unwrap_or ( 0 )
417
+ }
418
+ }
419
+
391
420
/// Data types supported for features in a `LeRobot` dataset.
392
421
#[ derive( Serialize , Deserialize , Debug , Clone , Copy , PartialEq , Eq ) ]
393
422
#[ serde( rename_all = "snake_case" ) ]
@@ -397,6 +426,7 @@ pub enum DType {
397
426
Bool ,
398
427
Float32 ,
399
428
Float64 ,
429
+ Int16 ,
400
430
Int64 ,
401
431
String ,
402
432
}
@@ -406,56 +436,129 @@ pub enum DType {
406
436
/// The name metadata can consist of
407
437
/// - A flat list of names for each dimension of a feature (e.g., `["height", "width", "channel"]`).
408
438
/// - A nested list of names for each dimension of a feature (e.g., `[[""kLeftShoulderPitch", "kLeftShoulderRoll"]]`)
409
- /// - A list specific to motors (e.g., `{ "motors": ["motor_0", "motor_1", ...] }`).
410
- #[ derive( Debug , Serialize , Deserialize , Clone ) ]
411
- #[ serde( untagged) ]
412
- pub enum Names {
413
- Motors { motors : Vec < String > } ,
414
- List ( NamesList ) ,
415
- }
439
+ /// - A map with a string array value (e.g., `{ "motors": ["motor_0", "motor_1", ...] }` or `{ "axes": ["x", "y", "z"] }`).
440
+ #[ derive( Debug , Clone , PartialEq , Eq , Serialize ) ]
441
+ pub struct Names ( Vec < String > ) ;
416
442
417
443
impl Names {
418
- /// Retrieves the name corresponding to a specific index within the `names` field of a feature.
444
+ /// Retrieves the name corresponding to a specific index.
445
+ ///
446
+ /// Returns `None` if the index is out of bounds.
419
447
pub fn name_for_index ( & self , index : usize ) -> Option < & String > {
420
- match self {
421
- Self :: Motors { motors } => motors. get ( index) ,
422
- Self :: List ( NamesList ( items) ) => items. get ( index) ,
423
- }
448
+ self . 0 . get ( index)
424
449
}
425
450
}
426
451
427
- /// A wrapper struct that deserializes flat or nested lists of strings
428
- /// into a single flattened [`Vec`] of names for easy indexing.
429
- #[ derive( Debug , Serialize , Clone ) ]
430
- pub struct NamesList ( Vec < String > ) ;
452
+ /// Visitor implementation for deserializing the [`Names`] type.
453
+ ///
454
+ /// Handles multiple representation formats:
455
+ /// - Flat string arrays: `["x", "y", "z"]`
456
+ /// - Nested string arrays: `[["motor_1", "motor_2"]]`
457
+ /// - Single-entry objects: `{"motors": ["motor_1", "motor_2"]}` or `{"axes": null}`
458
+ ///
459
+ /// See the `Names` type documentation for more details on the supported formats.
460
+ struct NamesVisitor ;
431
461
432
- impl < ' de > Deserialize < ' de > for NamesList {
433
- fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
462
+ impl < ' de > Visitor < ' de > for NamesVisitor {
463
+ type Value = Names ;
464
+
465
+ fn expecting ( & self , formatter : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
466
+ formatter. write_str (
467
+ "a flat string array, a nested string array, or a single-entry object with a string array or null value" ,
468
+ )
469
+ }
470
+
471
+ /// Handle sequences:
472
+ /// - Flat string arrays: `["x", "y", "z"]`
473
+ /// - Nested string arrays: `[["motor_1", "motor_2"]]`
474
+ fn visit_seq < A > ( self , mut seq : A ) -> Result < Self :: Value , A :: Error >
434
475
where
435
- D : serde :: Deserializer < ' de > ,
476
+ A : SeqAccess < ' de > ,
436
477
{
437
- let value = serde_json:: Value :: deserialize ( deserializer) ?;
438
- if let serde_json:: Value :: Array ( arr) = value {
439
- if arr. is_empty ( ) {
440
- return Ok ( Self ( vec ! [ ] ) ) ;
441
- }
442
- if let Some ( first) = arr. first ( ) {
443
- if first. is_string ( ) {
444
- let flat: Vec < String > = serde_json:: from_value ( serde_json:: Value :: Array ( arr) )
445
- . map_err ( serde:: de:: Error :: custom) ?;
446
- return Ok ( Self ( flat) ) ;
447
- } else if first. is_array ( ) {
448
- let nested: Vec < Vec < String > > =
449
- serde_json:: from_value ( serde_json:: Value :: Array ( arr) )
450
- . map_err ( serde:: de:: Error :: custom) ?;
451
- let flat = nested. into_iter ( ) . flatten ( ) . collect ( ) ;
452
- return Ok ( Self ( flat) ) ;
478
+ // Helper enum to deserialize sequence elements
479
+ #[ derive( Deserialize ) ]
480
+ #[ serde( untagged) ]
481
+ enum ListItem {
482
+ Str ( String ) ,
483
+ List ( Vec < String > ) ,
484
+ }
485
+
486
+ /// Enum to track the list type
487
+ #[ derive( PartialEq ) ]
488
+ enum ListType {
489
+ Undetermined ,
490
+ Flat ,
491
+ Nested ,
492
+ }
493
+
494
+ let mut names = Vec :: new ( ) ;
495
+ let mut determined_type = ListType :: Undetermined ;
496
+
497
+ while let Some ( item) = seq. next_element :: < ListItem > ( ) ? {
498
+ match item {
499
+ ListItem :: Str ( s) => {
500
+ if determined_type == ListType :: Nested {
501
+ return Err ( serde:: de:: Error :: custom (
502
+ "Cannot mix nested lists with flat strings within names array" ,
503
+ ) ) ;
504
+ }
505
+ determined_type = ListType :: Flat ;
506
+ names. push ( s) ;
507
+ }
508
+ ListItem :: List ( list) => {
509
+ if determined_type == ListType :: Flat {
510
+ return Err ( serde:: de:: Error :: custom (
511
+ "Cannot mix flat strings and nested lists within names array" ,
512
+ ) ) ;
513
+ }
514
+ determined_type = ListType :: Nested ;
515
+
516
+ // Flatten the nested list
517
+ names. extend ( list) ;
453
518
}
454
519
}
455
520
}
456
- Err ( serde:: de:: Error :: custom (
457
- "Unsupported name format in LeRobot dataset!" ,
458
- ) )
521
+
522
+ Ok ( Names ( names) )
523
+ }
524
+
525
+ /// Handle single-entry objects: `{"motors": ["motor_1", "motor_2"]}` or `{"axes": null}`
526
+ fn visit_map < A > ( self , mut map : A ) -> Result < Self :: Value , A :: Error >
527
+ where
528
+ A : MapAccess < ' de > ,
529
+ {
530
+ let mut names_vec: Option < Vec < String > > = None ;
531
+ let mut entry_count = 0 ;
532
+
533
+ // We expect exactly one entry.
534
+ while let Some ( ( _key, value) ) = map. next_entry :: < String , Option < Vec < String > > > ( ) ? {
535
+ entry_count += 1 ;
536
+ if entry_count > 1 {
537
+ // Consume remaining entries to be a good citizen before erroring
538
+ while map
539
+ . next_entry :: < serde:: de:: IgnoredAny , serde:: de:: IgnoredAny > ( ) ?
540
+ . is_some ( )
541
+ { }
542
+
543
+ return Err ( serde:: de:: Error :: invalid_length (
544
+ entry_count,
545
+ & "a Names object with exactly one entry." ,
546
+ ) ) ;
547
+ }
548
+
549
+ names_vec = Some ( value. unwrap_or_default ( ) ) ;
550
+ }
551
+
552
+ Ok ( Names ( names_vec. unwrap_or_default ( ) ) )
553
+ }
554
+ }
555
+
556
+ impl < ' de > Deserialize < ' de > for Names {
557
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
558
+ where
559
+ D : Deserializer < ' de > ,
560
+ {
561
+ deserializer. deserialize_any ( NamesVisitor )
459
562
}
460
563
}
461
564
@@ -504,3 +607,96 @@ pub struct LeRobotDatasetTask {
504
607
pub index : TaskIndex ,
505
608
pub task : String ,
506
609
}
610
+
611
+ #[ cfg( test) ]
612
+ mod tests {
613
+ use super :: * ;
614
+ use serde_json;
615
+
616
+ #[ test]
617
+ fn test_deserialize_flat_list ( ) {
618
+ let json = r#"["a", "b", "c"]"# ;
619
+ let expected = Names ( vec ! [ "a" . to_owned( ) , "b" . to_owned( ) , "c" . to_owned( ) ] ) ;
620
+ let names: Names = serde_json:: from_str ( json) . unwrap ( ) ;
621
+ assert_eq ! ( names, expected) ;
622
+ }
623
+
624
+ #[ test]
625
+ fn test_deserialize_nested_list ( ) {
626
+ let json = r#"[["a", "b"], ["c"]]"# ;
627
+ let expected = Names ( vec ! [ "a" . to_owned( ) , "b" . to_owned( ) , "c" . to_owned( ) ] ) ;
628
+ let names: Names = serde_json:: from_str ( json) . unwrap ( ) ;
629
+ assert_eq ! ( names, expected) ;
630
+ }
631
+
632
+ #[ test]
633
+ fn test_deserialize_empty_nested_list ( ) {
634
+ let json = r#"[[], []]"# ;
635
+ let expected = Names ( vec ! [ ] ) ;
636
+ let names: Names = serde_json:: from_str ( json) . unwrap ( ) ;
637
+ assert_eq ! ( names, expected) ;
638
+ }
639
+
640
+ #[ test]
641
+ fn test_deserialize_empty_list ( ) {
642
+ let json = r#"[]"# ;
643
+ let expected = Names ( vec ! [ ] ) ;
644
+ let names: Names = serde_json:: from_str ( json) . unwrap ( ) ;
645
+ assert_eq ! ( names, expected) ;
646
+ }
647
+
648
+ #[ test]
649
+ fn test_deserialize_object_with_list ( ) {
650
+ let json = r#"{ "axes": ["x", "y", "z"] }"# ;
651
+ let expected = Names ( vec ! [ "x" . to_owned( ) , "y" . to_owned( ) , "z" . to_owned( ) ] ) ;
652
+ let names: Names = serde_json:: from_str ( json) . unwrap ( ) ;
653
+ assert_eq ! ( names, expected) ;
654
+ }
655
+
656
+ #[ test]
657
+ fn test_deserialize_object_with_empty_list ( ) {
658
+ let json = r#"{ "motors": [] }"# ;
659
+ let expected = Names ( vec ! [ ] ) ;
660
+ let names: Names = serde_json:: from_str ( json) . unwrap ( ) ;
661
+ assert_eq ! ( names, expected) ;
662
+ }
663
+
664
+ #[ test]
665
+ fn test_deserialize_object_with_null ( ) {
666
+ let json = r#"{ "axes": null }"# ;
667
+ let expected = Names ( vec ! [ ] ) ; // Null results in an empty list
668
+ let names: Names = serde_json:: from_str ( json) . unwrap ( ) ;
669
+ assert_eq ! ( names, expected) ;
670
+ }
671
+
672
+ #[ test]
673
+ fn test_deserialize_empty_object ( ) {
674
+ // Empty object results in empty list.
675
+ let json = r#"{}"# ;
676
+ let expected = Names ( vec ! [ ] ) ;
677
+ let names: Names = serde_json:: from_str ( json) . unwrap ( ) ;
678
+ assert_eq ! ( names, expected) ;
679
+ }
680
+
681
+ #[ test]
682
+ fn test_deserialize_error_mixed_list ( ) {
683
+ let json = r#"["a", ["b"]]"# ; // Mixed flat and nested
684
+ let result: Result < Names , _ > = serde_json:: from_str ( json) ;
685
+ assert ! ( result. is_err( ) ) ;
686
+ assert ! ( result
687
+ . unwrap_err( )
688
+ . to_string( )
689
+ . contains( "Cannot mix flat strings and nested lists" ) ) ;
690
+ }
691
+
692
+ #[ test]
693
+ fn test_deserialize_error_object_multiple_entries ( ) {
694
+ let json = r#"{ "axes": ["x"], "motors": ["m"] }"# ;
695
+ let result: Result < Names , _ > = serde_json:: from_str ( json) ;
696
+ assert ! ( result. is_err( ) ) ;
697
+ assert ! ( result
698
+ . unwrap_err( )
699
+ . to_string( )
700
+ . contains( "a Names object with exactly one entry" ) ) ;
701
+ }
702
+ }
0 commit comments