Skip to content

Add frame level task #693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/port_datasets/pusht_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
# Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)],
"next.success": success[i + (frame_idx < num_frames - 1)],
"task": PUSHT_TASK,
}

frame["observation.state"] = torch.from_numpy(agent_pos[i])
Expand All @@ -191,7 +192,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T

dataset.add_frame(frame)

dataset.save_episode(task=PUSHT_TASK)
dataset.save_episode()

dataset.consolidate()

Expand Down
115 changes: 71 additions & 44 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
self.pull_from_repo(allow_patterns="meta/")
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root)

def pull_from_repo(
Expand Down Expand Up @@ -202,31 +202,35 @@ def chunks_size(self) -> int:
"""Max number of episodes per chunk."""
return self.info["chunks_size"]

@property
def task_to_task_index(self) -> dict:
return {task: task_idx for task_idx, task in self.tasks.items()}

def get_task_index(self, task: str) -> int:
def get_task_index(self, task: str) -> int | None:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise creates a new task_index.
otherwise return None.
"""
return self.task_to_task_index.get(task, None)

def add_task(self, task: str):
"""
task_index = self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks
Given a task in natural language, add it to the dictionnary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(f"The task '{task}' already exists and can't be added twice.")

task_index = self.info["total_tasks"]
self.task_to_task_index[task] = task_index
self.tasks[task_index] = task
self.info["total_tasks"] += 1

task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, self.root / TASKS_PATH)

def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
def save_episode(self, episode_index: int, episode_length: int, episode_tasks: list[str]) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length

if task_index not in self.tasks:
self.info["total_tasks"] += 1
self.tasks[task_index] = task
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, self.root / TASKS_PATH)

chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
self.info["total_chunks"] += 1
Expand All @@ -237,7 +241,7 @@ def save_episode(self, episode_index: int, episode_length: int, task: str, task_

episode_dict = {
"episode_index": episode_index,
"tasks": [task],
"tasks": episode_tasks,
"length": episode_length,
}
self.episodes.append(episode_dict)
Expand Down Expand Up @@ -313,7 +317,8 @@ def create(

features = {**features, **DEFAULT_FEATURES}

obj.tasks, obj.stats, obj.episodes = {}, {}, []
obj.tasks, obj.task_to_task_index = {}, {}
obj.stats, obj.episodes = {}, []
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
Expand Down Expand Up @@ -691,10 +696,13 @@ def __repr__(self):

def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
return {
"size": 0,
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
}
ep_buffer = {}
# size and task are special cases that are not in self.features
ep_buffer["size"] = 0
ep_buffer["task"] = []
for key in self.features:
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer

def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
Expand All @@ -718,6 +726,8 @@ def add_frame(self, frame: dict) -> None:
"""
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
# check the dtype and shape matches, etc.
if "task" not in frame:
raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.")

if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
Expand All @@ -728,24 +738,31 @@ def add_frame(self, frame: dict) -> None:
self.episode_buffer["timestamp"].append(timestamp)

for key in frame:
if key == "task":
# Note: we associate the task in natural language to its task index during `save_episode`
self.episode_buffer["task"].append(frame["task"])
continue

if key not in self.features:
raise ValueError(key)
raise ValueError(
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
)

if self.features[key]["dtype"] not in ["image", "video"]:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
self.episode_buffer[key].append(item)
elif self.features[key]["dtype"] in ["image", "video"]:
if self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path))
else:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
self.episode_buffer[key].append(item)

self.episode_buffer["size"] += 1

def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None:
"""
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
Expand All @@ -758,7 +775,11 @@ def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict
if not episode_data:
episode_buffer = self.episode_buffer

# size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))

episode_index = episode_buffer["episode_index"]
if episode_index != self.meta.total_episodes:
# TODO(aliberts): Add option to use existing episode_index
Expand All @@ -772,21 +793,27 @@ def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict
"You must add one or several frames with `add_frame` before calling `add_episode`."
)

task_index = self.meta.get_task_index(task)

if not set(episode_buffer.keys()) == set(self.features):
raise ValueError()
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'"
)

episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)

# Add new tasks to the tasks dictionnary
for task in episode_tasks:
task_index = self.meta.get_task_index(task)
if task_index is None:
self.meta.add_task(task)

# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])

for key, ft in self.features.items():
if key == "index":
episode_buffer[key] = np.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
elif key == "episode_index":
episode_buffer[key] = np.full((episode_length,), episode_index)
elif key == "task_index":
episode_buffer[key] = np.full((episode_length,), task_index)
elif ft["dtype"] in ["image", "video"]:
# index, episode_index, task_index are already processed above, and image and video
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
Expand All @@ -798,7 +825,7 @@ def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict
self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index)

self.meta.save_episode(episode_index, episode_length, task, task_index)
self.meta.save_episode(episode_index, episode_length, episode_tasks)

if encode_videos and len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index)
Expand Down
4 changes: 3 additions & 1 deletion lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def load_stats(local_dir: Path) -> dict:

def load_tasks(local_dir: Path) -> dict:
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index


def load_episodes(local_dir: Path) -> dict:
Expand Down
8 changes: 7 additions & 1 deletion lerobot/common/robot_devices/control_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def record_episode(
device,
use_amp,
fps,
single_task,
):
control_loop(
robot=robot,
Expand All @@ -195,6 +196,7 @@ def record_episode(
use_amp=use_amp,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
)


Expand All @@ -210,6 +212,7 @@ def control_loop(
device: torch.device | str | None = None,
use_amp: bool | None = None,
fps: int | None = None,
single_task: str | None = None,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
Expand All @@ -224,6 +227,9 @@ def control_loop(
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")

if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")

if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")

Expand All @@ -248,7 +254,7 @@ def control_loop(
action = {"action": action}

if dataset is not None:
frame = {**observation, **action}
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)

if display_cameras and not is_headless():
Expand Down
5 changes: 3 additions & 2 deletions lerobot/scripts/control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,16 @@ def record(

log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_episode(
dataset=dataset,
robot=robot,
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
device=cfg.device,
use_amp=cfg.use_amp,
fps=cfg.fps,
single_task=cfg.single_task,
)

# Execute a few seconds without recording to give time to manually reset the environment
Expand All @@ -291,7 +292,7 @@ def record(
dataset.clear_episode_buffer()
continue

dataset.save_episode(cfg.single_task)
dataset.save_episode()
recorded_episodes += 1

if events["stop_recording"]:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
assert dataset.num_frames == len(dataset)


def test_add_frame_no_task(tmp_path):
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
with pytest.raises(ValueError, match="The mandatory feature 'task' wasn't found in `frame` dictionnary."):
dataset.add_frame({"1d": torch.randn(1)})


def test_add_frame(tmp_path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is quite long for what it's doing (~0.5s)
Did you check if this new add_frame is slower than before?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.13s on my side

Screenshot 2025-02-13 at 18 54 49

features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
dataset.add_frame({"1d": torch.randn(1), "task": "dummy"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert len(dataset) == 1
assert dataset[0]["task"] == "dummy"
assert dataset[0]["task_index"] == 0


# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
Expand Down