@@ -509,7 +509,23 @@ fn load_scalar(
509
509
make_scalar_batch_entity_chunks ( entity_path, feature, timelines, fixed_size_array) ?;
510
510
Ok ( ScalarChunkIterator :: Batch ( Box :: new ( batch_chunks) ) )
511
511
}
512
- DataType :: Float32 => {
512
+ DataType :: List ( _field) => {
513
+ let list_array = data
514
+ . column_by_name ( feature_key)
515
+ . and_then ( |col| col. downcast_array_ref :: < arrow:: array:: ListArray > ( ) )
516
+ . ok_or_else ( || {
517
+ DataLoaderError :: Other ( anyhow ! ( "Failed to downcast feature to ListArray" ) )
518
+ } ) ?;
519
+
520
+ let sliced = extract_list_array_elements_as_f64 ( list_array) . with_context ( || {
521
+ format ! ( "Failed to cast scalar feature {entity_path} to Float64" )
522
+ } ) ?;
523
+
524
+ Ok ( ScalarChunkIterator :: Single ( std:: iter:: once (
525
+ make_scalar_entity_chunk ( entity_path, timelines, & sliced) ?,
526
+ ) ) )
527
+ }
528
+ DataType :: Float32 | DataType :: Float64 => {
513
529
let feature_data = data. column_by_name ( feature_key) . ok_or_else ( || {
514
530
DataLoaderError :: Other ( anyhow ! (
515
531
"Failed to get LeRobot dataset column data for: {:?}" ,
@@ -546,7 +562,7 @@ fn make_scalar_batch_entity_chunks(
546
562
547
563
let mut chunks = Vec :: with_capacity ( num_elements) ;
548
564
549
- let sliced = extract_list_elements_as_f64 ( data)
565
+ let sliced = extract_fixed_size_list_array_elements_as_f64 ( data)
550
566
. with_context ( || format ! ( "Failed to cast scalar feature {entity_path} to Float64" ) ) ?;
551
567
552
568
chunks. push ( make_scalar_entity_chunk (
@@ -612,7 +628,20 @@ fn extract_scalar_slices_as_f64(data: &ArrayRef) -> anyhow::Result<Vec<ArrayRef>
612
628
. collect :: < Vec < _ > > ( ) )
613
629
}
614
630
615
- fn extract_list_elements_as_f64 ( data : & FixedSizeListArray ) -> anyhow:: Result < Vec < ArrayRef > > {
631
+ fn extract_fixed_size_list_array_elements_as_f64 (
632
+ data : & FixedSizeListArray ,
633
+ ) -> anyhow:: Result < Vec < ArrayRef > > {
634
+ ( 0 ..data. len ( ) )
635
+ . map ( |idx| {
636
+ cast ( & data. value ( idx) , & DataType :: Float64 )
637
+ . with_context ( || format ! ( "Failed to cast {:?} to Float64" , data. data_type( ) ) )
638
+ } )
639
+ . collect :: < Result < Vec < _ > , _ > > ( )
640
+ }
641
+
642
+ fn extract_list_array_elements_as_f64 (
643
+ data : & arrow:: array:: ListArray ,
644
+ ) -> anyhow:: Result < Vec < ArrayRef > > {
616
645
( 0 ..data. len ( ) )
617
646
. map ( |idx| {
618
647
cast ( & data. value ( idx) , & DataType :: Float64 )
0 commit comments