|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +""" |
| 17 | +Contains utilities to process raw data format from dora-record |
| 18 | +""" |
| 19 | + |
| 20 | +import logging |
| 21 | +import re |
| 22 | +from pathlib import Path |
| 23 | + |
| 24 | +import pandas as pd |
| 25 | +import torch |
| 26 | +from datasets import Dataset, Features, Image, Sequence, Value |
| 27 | + |
| 28 | +from lerobot.common.datasets.utils import ( |
| 29 | + hf_transform_to_torch, |
| 30 | +) |
| 31 | +from lerobot.common.datasets.video_utils import VideoFrame |
| 32 | +from lerobot.common.utils.utils import init_logging |
| 33 | + |
| 34 | + |
| 35 | +def check_format(raw_dir) -> bool: |
| 36 | + assert raw_dir.exists() |
| 37 | + |
| 38 | + leader_file = list(raw_dir.glob("*.parquet")) |
| 39 | + if len(leader_file) == 0: |
| 40 | + raise ValueError(f"Missing parquet files in '{raw_dir}'") |
| 41 | + return True |
| 42 | + |
| 43 | + |
| 44 | +def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): |
| 45 | + # Load data stream that will be used as reference for the timestamps synchronization |
| 46 | + reference_files = list(raw_dir.glob("observation.images.cam_*.parquet")) |
| 47 | + if len(reference_files) == 0: |
| 48 | + raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'") |
| 49 | + # select first camera in alphanumeric order |
| 50 | + reference_key = sorted(reference_files)[0].stem |
| 51 | + reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet") |
| 52 | + reference_df = reference_df[["timestamp_utc", reference_key]] |
| 53 | + |
| 54 | + # Merge all data stream using nearest backward strategy |
| 55 | + df = reference_df |
| 56 | + for path in raw_dir.glob("*.parquet"): |
| 57 | + key = path.stem # action or observation.state or ... |
| 58 | + if key == reference_key: |
| 59 | + continue |
| 60 | + if "failed_episode_index" in key: |
| 61 | + # TODO(rcadene): add support for removing episodes that are tagged as "failed" |
| 62 | + continue |
| 63 | + modality_df = pd.read_parquet(path) |
| 64 | + modality_df = modality_df[["timestamp_utc", key]] |
| 65 | + df = pd.merge_asof( |
| 66 | + df, |
| 67 | + modality_df, |
| 68 | + on="timestamp_utc", |
| 69 | + # "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by |
| 70 | + # matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest". |
| 71 | + # However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps. |
| 72 | + # This is not a problem when the tolerance is set to be low enough to avoid matching timestamps that |
| 73 | + # are too far appart. |
| 74 | + direction="nearest", |
| 75 | + tolerance=pd.Timedelta(f"{1/fps} seconds"), |
| 76 | + ) |
| 77 | + |
| 78 | + # Remove rows with episode_index -1 which indicates data that correspond to in-between episodes |
| 79 | + df = df[df["episode_index"] != -1] |
| 80 | + |
| 81 | + image_keys = [key for key in df if "observation.images." in key] |
| 82 | + |
| 83 | + def get_episode_index(row): |
| 84 | + episode_index_per_cam = {} |
| 85 | + for key in image_keys: |
| 86 | + path = row[key][0]["path"] |
| 87 | + match = re.search(r"_(\d{6}).mp4", path) |
| 88 | + if not match: |
| 89 | + raise ValueError(path) |
| 90 | + episode_index = int(match.group(1)) |
| 91 | + episode_index_per_cam[key] = episode_index |
| 92 | + assert ( |
| 93 | + len(set(episode_index_per_cam.values())) == 1 |
| 94 | + ), f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}" |
| 95 | + return episode_index |
| 96 | + |
| 97 | + df["episode_index"] = df.apply(get_episode_index, axis=1) |
| 98 | + |
| 99 | + # dora only use arrays, so single values are encapsulated into a list |
| 100 | + df["frame_index"] = df.groupby("episode_index").cumcount() |
| 101 | + df = df.reset_index() |
| 102 | + df["index"] = df.index |
| 103 | + |
| 104 | + # set 'next.done' to True for the last frame of each episode |
| 105 | + df["next.done"] = False |
| 106 | + df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True |
| 107 | + |
| 108 | + df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp()) |
| 109 | + # each episode starts with timestamp 0 to match the ones from the video |
| 110 | + df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0]) |
| 111 | + |
| 112 | + del df["timestamp_utc"] |
| 113 | + |
| 114 | + # sanity check |
| 115 | + has_nan = df.isna().any().any() |
| 116 | + if has_nan: |
| 117 | + raise ValueError("Dataset contains Nan values.") |
| 118 | + |
| 119 | + # sanity check episode indices go from 0 to n-1 |
| 120 | + ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")] |
| 121 | + expected_ep_ids = list(range(df["episode_index"].max() + 1)) |
| 122 | + assert ep_ids == expected_ep_ids, f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}" |
| 123 | + |
| 124 | + # Create symlink to raw videos directory (that needs to be absolute not relative) |
| 125 | + out_dir.mkdir(parents=True, exist_ok=True) |
| 126 | + videos_dir = out_dir / "videos" |
| 127 | + videos_dir.symlink_to((raw_dir / "videos").absolute()) |
| 128 | + |
| 129 | + # sanity check the video paths are well formated |
| 130 | + for key in df: |
| 131 | + if "observation.images." not in key: |
| 132 | + continue |
| 133 | + for ep_idx in ep_ids: |
| 134 | + video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4" |
| 135 | + assert video_path.exists(), f"Video file not found in {video_path}" |
| 136 | + |
| 137 | + data_dict = {} |
| 138 | + for key in df: |
| 139 | + # is video frame |
| 140 | + if "observation.images." in key: |
| 141 | + # we need `[0] because dora only use arrays, so single values are encapsulated into a list. |
| 142 | + # it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}] |
| 143 | + data_dict[key] = [video_frame[0] for video_frame in df[key].values] |
| 144 | + |
| 145 | + # sanity check the video path is well formated |
| 146 | + video_path = videos_dir.parent / data_dict[key][0]["path"] |
| 147 | + assert video_path.exists(), f"Video file not found in {video_path}" |
| 148 | + # is number |
| 149 | + elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1: |
| 150 | + data_dict[key] = torch.from_numpy(df[key].values) |
| 151 | + # is vector |
| 152 | + elif df[key].iloc[0].shape[0] > 1: |
| 153 | + data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values]) |
| 154 | + else: |
| 155 | + raise ValueError(key) |
| 156 | + |
| 157 | + # Get the episode index containing for each unique episode index |
| 158 | + first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index() |
| 159 | + from_ = first_ep_index_df["start_index"].tolist() |
| 160 | + to_ = from_[1:] + [len(df)] |
| 161 | + episode_data_index = { |
| 162 | + "from": from_, |
| 163 | + "to": to_, |
| 164 | + } |
| 165 | + |
| 166 | + return data_dict, episode_data_index |
| 167 | + |
| 168 | + |
| 169 | +def to_hf_dataset(data_dict, video) -> Dataset: |
| 170 | + features = {} |
| 171 | + |
| 172 | + keys = [key for key in data_dict if "observation.images." in key] |
| 173 | + for key in keys: |
| 174 | + if video: |
| 175 | + features[key] = VideoFrame() |
| 176 | + else: |
| 177 | + features[key] = Image() |
| 178 | + |
| 179 | + features["observation.state"] = Sequence( |
| 180 | + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) |
| 181 | + ) |
| 182 | + if "observation.velocity" in data_dict: |
| 183 | + features["observation.velocity"] = Sequence( |
| 184 | + length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None) |
| 185 | + ) |
| 186 | + if "observation.effort" in data_dict: |
| 187 | + features["observation.effort"] = Sequence( |
| 188 | + length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None) |
| 189 | + ) |
| 190 | + features["action"] = Sequence( |
| 191 | + length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) |
| 192 | + ) |
| 193 | + features["episode_index"] = Value(dtype="int64", id=None) |
| 194 | + features["frame_index"] = Value(dtype="int64", id=None) |
| 195 | + features["timestamp"] = Value(dtype="float32", id=None) |
| 196 | + features["next.done"] = Value(dtype="bool", id=None) |
| 197 | + features["index"] = Value(dtype="int64", id=None) |
| 198 | + |
| 199 | + hf_dataset = Dataset.from_dict(data_dict, features=Features(features)) |
| 200 | + hf_dataset.set_transform(hf_transform_to_torch) |
| 201 | + return hf_dataset |
| 202 | + |
| 203 | + |
| 204 | +def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): |
| 205 | + init_logging() |
| 206 | + |
| 207 | + if debug: |
| 208 | + logging.warning("debug=True not implemented. Falling back to debug=False.") |
| 209 | + |
| 210 | + # sanity check |
| 211 | + check_format(raw_dir) |
| 212 | + |
| 213 | + if fps is None: |
| 214 | + fps = 30 |
| 215 | + else: |
| 216 | + raise NotImplementedError() |
| 217 | + |
| 218 | + if not video: |
| 219 | + raise NotImplementedError() |
| 220 | + |
| 221 | + data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps) |
| 222 | + hf_dataset = to_hf_dataset(data_df, video) |
| 223 | + |
| 224 | + info = { |
| 225 | + "fps": fps, |
| 226 | + "video": video, |
| 227 | + } |
| 228 | + return hf_dataset, episode_data_index, info |
0 commit comments