Skip to content

Commit 2c2e4e1

Browse files
Cadenethomwolf
andauthored
Add aloha_dora_format.py (#201)
Co-authored-by: Thomas Wolf <[email protected]>
1 parent 1331068 commit 2c2e4e1

File tree

2 files changed

+233
-1
lines changed

2 files changed

+233
-1
lines changed
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""
17+
Contains utilities to process raw data format from dora-record
18+
"""
19+
20+
import logging
21+
import re
22+
from pathlib import Path
23+
24+
import pandas as pd
25+
import torch
26+
from datasets import Dataset, Features, Image, Sequence, Value
27+
28+
from lerobot.common.datasets.utils import (
29+
hf_transform_to_torch,
30+
)
31+
from lerobot.common.datasets.video_utils import VideoFrame
32+
from lerobot.common.utils.utils import init_logging
33+
34+
35+
def check_format(raw_dir) -> bool:
36+
assert raw_dir.exists()
37+
38+
leader_file = list(raw_dir.glob("*.parquet"))
39+
if len(leader_file) == 0:
40+
raise ValueError(f"Missing parquet files in '{raw_dir}'")
41+
return True
42+
43+
44+
def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
45+
# Load data stream that will be used as reference for the timestamps synchronization
46+
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
47+
if len(reference_files) == 0:
48+
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
49+
# select first camera in alphanumeric order
50+
reference_key = sorted(reference_files)[0].stem
51+
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
52+
reference_df = reference_df[["timestamp_utc", reference_key]]
53+
54+
# Merge all data stream using nearest backward strategy
55+
df = reference_df
56+
for path in raw_dir.glob("*.parquet"):
57+
key = path.stem # action or observation.state or ...
58+
if key == reference_key:
59+
continue
60+
if "failed_episode_index" in key:
61+
# TODO(rcadene): add support for removing episodes that are tagged as "failed"
62+
continue
63+
modality_df = pd.read_parquet(path)
64+
modality_df = modality_df[["timestamp_utc", key]]
65+
df = pd.merge_asof(
66+
df,
67+
modality_df,
68+
on="timestamp_utc",
69+
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
70+
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
71+
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
72+
# This is not a problem when the tolerance is set to be low enough to avoid matching timestamps that
73+
# are too far appart.
74+
direction="nearest",
75+
tolerance=pd.Timedelta(f"{1/fps} seconds"),
76+
)
77+
78+
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
79+
df = df[df["episode_index"] != -1]
80+
81+
image_keys = [key for key in df if "observation.images." in key]
82+
83+
def get_episode_index(row):
84+
episode_index_per_cam = {}
85+
for key in image_keys:
86+
path = row[key][0]["path"]
87+
match = re.search(r"_(\d{6}).mp4", path)
88+
if not match:
89+
raise ValueError(path)
90+
episode_index = int(match.group(1))
91+
episode_index_per_cam[key] = episode_index
92+
assert (
93+
len(set(episode_index_per_cam.values())) == 1
94+
), f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
95+
return episode_index
96+
97+
df["episode_index"] = df.apply(get_episode_index, axis=1)
98+
99+
# dora only use arrays, so single values are encapsulated into a list
100+
df["frame_index"] = df.groupby("episode_index").cumcount()
101+
df = df.reset_index()
102+
df["index"] = df.index
103+
104+
# set 'next.done' to True for the last frame of each episode
105+
df["next.done"] = False
106+
df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True
107+
108+
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
109+
# each episode starts with timestamp 0 to match the ones from the video
110+
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
111+
112+
del df["timestamp_utc"]
113+
114+
# sanity check
115+
has_nan = df.isna().any().any()
116+
if has_nan:
117+
raise ValueError("Dataset contains Nan values.")
118+
119+
# sanity check episode indices go from 0 to n-1
120+
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
121+
expected_ep_ids = list(range(df["episode_index"].max() + 1))
122+
assert ep_ids == expected_ep_ids, f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
123+
124+
# Create symlink to raw videos directory (that needs to be absolute not relative)
125+
out_dir.mkdir(parents=True, exist_ok=True)
126+
videos_dir = out_dir / "videos"
127+
videos_dir.symlink_to((raw_dir / "videos").absolute())
128+
129+
# sanity check the video paths are well formated
130+
for key in df:
131+
if "observation.images." not in key:
132+
continue
133+
for ep_idx in ep_ids:
134+
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
135+
assert video_path.exists(), f"Video file not found in {video_path}"
136+
137+
data_dict = {}
138+
for key in df:
139+
# is video frame
140+
if "observation.images." in key:
141+
# we need `[0] because dora only use arrays, so single values are encapsulated into a list.
142+
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
143+
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
144+
145+
# sanity check the video path is well formated
146+
video_path = videos_dir.parent / data_dict[key][0]["path"]
147+
assert video_path.exists(), f"Video file not found in {video_path}"
148+
# is number
149+
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
150+
data_dict[key] = torch.from_numpy(df[key].values)
151+
# is vector
152+
elif df[key].iloc[0].shape[0] > 1:
153+
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
154+
else:
155+
raise ValueError(key)
156+
157+
# Get the episode index containing for each unique episode index
158+
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
159+
from_ = first_ep_index_df["start_index"].tolist()
160+
to_ = from_[1:] + [len(df)]
161+
episode_data_index = {
162+
"from": from_,
163+
"to": to_,
164+
}
165+
166+
return data_dict, episode_data_index
167+
168+
169+
def to_hf_dataset(data_dict, video) -> Dataset:
170+
features = {}
171+
172+
keys = [key for key in data_dict if "observation.images." in key]
173+
for key in keys:
174+
if video:
175+
features[key] = VideoFrame()
176+
else:
177+
features[key] = Image()
178+
179+
features["observation.state"] = Sequence(
180+
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
181+
)
182+
if "observation.velocity" in data_dict:
183+
features["observation.velocity"] = Sequence(
184+
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
185+
)
186+
if "observation.effort" in data_dict:
187+
features["observation.effort"] = Sequence(
188+
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
189+
)
190+
features["action"] = Sequence(
191+
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
192+
)
193+
features["episode_index"] = Value(dtype="int64", id=None)
194+
features["frame_index"] = Value(dtype="int64", id=None)
195+
features["timestamp"] = Value(dtype="float32", id=None)
196+
features["next.done"] = Value(dtype="bool", id=None)
197+
features["index"] = Value(dtype="int64", id=None)
198+
199+
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
200+
hf_dataset.set_transform(hf_transform_to_torch)
201+
return hf_dataset
202+
203+
204+
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
205+
init_logging()
206+
207+
if debug:
208+
logging.warning("debug=True not implemented. Falling back to debug=False.")
209+
210+
# sanity check
211+
check_format(raw_dir)
212+
213+
if fps is None:
214+
fps = 30
215+
else:
216+
raise NotImplementedError()
217+
218+
if not video:
219+
raise NotImplementedError()
220+
221+
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps)
222+
hf_dataset = to_hf_dataset(data_df, video)
223+
224+
info = {
225+
"fps": fps,
226+
"video": video,
227+
}
228+
return hf_dataset, episode_data_index, info

lerobot/scripts/push_dataset_to_hub.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,14 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
8484
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
8585
elif raw_format == "aloha_hdf5":
8686
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
87+
elif raw_format == "aloha_dora":
88+
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
8789
elif raw_format == "xarm_pkl":
8890
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
8991
else:
90-
raise ValueError(raw_format)
92+
raise ValueError(
93+
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
94+
)
9195

9296
return from_raw_to_lerobot_format
9397

0 commit comments

Comments
 (0)