@@ -232,69 +232,54 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
232
232
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
233
233
This file will be loaded by Dygraph javascript to plot data in real time."""
234
234
columns = []
235
- has_state = "observation.state" in dataset .features
236
- has_action = "action" in dataset .features
235
+
236
+ selected_columns = [col for col , ft in dataset .features .items () if ft ["dtype" ] == "float32" ]
237
+ selected_columns .remove ("timestamp" )
237
238
238
239
# init header of csv with state and action names
239
240
header = ["timestamp" ]
240
- if has_state :
241
+
242
+ for column_name in selected_columns :
241
243
dim_state = (
242
- dataset .meta .shapes ["observation.state" ][0 ]
243
- if isinstance (dataset , LeRobotDataset )
244
- else dataset .features ["observation.state" ].shape [0 ]
245
- )
246
- header += [f"state_{ i } " for i in range (dim_state )]
247
- column_names = dataset .features ["observation.state" ]["names" ]
248
- while not isinstance (column_names , list ):
249
- column_names = list (column_names .values ())[0 ]
250
- columns .append ({"key" : "state" , "value" : column_names })
251
- if has_action :
252
- dim_action = (
253
- dataset .meta .shapes ["action" ][0 ]
244
+ dataset .meta .shapes [column_name ][0 ]
254
245
if isinstance (dataset , LeRobotDataset )
255
- else dataset .features . action .shape [0 ]
246
+ else dataset .features [ column_name ] .shape [0 ]
256
247
)
257
- header += [f"action_{ i } " for i in range (dim_action )]
258
- column_names = dataset .features ["action" ]["names" ]
259
- while not isinstance (column_names , list ):
260
- column_names = list (column_names .values ())[0 ]
261
- columns .append ({"key" : "action" , "value" : column_names })
248
+ header += [f"{ column_name } _{ i } " for i in range (dim_state )]
249
+
250
+ if "names" in dataset .features [column_name ] and dataset .features [column_name ]["names" ]:
251
+ column_names = dataset .features [column_name ]["names" ]
252
+ while not isinstance (column_names , list ):
253
+ column_names = list (column_names .values ())[0 ]
254
+ else :
255
+ column_names = [f"motor_{ i } " for i in range (dim_state )]
256
+ columns .append ({"key" : column_name , "value" : column_names })
257
+
258
+ selected_columns .insert (0 , "timestamp" )
262
259
263
260
if isinstance (dataset , LeRobotDataset ):
264
261
from_idx = dataset .episode_data_index ["from" ][episode_index ]
265
262
to_idx = dataset .episode_data_index ["to" ][episode_index ]
266
- selected_columns = ["timestamp" ]
267
- if has_state :
268
- selected_columns += ["observation.state" ]
269
- if has_action :
270
- selected_columns += ["action" ]
271
263
data = (
272
264
dataset .hf_dataset .select (range (from_idx , to_idx ))
273
265
.select_columns (selected_columns )
274
- .with_format ("numpy " )
266
+ .with_format ("pandas " )
275
267
)
276
- rows = np .hstack (
277
- (np .expand_dims (data ["timestamp" ], axis = 1 ), * [data [col ] for col in selected_columns [1 :]])
278
- ).tolist ()
279
268
else :
280
269
repo_id = dataset .repo_id
281
- selected_columns = ["timestamp" ]
282
- if "observation.state" in dataset .features :
283
- selected_columns .append ("observation.state" )
284
- if "action" in dataset .features :
285
- selected_columns .append ("action" )
286
270
287
271
url = f"https://huggingface.co/datasets/{ repo_id } /resolve/main/" + dataset .data_path .format (
288
272
episode_chunk = int (episode_index ) // dataset .chunks_size , episode_index = episode_index
289
273
)
290
274
df = pd .read_parquet (url )
291
275
data = df [selected_columns ] # Select specific columns
292
- rows = np .hstack (
293
- (
294
- np .expand_dims (data ["timestamp" ], axis = 1 ),
295
- * [np .vstack (data [col ]) for col in selected_columns [1 :]],
296
- )
297
- ).tolist ()
276
+
277
+ rows = np .hstack (
278
+ (
279
+ np .expand_dims (data ["timestamp" ], axis = 1 ),
280
+ * [np .vstack (data [col ]) for col in selected_columns [1 :]],
281
+ )
282
+ ).tolist ()
298
283
299
284
# Convert data to CSV string
300
285
csv_buffer = StringIO ()
@@ -379,10 +364,6 @@ def visualize_dataset_html(
379
364
template_folder = template_dir ,
380
365
)
381
366
else :
382
- image_keys = dataset .meta .image_keys if isinstance (dataset , LeRobotDataset ) else []
383
- if len (image_keys ) > 0 :
384
- raise NotImplementedError (f"Image keys ({ image_keys = } ) are currently not supported." )
385
-
386
367
# Create a simlink from the dataset video folder containg mp4 files to the output directory
387
368
# so that the http server can get access to the mp4 files.
388
369
if isinstance (dataset , LeRobotDataset ):
0 commit comments