Skip to content

Commit 25a8597

Browse files
authored
[viz] Fixes & updates to html visualizer (#617)
1 parent b8b3683 commit 25a8597

File tree

2 files changed

+56
-46
lines changed

2 files changed

+56
-46
lines changed

lerobot/scripts/visualize_dataset_html.py

Lines changed: 26 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -232,69 +232,54 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
232232
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
233233
This file will be loaded by Dygraph javascript to plot data in real time."""
234234
columns = []
235-
has_state = "observation.state" in dataset.features
236-
has_action = "action" in dataset.features
235+
236+
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
237+
selected_columns.remove("timestamp")
237238

238239
# init header of csv with state and action names
239240
header = ["timestamp"]
240-
if has_state:
241+
242+
for column_name in selected_columns:
241243
dim_state = (
242-
dataset.meta.shapes["observation.state"][0]
243-
if isinstance(dataset, LeRobotDataset)
244-
else dataset.features["observation.state"].shape[0]
245-
)
246-
header += [f"state_{i}" for i in range(dim_state)]
247-
column_names = dataset.features["observation.state"]["names"]
248-
while not isinstance(column_names, list):
249-
column_names = list(column_names.values())[0]
250-
columns.append({"key": "state", "value": column_names})
251-
if has_action:
252-
dim_action = (
253-
dataset.meta.shapes["action"][0]
244+
dataset.meta.shapes[column_name][0]
254245
if isinstance(dataset, LeRobotDataset)
255-
else dataset.features.action.shape[0]
246+
else dataset.features[column_name].shape[0]
256247
)
257-
header += [f"action_{i}" for i in range(dim_action)]
258-
column_names = dataset.features["action"]["names"]
259-
while not isinstance(column_names, list):
260-
column_names = list(column_names.values())[0]
261-
columns.append({"key": "action", "value": column_names})
248+
header += [f"{column_name}_{i}" for i in range(dim_state)]
249+
250+
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
251+
column_names = dataset.features[column_name]["names"]
252+
while not isinstance(column_names, list):
253+
column_names = list(column_names.values())[0]
254+
else:
255+
column_names = [f"motor_{i}" for i in range(dim_state)]
256+
columns.append({"key": column_name, "value": column_names})
257+
258+
selected_columns.insert(0, "timestamp")
262259

263260
if isinstance(dataset, LeRobotDataset):
264261
from_idx = dataset.episode_data_index["from"][episode_index]
265262
to_idx = dataset.episode_data_index["to"][episode_index]
266-
selected_columns = ["timestamp"]
267-
if has_state:
268-
selected_columns += ["observation.state"]
269-
if has_action:
270-
selected_columns += ["action"]
271263
data = (
272264
dataset.hf_dataset.select(range(from_idx, to_idx))
273265
.select_columns(selected_columns)
274-
.with_format("numpy")
266+
.with_format("pandas")
275267
)
276-
rows = np.hstack(
277-
(np.expand_dims(data["timestamp"], axis=1), *[data[col] for col in selected_columns[1:]])
278-
).tolist()
279268
else:
280269
repo_id = dataset.repo_id
281-
selected_columns = ["timestamp"]
282-
if "observation.state" in dataset.features:
283-
selected_columns.append("observation.state")
284-
if "action" in dataset.features:
285-
selected_columns.append("action")
286270

287271
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
288272
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
289273
)
290274
df = pd.read_parquet(url)
291275
data = df[selected_columns] # Select specific columns
292-
rows = np.hstack(
293-
(
294-
np.expand_dims(data["timestamp"], axis=1),
295-
*[np.vstack(data[col]) for col in selected_columns[1:]],
296-
)
297-
).tolist()
276+
277+
rows = np.hstack(
278+
(
279+
np.expand_dims(data["timestamp"], axis=1),
280+
*[np.vstack(data[col]) for col in selected_columns[1:]],
281+
)
282+
).tolist()
298283

299284
# Convert data to CSV string
300285
csv_buffer = StringIO()
@@ -379,10 +364,6 @@ def visualize_dataset_html(
379364
template_folder=template_dir,
380365
)
381366
else:
382-
image_keys = dataset.meta.image_keys if isinstance(dataset, LeRobotDataset) else []
383-
if len(image_keys) > 0:
384-
raise NotImplementedError(f"Image keys ({image_keys=}) are currently not supported.")
385-
386367
# Create a simlink from the dataset video folder containg mp4 files to the output directory
387368
# so that the http server can get access to the mp4 files.
388369
if isinstance(dataset, LeRobotDataset):

lerobot/templates/visualize_dataset_template.html

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,34 @@ <h1 class="text-xl font-bold mt-4 font-mono">
9898
</div>
9999

100100
<!-- Videos -->
101+
<div class="max-w-32 relative text-sm mb-4 select-none"
102+
@click.outside="isVideosDropdownOpen = false">
103+
<div
104+
@click="isVideosDropdownOpen = !isVideosDropdownOpen"
105+
class="p-2 border border-slate-500 rounded flex justify-between items-center cursor-pointer"
106+
>
107+
<span class="truncate">filter videos</span>
108+
<div class="transition-transform" :class="{ 'rotate-180': isVideosDropdownOpen }">🔽</div>
109+
</div>
110+
111+
<div x-show="isVideosDropdownOpen"
112+
class="absolute mt-1 border border-slate-500 rounded shadow-lg z-10">
113+
<div>
114+
<template x-for="option in videosKeys" :key="option">
115+
<div
116+
@click="videosKeysSelected = videosKeysSelected.includes(option) ? videosKeysSelected.filter(v => v !== option) : [...videosKeysSelected, option]"
117+
class="p-2 cursor-pointer bg-slate-900"
118+
:class="{ 'bg-slate-700': videosKeysSelected.includes(option) }"
119+
x-text="option"
120+
></div>
121+
</template>
122+
</div>
123+
</div>
124+
</div>
125+
101126
<div class="flex flex-wrap gap-x-2 gap-y-6">
102127
{% for video_info in videos_info %}
103-
<div x-show="!videoCodecError" class="max-w-96 relative">
128+
<div x-show="!videoCodecError && videosKeysSelected.includes('{{ video_info.filename }}')" class="max-w-96 relative">
104129
<p class="absolute inset-x-0 -top-4 text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
105130
<video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => {
106131
if (video.duration) {
@@ -250,6 +275,9 @@ <h1 class="text-xl font-bold mt-4 font-mono">
250275
nVideos: {{ videos_info | length }},
251276
nVideoReadyToPlay: 0,
252277
videoCodecError: false,
278+
isVideosDropdownOpen: false,
279+
videosKeys: {{ videos_info | map(attribute='filename') | list | tojson }},
280+
videosKeysSelected: [],
253281
columns: {{ columns | tojson }},
254282
rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value,
255283

@@ -261,6 +289,7 @@ <h1 class="text-xl font-bold mt-4 font-mono">
261289
if(!canPlayVideos){
262290
this.videoCodecError = true;
263291
}
292+
this.videosKeysSelected = this.videosKeys.map(opt => opt)
264293

265294
// process CSV data
266295
const csvDataStr = {{ episode_data_csv_str|tojson|safe }};

0 commit comments

Comments
 (0)