Skip to content

Commit 64b0d6a

Browse files
authored
LeRobot: Add support for List datatype (#9958)
Used by Phospho robots, for instance
1 parent 2d1b386 commit 64b0d6a

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

crates/store/re_data_loader/src/loader_lerobot.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,23 @@ fn load_scalar(
509509
make_scalar_batch_entity_chunks(entity_path, feature, timelines, fixed_size_array)?;
510510
Ok(ScalarChunkIterator::Batch(Box::new(batch_chunks)))
511511
}
512-
DataType::Float32 => {
512+
DataType::List(_field) => {
513+
let list_array = data
514+
.column_by_name(feature_key)
515+
.and_then(|col| col.downcast_array_ref::<arrow::array::ListArray>())
516+
.ok_or_else(|| {
517+
DataLoaderError::Other(anyhow!("Failed to downcast feature to ListArray"))
518+
})?;
519+
520+
let sliced = extract_list_array_elements_as_f64(list_array).with_context(|| {
521+
format!("Failed to cast scalar feature {entity_path} to Float64")
522+
})?;
523+
524+
Ok(ScalarChunkIterator::Single(std::iter::once(
525+
make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
526+
)))
527+
}
528+
DataType::Float32 | DataType::Float64 => {
513529
let feature_data = data.column_by_name(feature_key).ok_or_else(|| {
514530
DataLoaderError::Other(anyhow!(
515531
"Failed to get LeRobot dataset column data for: {:?}",
@@ -546,7 +562,7 @@ fn make_scalar_batch_entity_chunks(
546562

547563
let mut chunks = Vec::with_capacity(num_elements);
548564

549-
let sliced = extract_list_elements_as_f64(data)
565+
let sliced = extract_fixed_size_list_array_elements_as_f64(data)
550566
.with_context(|| format!("Failed to cast scalar feature {entity_path} to Float64"))?;
551567

552568
chunks.push(make_scalar_entity_chunk(
@@ -612,7 +628,20 @@ fn extract_scalar_slices_as_f64(data: &ArrayRef) -> anyhow::Result<Vec<ArrayRef>
612628
.collect::<Vec<_>>())
613629
}
614630

615-
fn extract_list_elements_as_f64(data: &FixedSizeListArray) -> anyhow::Result<Vec<ArrayRef>> {
631+
fn extract_fixed_size_list_array_elements_as_f64(
632+
data: &FixedSizeListArray,
633+
) -> anyhow::Result<Vec<ArrayRef>> {
634+
(0..data.len())
635+
.map(|idx| {
636+
cast(&data.value(idx), &DataType::Float64)
637+
.with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))
638+
})
639+
.collect::<Result<Vec<_>, _>>()
640+
}
641+
642+
fn extract_list_array_elements_as_f64(
643+
data: &arrow::array::ListArray,
644+
) -> anyhow::Result<Vec<ArrayRef>> {
616645
(0..data.len())
617646
.map(|idx| {
618647
cast(&data.value(idx), &DataType::Float64)

0 commit comments

Comments
 (0)