Skip to content

Commit 9d6886d

Browse files
Cadenealiberts
andauthored
Add frame level task (#693)
Co-authored-by: Simon Alibert <[email protected]>
1 parent d67ca34 commit 9d6886d

File tree

6 files changed

+104
-49
lines changed

6 files changed

+104
-49
lines changed

examples/port_datasets/pusht_zarr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
180180
# Shift reward and success by +1 until the last item of the episode
181181
"next.reward": reward[i + (frame_idx < num_frames - 1)],
182182
"next.success": success[i + (frame_idx < num_frames - 1)],
183+
"task": PUSHT_TASK,
183184
}
184185

185186
frame["observation.state"] = torch.from_numpy(agent_pos[i])
@@ -191,7 +192,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
191192

192193
dataset.add_frame(frame)
193194

194-
dataset.save_episode(task=PUSHT_TASK)
195+
dataset.save_episode()
195196

196197
dataset.consolidate()
197198

lerobot/common/datasets/lerobot_dataset.py

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
self.pull_from_repo(allow_patterns="meta/")
8888
self.info = load_info(self.root)
8989
self.stats = load_stats(self.root)
90-
self.tasks = load_tasks(self.root)
90+
self.tasks, self.task_to_task_index = load_tasks(self.root)
9191
self.episodes = load_episodes(self.root)
9292

9393
def pull_from_repo(
@@ -202,31 +202,35 @@ def chunks_size(self) -> int:
202202
"""Max number of episodes per chunk."""
203203
return self.info["chunks_size"]
204204

205-
@property
206-
def task_to_task_index(self) -> dict:
207-
return {task: task_idx for task_idx, task in self.tasks.items()}
208-
209-
def get_task_index(self, task: str) -> int:
205+
def get_task_index(self, task: str) -> int | None:
210206
"""
211207
Given a task in natural language, returns its task_index if the task already exists in the dataset,
212-
otherwise creates a new task_index.
208+
otherwise return None.
209+
"""
210+
return self.task_to_task_index.get(task, None)
211+
212+
def add_task(self, task: str):
213213
"""
214-
task_index = self.task_to_task_index.get(task, None)
215-
return task_index if task_index is not None else self.total_tasks
214+
Given a task in natural language, add it to the dictionnary of tasks.
215+
"""
216+
if task in self.task_to_task_index:
217+
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
218+
219+
task_index = self.info["total_tasks"]
220+
self.task_to_task_index[task] = task_index
221+
self.tasks[task_index] = task
222+
self.info["total_tasks"] += 1
223+
224+
task_dict = {
225+
"task_index": task_index,
226+
"task": task,
227+
}
228+
append_jsonlines(task_dict, self.root / TASKS_PATH)
216229

217-
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
230+
def save_episode(self, episode_index: int, episode_length: int, episode_tasks: list[str]) -> None:
218231
self.info["total_episodes"] += 1
219232
self.info["total_frames"] += episode_length
220233

221-
if task_index not in self.tasks:
222-
self.info["total_tasks"] += 1
223-
self.tasks[task_index] = task
224-
task_dict = {
225-
"task_index": task_index,
226-
"task": task,
227-
}
228-
append_jsonlines(task_dict, self.root / TASKS_PATH)
229-
230234
chunk = self.get_episode_chunk(episode_index)
231235
if chunk >= self.total_chunks:
232236
self.info["total_chunks"] += 1
@@ -237,7 +241,7 @@ def save_episode(self, episode_index: int, episode_length: int, task: str, task_
237241

238242
episode_dict = {
239243
"episode_index": episode_index,
240-
"tasks": [task],
244+
"tasks": episode_tasks,
241245
"length": episode_length,
242246
}
243247
self.episodes.append(episode_dict)
@@ -313,7 +317,8 @@ def create(
313317

314318
features = {**features, **DEFAULT_FEATURES}
315319

316-
obj.tasks, obj.stats, obj.episodes = {}, {}, []
320+
obj.tasks, obj.task_to_task_index = {}, {}
321+
obj.stats, obj.episodes = {}, []
317322
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
318323
if len(obj.video_keys) > 0 and not use_videos:
319324
raise ValueError()
@@ -691,10 +696,13 @@ def __repr__(self):
691696

692697
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
693698
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
694-
return {
695-
"size": 0,
696-
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
697-
}
699+
ep_buffer = {}
700+
# size and task are special cases that are not in self.features
701+
ep_buffer["size"] = 0
702+
ep_buffer["task"] = []
703+
for key in self.features:
704+
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
705+
return ep_buffer
698706

699707
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
700708
fpath = DEFAULT_IMAGE_PATH.format(
@@ -718,6 +726,8 @@ def add_frame(self, frame: dict) -> None:
718726
"""
719727
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
720728
# check the dtype and shape matches, etc.
729+
if "task" not in frame:
730+
raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.")
721731

722732
if self.episode_buffer is None:
723733
self.episode_buffer = self.create_episode_buffer()
@@ -728,24 +738,31 @@ def add_frame(self, frame: dict) -> None:
728738
self.episode_buffer["timestamp"].append(timestamp)
729739

730740
for key in frame:
741+
if key == "task":
742+
# Note: we associate the task in natural language to its task index during `save_episode`
743+
self.episode_buffer["task"].append(frame["task"])
744+
continue
745+
731746
if key not in self.features:
732-
raise ValueError(key)
747+
raise ValueError(
748+
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
749+
)
733750

734-
if self.features[key]["dtype"] not in ["image", "video"]:
735-
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
736-
self.episode_buffer[key].append(item)
737-
elif self.features[key]["dtype"] in ["image", "video"]:
751+
if self.features[key]["dtype"] in ["image", "video"]:
738752
img_path = self._get_image_file_path(
739753
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
740754
)
741755
if frame_index == 0:
742756
img_path.parent.mkdir(parents=True, exist_ok=True)
743757
self._save_image(frame[key], img_path)
744758
self.episode_buffer[key].append(str(img_path))
759+
else:
760+
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
761+
self.episode_buffer[key].append(item)
745762

746763
self.episode_buffer["size"] += 1
747764

748-
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
765+
def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None:
749766
"""
750767
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
751768
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
@@ -758,7 +775,11 @@ def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict
758775
if not episode_data:
759776
episode_buffer = self.episode_buffer
760777

778+
# size and task are special cases that won't be added to hf_dataset
761779
episode_length = episode_buffer.pop("size")
780+
tasks = episode_buffer.pop("task")
781+
episode_tasks = list(set(tasks))
782+
762783
episode_index = episode_buffer["episode_index"]
763784
if episode_index != self.meta.total_episodes:
764785
# TODO(aliberts): Add option to use existing episode_index
@@ -772,21 +793,27 @@ def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict
772793
"You must add one or several frames with `add_frame` before calling `add_episode`."
773794
)
774795

775-
task_index = self.meta.get_task_index(task)
776-
777796
if not set(episode_buffer.keys()) == set(self.features):
778-
raise ValueError()
797+
raise ValueError(
798+
f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'"
799+
)
800+
801+
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
802+
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
803+
804+
# Add new tasks to the tasks dictionnary
805+
for task in episode_tasks:
806+
task_index = self.meta.get_task_index(task)
807+
if task_index is None:
808+
self.meta.add_task(task)
809+
810+
# Given tasks in natural language, find their corresponding task indices
811+
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
779812

780813
for key, ft in self.features.items():
781-
if key == "index":
782-
episode_buffer[key] = np.arange(
783-
self.meta.total_frames, self.meta.total_frames + episode_length
784-
)
785-
elif key == "episode_index":
786-
episode_buffer[key] = np.full((episode_length,), episode_index)
787-
elif key == "task_index":
788-
episode_buffer[key] = np.full((episode_length,), task_index)
789-
elif ft["dtype"] in ["image", "video"]:
814+
# index, episode_index, task_index are already processed above, and image and video
815+
# are processed separately by storing image path and frame info as meta data
816+
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
790817
continue
791818
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
792819
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
@@ -798,7 +825,7 @@ def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict
798825
self._wait_image_writer()
799826
self._save_episode_table(episode_buffer, episode_index)
800827

801-
self.meta.save_episode(episode_index, episode_length, task, task_index)
828+
self.meta.save_episode(episode_index, episode_length, episode_tasks)
802829

803830
if encode_videos and len(self.meta.video_keys) > 0:
804831
video_paths = self.encode_episode_videos(episode_index)

lerobot/common/datasets/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def load_stats(local_dir: Path) -> dict:
170170

171171
def load_tasks(local_dir: Path) -> dict:
172172
tasks = load_jsonlines(local_dir / TASKS_PATH)
173-
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
173+
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
174+
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
175+
return tasks, task_to_task_index
174176

175177

176178
def load_episodes(local_dir: Path) -> dict:

lerobot/common/robot_devices/control_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def record_episode(
183183
device,
184184
use_amp,
185185
fps,
186+
single_task,
186187
):
187188
control_loop(
188189
robot=robot,
@@ -195,6 +196,7 @@ def record_episode(
195196
use_amp=use_amp,
196197
fps=fps,
197198
teleoperate=policy is None,
199+
single_task=single_task,
198200
)
199201

200202

@@ -210,6 +212,7 @@ def control_loop(
210212
device: torch.device | str | None = None,
211213
use_amp: bool | None = None,
212214
fps: int | None = None,
215+
single_task: str | None = None,
213216
):
214217
# TODO(rcadene): Add option to record logs
215218
if not robot.is_connected:
@@ -224,6 +227,9 @@ def control_loop(
224227
if teleoperate and policy is not None:
225228
raise ValueError("When `teleoperate` is True, `policy` should be None.")
226229

230+
if dataset is not None and single_task is None:
231+
raise ValueError("You need to provide a task as argument in `single_task`.")
232+
227233
if dataset is not None and fps is not None and dataset.fps != fps:
228234
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
229235

@@ -248,7 +254,7 @@ def control_loop(
248254
action = {"action": action}
249255

250256
if dataset is not None:
251-
frame = {**observation, **action}
257+
frame = {**observation, **action, "task": single_task}
252258
dataset.add_frame(frame)
253259

254260
if display_cameras and not is_headless():

lerobot/scripts/control_robot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,16 @@ def record(
263263

264264
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
265265
record_episode(
266-
dataset=dataset,
267266
robot=robot,
267+
dataset=dataset,
268268
events=events,
269269
episode_time_s=cfg.episode_time_s,
270270
display_cameras=cfg.display_cameras,
271271
policy=policy,
272272
device=cfg.device,
273273
use_amp=cfg.use_amp,
274274
fps=cfg.fps,
275+
single_task=cfg.single_task,
275276
)
276277

277278
# Execute a few seconds without recording to give time to manually reset the environment
@@ -291,7 +292,7 @@ def record(
291292
dataset.clear_episode_buffer()
292293
continue
293294

294-
dataset.save_episode(cfg.single_task)
295+
dataset.save_episode()
295296
recorded_episodes += 1
296297

297298
if events["stop_recording"]:

tests/test_datasets.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,24 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
9393
assert dataset.num_frames == len(dataset)
9494

9595

96+
def test_add_frame_no_task(tmp_path):
97+
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
98+
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
99+
with pytest.raises(ValueError, match="The mandatory feature 'task' wasn't found in `frame` dictionnary."):
100+
dataset.add_frame({"1d": torch.randn(1)})
101+
102+
103+
def test_add_frame(tmp_path):
104+
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
105+
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
106+
dataset.add_frame({"1d": torch.randn(1), "task": "dummy"})
107+
dataset.save_episode(encode_videos=False)
108+
dataset.consolidate(run_compute_stats=False)
109+
assert len(dataset) == 1
110+
assert dataset[0]["task"] == "dummy"
111+
assert dataset[0]["task_index"] == 0
112+
113+
96114
# TODO(aliberts):
97115
# - [ ] test various attributes & state from init and create
98116
# - [ ] test init with episodes and check num_frames

0 commit comments

Comments
 (0)