Skip to content

Commit 504d2aa

Browse files
add EpisodeAwareSampler (#217)
Co-authored-by: Alexander Soare <[email protected]>
1 parent 83f4f7f commit 504d2aa

File tree

4 files changed

+168
-1
lines changed

4 files changed

+168
-1
lines changed

lerobot/common/datasets/sampler.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
from typing import Iterator, Union
17+
18+
import torch
19+
20+
21+
class EpisodeAwareSampler:
22+
def __init__(
23+
self,
24+
episode_data_index: dict,
25+
episode_indices_to_use: Union[list, None] = None,
26+
drop_n_first_frames: int = 0,
27+
drop_n_last_frames: int = 0,
28+
shuffle: bool = False,
29+
):
30+
"""Sampler that optionally incorporates episode boundary information.
31+
32+
Args:
33+
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
34+
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
35+
Assumes that episodes are indexed from 0 to N-1.
36+
drop_n_first_frames: Number of frames to drop from the start of each episode.
37+
drop_n_last_frames: Number of frames to drop from the end of each episode.
38+
shuffle: Whether to shuffle the indices.
39+
"""
40+
indices = []
41+
for episode_idx, (start_index, end_index) in enumerate(
42+
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
43+
):
44+
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
45+
indices.extend(
46+
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
47+
)
48+
49+
self.indices = indices
50+
self.shuffle = shuffle
51+
52+
def __iter__(self) -> Iterator[int]:
53+
if self.shuffle:
54+
for i in torch.randperm(len(self.indices)):
55+
yield self.indices[i]
56+
else:
57+
for i in self.indices:
58+
yield i
59+
60+
def __len__(self) -> int:
61+
return len(self.indices)

lerobot/configs/policy/diffusion.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ training:
4444
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
4545
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
4646

47+
# The original implementation doesn't sample frames for the last 7 steps,
48+
# which avoids excessive padding and leads to improved training results.
49+
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
50+
4751
eval:
4852
n_episodes: 50
4953
batch_size: 50

lerobot/scripts/train.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
3030
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
31+
from lerobot.common.datasets.sampler import EpisodeAwareSampler
3132
from lerobot.common.datasets.utils import cycle
3233
from lerobot.common.envs.factory import make_env
3334
from lerobot.common.logger import Logger, log_output_dir
@@ -356,11 +357,22 @@ def evaluate_and_checkpoint_if_needed(step):
356357
logging.info("Resume training")
357358

358359
# create dataloader for offline training
360+
if cfg.training.get("drop_n_last_frames"):
361+
shuffle = False
362+
sampler = EpisodeAwareSampler(
363+
offline_dataset.episode_data_index,
364+
drop_n_last_frames=cfg.training.drop_n_last_frames,
365+
shuffle=True,
366+
)
367+
else:
368+
shuffle = True
369+
sampler = None
359370
dataloader = torch.utils.data.DataLoader(
360371
offline_dataset,
361372
num_workers=cfg.training.num_workers,
362373
batch_size=cfg.training.batch_size,
363-
shuffle=True,
374+
shuffle=shuffle,
375+
sampler=sampler,
364376
pin_memory=device.type != "cpu",
365377
drop_last=False,
366378
)

tests/test_sampler.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
from datasets import Dataset
17+
18+
from lerobot.common.datasets.sampler import EpisodeAwareSampler
19+
from lerobot.common.datasets.utils import (
20+
calculate_episode_data_index,
21+
hf_transform_to_torch,
22+
)
23+
24+
25+
def test_drop_n_first_frames():
26+
dataset = Dataset.from_dict(
27+
{
28+
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
29+
"index": [0, 1, 2, 3, 4, 5],
30+
"episode_index": [0, 0, 1, 2, 2, 2],
31+
},
32+
)
33+
dataset.set_transform(hf_transform_to_torch)
34+
episode_data_index = calculate_episode_data_index(dataset)
35+
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
36+
assert sampler.indices == [1, 4, 5]
37+
assert len(sampler) == 3
38+
assert list(sampler) == [1, 4, 5]
39+
40+
41+
def test_drop_n_last_frames():
42+
dataset = Dataset.from_dict(
43+
{
44+
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
45+
"index": [0, 1, 2, 3, 4, 5],
46+
"episode_index": [0, 0, 1, 2, 2, 2],
47+
},
48+
)
49+
dataset.set_transform(hf_transform_to_torch)
50+
episode_data_index = calculate_episode_data_index(dataset)
51+
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
52+
assert sampler.indices == [0, 3, 4]
53+
assert len(sampler) == 3
54+
assert list(sampler) == [0, 3, 4]
55+
56+
57+
def test_episode_indices_to_use():
58+
dataset = Dataset.from_dict(
59+
{
60+
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
61+
"index": [0, 1, 2, 3, 4, 5],
62+
"episode_index": [0, 0, 1, 2, 2, 2],
63+
},
64+
)
65+
dataset.set_transform(hf_transform_to_torch)
66+
episode_data_index = calculate_episode_data_index(dataset)
67+
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
68+
assert sampler.indices == [0, 1, 3, 4, 5]
69+
assert len(sampler) == 5
70+
assert list(sampler) == [0, 1, 3, 4, 5]
71+
72+
73+
def test_shuffle():
74+
dataset = Dataset.from_dict(
75+
{
76+
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
77+
"index": [0, 1, 2, 3, 4, 5],
78+
"episode_index": [0, 0, 1, 2, 2, 2],
79+
},
80+
)
81+
dataset.set_transform(hf_transform_to_torch)
82+
episode_data_index = calculate_episode_data_index(dataset)
83+
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
84+
assert sampler.indices == [0, 1, 2, 3, 4, 5]
85+
assert len(sampler) == 6
86+
assert list(sampler) == [0, 1, 2, 3, 4, 5]
87+
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
88+
assert sampler.indices == [0, 1, 2, 3, 4, 5]
89+
assert len(sampler) == 6
90+
assert set(sampler) == {0, 1, 2, 3, 4, 5}

0 commit comments

Comments
 (0)