Skip to content

Commit 23fbc19

Browse files
authored
Merge pull request #2 from huggingface/main
Merge from main
2 parents 44edc6f + b187942 commit 23fbc19

File tree

15 files changed

+90
-26
lines changed

15 files changed

+90
-26
lines changed

Makefile

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ test-end-to-end:
2222
${MAKE} test-act-ete-eval
2323
${MAKE} test-diffusion-ete-train
2424
${MAKE} test-diffusion-ete-eval
25-
# TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc
26-
# ${MAKE} test-tdmpc-ete-train
27-
# ${MAKE} test-tdmpc-ete-eval
25+
${MAKE} test-tdmpc-ete-train
26+
${MAKE} test-tdmpc-ete-eval
2827
${MAKE} test-default-ete-eval
2928

3029
test-act-ete-train:
@@ -80,7 +79,7 @@ test-tdmpc-ete-train:
8079
policy=tdmpc \
8180
env=xarm \
8281
env.task=XarmLift-v0 \
83-
dataset_repo_id=lerobot/xarm_lift_medium_replay \
82+
dataset_repo_id=lerobot/xarm_lift_medium \
8483
wandb.enable=False \
8584
training.offline_steps=2 \
8685
training.online_steps=2 \

lerobot/common/utils/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import logging
22
import os.path as osp
33
import random
4+
from contextlib import contextmanager
45
from datetime import datetime
56
from pathlib import Path
7+
from typing import Generator
68

79
import hydra
810
import numpy as np
@@ -39,6 +41,31 @@ def set_global_seed(seed):
3941
torch.cuda.manual_seed_all(seed)
4042

4143

44+
@contextmanager
45+
def seeded_context(seed: int) -> Generator[None, None, None]:
46+
"""Set the seed when entering a context, and restore the prior random state at exit.
47+
48+
Example usage:
49+
50+
```
51+
a = random.random() # produces some random number
52+
with seeded_context(1337):
53+
b = random.random() # produces some other random number
54+
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
55+
```
56+
"""
57+
random_state = random.getstate()
58+
np_random_state = np.random.get_state()
59+
torch_random_state = torch.random.get_rng_state()
60+
torch_cuda_random_state = torch.cuda.random.get_rng_state()
61+
set_global_seed(seed)
62+
yield None
63+
random.setstate(random_state)
64+
np.random.set_state(np_random_state)
65+
torch.random.set_rng_state(torch_random_state)
66+
torch.cuda.random.set_rng_state(torch_cuda_random_state)
67+
68+
4269
def init_logging():
4370
def custom_format(record):
4471
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

lerobot/configs/policy/act.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
seed: 1000
44
dataset_repo_id: lerobot/aloha_sim_insertion_human
55

6+
override_dataset_stats:
7+
observation.images.top:
8+
# stats from imagenet, since we use a pretrained vision model
9+
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
10+
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
11+
612
training:
713
offline_steps: 80000
814
online_steps: 0
@@ -18,12 +24,6 @@ training:
1824
grad_clip_norm: 10
1925
online_steps_between_rollouts: 1
2026

21-
override_dataset_stats:
22-
observation.images.top:
23-
# stats from imagenet, since we use a pretrained vision model
24-
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
25-
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
26-
2727
delta_timestamps:
2828
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
2929

lerobot/configs/policy/diffusion.yaml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,20 @@
77
seed: 100000
88
dataset_repo_id: lerobot/pusht
99

10+
override_dataset_stats:
11+
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
12+
observation.image:
13+
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
14+
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
15+
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
16+
# from the original codebase, but we should remove these and train our own pretrained model
17+
observation.state:
18+
min: [13.456424, 32.938293]
19+
max: [496.14618, 510.9579]
20+
action:
21+
min: [12.0, 25.0]
22+
max: [511.0, 511.0]
23+
1024
training:
1125
offline_steps: 200000
1226
online_steps: 0
@@ -34,20 +48,6 @@ eval:
3448
n_episodes: 50
3549
batch_size: 50
3650

37-
override_dataset_stats:
38-
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
39-
observation.image:
40-
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
41-
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
42-
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
43-
# from the original codebase, but we should remove these and train our own pretrained model
44-
observation.state:
45-
min: [13.456424, 32.938293]
46-
max: [496.14618, 510.9579]
47-
action:
48-
min: [12.0, 25.0]
49-
max: [511.0, 511.0]
50-
5151
policy:
5252
name: diffusion
5353

lerobot/configs/policy/tdmpc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# @package _global_
22

33
seed: 1
4-
dataset_repo_id: lerobot/xarm_lift_medium_replay
4+
dataset_repo_id: lerobot/xarm_lift_medium
55

66
training:
77
offline_steps: 25000
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

tests/test_policies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_normalize(insert_temporal_dim):
237237
@pytest.mark.parametrize(
238238
"env_name, policy_name, extra_overrides",
239239
[
240-
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
240+
("xarm", "tdmpc", []),
241241
(
242242
"pusht",
243243
"diffusion",

tests/test_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import random
2+
from typing import Callable
3+
4+
import numpy as np
5+
import pytest
6+
import torch
7+
8+
from lerobot.common.utils.utils import seeded_context, set_global_seed
9+
10+
11+
@pytest.mark.parametrize(
12+
"rand_fn",
13+
[
14+
random.random,
15+
np.random.random,
16+
lambda: torch.rand(1).item(),
17+
]
18+
+ [lambda: torch.rand(1, device="cuda")]
19+
if torch.cuda.is_available()
20+
else [],
21+
)
22+
def test_seeding(rand_fn: Callable[[], int]):
23+
set_global_seed(0)
24+
a = rand_fn()
25+
with seeded_context(1337):
26+
c = rand_fn()
27+
b = rand_fn()
28+
set_global_seed(0)
29+
a_ = rand_fn()
30+
b_ = rand_fn()
31+
# Check that `set_global_seed` lets us reproduce a and b.
32+
assert a_ == a
33+
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
34+
assert b_ == b
35+
set_global_seed(1337)
36+
c_ = rand_fn()
37+
# Check that `seeded_context` and `global_seed` give the same reproducibility.
38+
assert c_ == c

0 commit comments

Comments
 (0)