Skip to content

Commit fe08458

Browse files
oxkitsunejprochazk
authored andcommitted
Infer image/video channel index from LeRobot metadata (#9435)
### What Most LeRobot datasets come with a name for each dimension of an `image` or `video` feature. This means we can attempt to infer the index, and do less guessing. Additionally I rewrote the `Names` parsing to be a bit more general and support even more possible ways to name the dimensions of your feature!
1 parent f37a266 commit fe08458

File tree

2 files changed

+240
-48
lines changed

2 files changed

+240
-48
lines changed

crates/store/re_data_loader/src/lerobot.rs

+236-40
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
//! See [`LeRobotDataset`] for more information on the dataset format.
1111
1212
use std::borrow::Cow;
13+
use std::fmt;
1314
use std::fs::File;
1415
use std::io::BufReader;
1516
use std::path::{Path, PathBuf};
1617

1718
use ahash::HashMap;
1819
use arrow::array::RecordBatch;
1920
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
20-
use serde::de::DeserializeOwned;
21-
use serde::{Deserialize, Serialize};
21+
use serde::de::{DeserializeOwned, MapAccess, SeqAccess, Visitor};
22+
use serde::{Deserialize, Deserializer, Serialize};
2223

2324
/// Check whether the provided path contains a `LeRobot` dataset.
2425
pub fn is_lerobot_dataset(path: impl AsRef<Path>) -> bool {
@@ -388,6 +389,34 @@ pub struct Feature {
388389
pub names: Option<Names>,
389390
}
390391

392+
impl Feature {
393+
/// Get the channel dimension for this [`Feature`].
394+
///
395+
/// Returns the number of channels in the feature's data representation.
396+
///
397+
/// # Note
398+
///
399+
/// This is primarily intended for [`DType::Image`] and [`DType::Video`] features,
400+
/// where it represents color channels (e.g., 3 for RGB, 4 for RGBA).
401+
/// For other feature types, this function returns the size of the last dimension
402+
/// from the feature's shape.
403+
pub fn channel_dim(&self) -> usize {
404+
// first check if there's a "channels" name, if there is we can use that index.
405+
if let Some(names) = &self.names {
406+
if let Some(channel_idx) = names.0.iter().position(|name| name == "channels") {
407+
// If channel_idx is within bounds of shape, return that dimension
408+
if channel_idx < self.shape.len() {
409+
return self.shape[channel_idx];
410+
}
411+
}
412+
}
413+
414+
// Default to the last dimension if no channels name is found
415+
// or if the found index is out of bounds
416+
self.shape.last().copied().unwrap_or(0)
417+
}
418+
}
419+
391420
/// Data types supported for features in a `LeRobot` dataset.
392421
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
393422
#[serde(rename_all = "snake_case")]
@@ -397,6 +426,7 @@ pub enum DType {
397426
Bool,
398427
Float32,
399428
Float64,
429+
Int16,
400430
Int64,
401431
String,
402432
}
@@ -406,56 +436,129 @@ pub enum DType {
406436
/// The name metadata can consist of
407437
/// - A flat list of names for each dimension of a feature (e.g., `["height", "width", "channel"]`).
408438
/// - A nested list of names for each dimension of a feature (e.g., `[[""kLeftShoulderPitch", "kLeftShoulderRoll"]]`)
409-
/// - A list specific to motors (e.g., `{ "motors": ["motor_0", "motor_1", ...] }`).
410-
#[derive(Debug, Serialize, Deserialize, Clone)]
411-
#[serde(untagged)]
412-
pub enum Names {
413-
Motors { motors: Vec<String> },
414-
List(NamesList),
415-
}
439+
/// - A map with a string array value (e.g., `{ "motors": ["motor_0", "motor_1", ...] }` or `{ "axes": ["x", "y", "z"] }`).
440+
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
441+
pub struct Names(Vec<String>);
416442

417443
impl Names {
418-
/// Retrieves the name corresponding to a specific index within the `names` field of a feature.
444+
/// Retrieves the name corresponding to a specific index.
445+
///
446+
/// Returns `None` if the index is out of bounds.
419447
pub fn name_for_index(&self, index: usize) -> Option<&String> {
420-
match self {
421-
Self::Motors { motors } => motors.get(index),
422-
Self::List(NamesList(items)) => items.get(index),
423-
}
448+
self.0.get(index)
424449
}
425450
}
426451

427-
/// A wrapper struct that deserializes flat or nested lists of strings
428-
/// into a single flattened [`Vec`] of names for easy indexing.
429-
#[derive(Debug, Serialize, Clone)]
430-
pub struct NamesList(Vec<String>);
452+
/// Visitor implementation for deserializing the [`Names`] type.
453+
///
454+
/// Handles multiple representation formats:
455+
/// - Flat string arrays: `["x", "y", "z"]`
456+
/// - Nested string arrays: `[["motor_1", "motor_2"]]`
457+
/// - Single-entry objects: `{"motors": ["motor_1", "motor_2"]}` or `{"axes": null}`
458+
///
459+
/// See the `Names` type documentation for more details on the supported formats.
460+
struct NamesVisitor;
431461

432-
impl<'de> Deserialize<'de> for NamesList {
433-
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
462+
impl<'de> Visitor<'de> for NamesVisitor {
463+
type Value = Names;
464+
465+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
466+
formatter.write_str(
467+
"a flat string array, a nested string array, or a single-entry object with a string array or null value",
468+
)
469+
}
470+
471+
/// Handle sequences:
472+
/// - Flat string arrays: `["x", "y", "z"]`
473+
/// - Nested string arrays: `[["motor_1", "motor_2"]]`
474+
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
434475
where
435-
D: serde::Deserializer<'de>,
476+
A: SeqAccess<'de>,
436477
{
437-
let value = serde_json::Value::deserialize(deserializer)?;
438-
if let serde_json::Value::Array(arr) = value {
439-
if arr.is_empty() {
440-
return Ok(Self(vec![]));
441-
}
442-
if let Some(first) = arr.first() {
443-
if first.is_string() {
444-
let flat: Vec<String> = serde_json::from_value(serde_json::Value::Array(arr))
445-
.map_err(serde::de::Error::custom)?;
446-
return Ok(Self(flat));
447-
} else if first.is_array() {
448-
let nested: Vec<Vec<String>> =
449-
serde_json::from_value(serde_json::Value::Array(arr))
450-
.map_err(serde::de::Error::custom)?;
451-
let flat = nested.into_iter().flatten().collect();
452-
return Ok(Self(flat));
478+
// Helper enum to deserialize sequence elements
479+
#[derive(Deserialize)]
480+
#[serde(untagged)]
481+
enum ListItem {
482+
Str(String),
483+
List(Vec<String>),
484+
}
485+
486+
/// Enum to track the list type
487+
#[derive(PartialEq)]
488+
enum ListType {
489+
Undetermined,
490+
Flat,
491+
Nested,
492+
}
493+
494+
let mut names = Vec::new();
495+
let mut determined_type = ListType::Undetermined;
496+
497+
while let Some(item) = seq.next_element::<ListItem>()? {
498+
match item {
499+
ListItem::Str(s) => {
500+
if determined_type == ListType::Nested {
501+
return Err(serde::de::Error::custom(
502+
"Cannot mix nested lists with flat strings within names array",
503+
));
504+
}
505+
determined_type = ListType::Flat;
506+
names.push(s);
507+
}
508+
ListItem::List(list) => {
509+
if determined_type == ListType::Flat {
510+
return Err(serde::de::Error::custom(
511+
"Cannot mix flat strings and nested lists within names array",
512+
));
513+
}
514+
determined_type = ListType::Nested;
515+
516+
// Flatten the nested list
517+
names.extend(list);
453518
}
454519
}
455520
}
456-
Err(serde::de::Error::custom(
457-
"Unsupported name format in LeRobot dataset!",
458-
))
521+
522+
Ok(Names(names))
523+
}
524+
525+
/// Handle single-entry objects: `{"motors": ["motor_1", "motor_2"]}` or `{"axes": null}`
526+
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
527+
where
528+
A: MapAccess<'de>,
529+
{
530+
let mut names_vec: Option<Vec<String>> = None;
531+
let mut entry_count = 0;
532+
533+
// We expect exactly one entry.
534+
while let Some((_key, value)) = map.next_entry::<String, Option<Vec<String>>>()? {
535+
entry_count += 1;
536+
if entry_count > 1 {
537+
// Consume remaining entries to be a good citizen before erroring
538+
while map
539+
.next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
540+
.is_some()
541+
{}
542+
543+
return Err(serde::de::Error::invalid_length(
544+
entry_count,
545+
&"a Names object with exactly one entry.",
546+
));
547+
}
548+
549+
names_vec = Some(value.unwrap_or_default());
550+
}
551+
552+
Ok(Names(names_vec.unwrap_or_default()))
553+
}
554+
}
555+
556+
impl<'de> Deserialize<'de> for Names {
557+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
558+
where
559+
D: Deserializer<'de>,
560+
{
561+
deserializer.deserialize_any(NamesVisitor)
459562
}
460563
}
461564

@@ -504,3 +607,96 @@ pub struct LeRobotDatasetTask {
504607
pub index: TaskIndex,
505608
pub task: String,
506609
}
610+
611+
#[cfg(test)]
612+
mod tests {
613+
use super::*;
614+
use serde_json;
615+
616+
#[test]
617+
fn test_deserialize_flat_list() {
618+
let json = r#"["a", "b", "c"]"#;
619+
let expected = Names(vec!["a".to_owned(), "b".to_owned(), "c".to_owned()]);
620+
let names: Names = serde_json::from_str(json).unwrap();
621+
assert_eq!(names, expected);
622+
}
623+
624+
#[test]
625+
fn test_deserialize_nested_list() {
626+
let json = r#"[["a", "b"], ["c"]]"#;
627+
let expected = Names(vec!["a".to_owned(), "b".to_owned(), "c".to_owned()]);
628+
let names: Names = serde_json::from_str(json).unwrap();
629+
assert_eq!(names, expected);
630+
}
631+
632+
#[test]
633+
fn test_deserialize_empty_nested_list() {
634+
let json = r#"[[], []]"#;
635+
let expected = Names(vec![]);
636+
let names: Names = serde_json::from_str(json).unwrap();
637+
assert_eq!(names, expected);
638+
}
639+
640+
#[test]
641+
fn test_deserialize_empty_list() {
642+
let json = r#"[]"#;
643+
let expected = Names(vec![]);
644+
let names: Names = serde_json::from_str(json).unwrap();
645+
assert_eq!(names, expected);
646+
}
647+
648+
#[test]
649+
fn test_deserialize_object_with_list() {
650+
let json = r#"{ "axes": ["x", "y", "z"] }"#;
651+
let expected = Names(vec!["x".to_owned(), "y".to_owned(), "z".to_owned()]);
652+
let names: Names = serde_json::from_str(json).unwrap();
653+
assert_eq!(names, expected);
654+
}
655+
656+
#[test]
657+
fn test_deserialize_object_with_empty_list() {
658+
let json = r#"{ "motors": [] }"#;
659+
let expected = Names(vec![]);
660+
let names: Names = serde_json::from_str(json).unwrap();
661+
assert_eq!(names, expected);
662+
}
663+
664+
#[test]
665+
fn test_deserialize_object_with_null() {
666+
let json = r#"{ "axes": null }"#;
667+
let expected = Names(vec![]); // Null results in an empty list
668+
let names: Names = serde_json::from_str(json).unwrap();
669+
assert_eq!(names, expected);
670+
}
671+
672+
#[test]
673+
fn test_deserialize_empty_object() {
674+
// Empty object results in empty list.
675+
let json = r#"{}"#;
676+
let expected = Names(vec![]);
677+
let names: Names = serde_json::from_str(json).unwrap();
678+
assert_eq!(names, expected);
679+
}
680+
681+
#[test]
682+
fn test_deserialize_error_mixed_list() {
683+
let json = r#"["a", ["b"]]"#; // Mixed flat and nested
684+
let result: Result<Names, _> = serde_json::from_str(json);
685+
assert!(result.is_err());
686+
assert!(result
687+
.unwrap_err()
688+
.to_string()
689+
.contains("Cannot mix flat strings and nested lists"));
690+
}
691+
692+
#[test]
693+
fn test_deserialize_error_object_multiple_entries() {
694+
let json = r#"{ "axes": ["x"], "motors": ["m"] }"#;
695+
let result: Result<Names, _> = serde_json::from_str(json);
696+
assert!(result.is_err());
697+
assert!(result
698+
.unwrap_err()
699+
.to_string()
700+
.contains("a Names object with exactly one entry"));
701+
}
702+
}

crates/store/re_data_loader/src/loader_lerobot.rs

+4-8
Original file line numberDiff line numberDiff line change
@@ -247,24 +247,20 @@ pub fn load_episode(
247247
}
248248

249249
DType::Image => {
250-
let num_channels = feature.shape.last().with_context(|| {
251-
format!(
252-
"Image feature '{feature_key}' in LeRobot dataset is missing channel dimension",
253-
)
254-
})?;
250+
let num_channels = feature.channel_dim();
255251

256-
match *num_channels {
252+
match num_channels {
257253
1 => chunks.extend(load_episode_depth_images(feature_key, &timeline, &data)?),
258254
3 => chunks.extend(load_episode_images(feature_key, &timeline, &data)?),
259-
_ => re_log::warn_once!("Unsupported channel count {num_channels} for LeRobot dataset; Only 1- and 3-channel images are supported")
255+
_ => re_log::warn_once!("Unsupported channel count {num_channels} (shape: {:?}) for LeRobot dataset; Only 1- and 3-channel images are supported", feature.shape)
260256
};
261257
}
262258
DType::Int64 if feature_key == "task_index" => {
263259
// special case int64 task_index columns
264260
// this always refers to the task description in the dataset metadata.
265261
chunks.extend(log_episode_task(dataset, &timeline, &data)?);
266262
}
267-
DType::Int64 | DType::Bool | DType::String => {
263+
DType::Int16 | DType::Int64 | DType::Bool | DType::String => {
268264
re_log::warn_once!(
269265
"Loading LeRobot feature ({feature_key}) of dtype `{:?}` into Rerun is not yet implemented",
270266
feature.dtype

0 commit comments

Comments
 (0)