Skip to content

Commit 90e099b

Browse files
alibertsCadene
andauthored
Remove offline training, refactor train.py and logging/checkpointing (#670)
Co-authored-by: Remi <[email protected]>
1 parent 334deb9 commit 90e099b

40 files changed

+1519
-939
lines changed

Makefile

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ test-act-ete-train:
3939
--dataset.image_transforms.enable=true \
4040
--dataset.episodes="[0]" \
4141
--batch_size=2 \
42-
--offline.steps=4 \
43-
--online.steps=0 \
42+
--steps=4 \
43+
--eval_freq=2 \
4444
--eval.n_episodes=1 \
4545
--eval.batch_size=1 \
4646
--save_freq=2 \
@@ -76,8 +76,8 @@ test-diffusion-ete-train:
7676
--dataset.image_transforms.enable=true \
7777
--dataset.episodes="[0]" \
7878
--batch_size=2 \
79-
--offline.steps=2 \
80-
--online.steps=0 \
79+
--steps=2 \
80+
--eval_freq=2 \
8181
--eval.n_episodes=1 \
8282
--eval.batch_size=1 \
8383
--save_checkpoint=true \
@@ -106,8 +106,8 @@ test-tdmpc-ete-train:
106106
--dataset.image_transforms.enable=true \
107107
--dataset.episodes="[0]" \
108108
--batch_size=2 \
109-
--offline.steps=2 \
110-
--online.steps=0 \
109+
--steps=2 \
110+
--eval_freq=2 \
111111
--eval.n_episodes=1 \
112112
--eval.batch_size=1 \
113113
--save_checkpoint=true \
@@ -126,30 +126,3 @@ test-tdmpc-ete-eval:
126126
--eval.n_episodes=1 \
127127
--eval.batch_size=1 \
128128
--device=$(DEVICE)
129-
130-
# TODO(rcadene): fix online buffer to storing "task"
131-
# test-tdmpc-ete-train-with-online:
132-
# python lerobot/scripts/train.py \
133-
# --policy.type=tdmpc \
134-
# --env.type=pusht \
135-
# --env.obs_type=environment_state_agent_pos \
136-
# --env.episode_length=5 \
137-
# --dataset.repo_id=lerobot/pusht_keypoints \
138-
# --dataset.image_transforms.enable=true \
139-
# --dataset.episodes="[0]" \
140-
# --batch_size=2 \
141-
# --offline.steps=2 \
142-
# --online.steps=20 \
143-
# --online.rollout_n_episodes=2 \
144-
# --online.rollout_batch_size=2 \
145-
# --online.steps_between_rollouts=10 \
146-
# --online.buffer_capacity=1000 \
147-
# --online.env_seed=10000 \
148-
# --save_checkpoint=false \
149-
# --save_freq=10 \
150-
# --log_freq=1 \
151-
# --eval.use_async_envs=true \
152-
# --eval.n_episodes=1 \
153-
# --eval.batch_size=1 \
154-
# --device=$(DEVICE) \
155-
# --output_dir=tests/outputs/tdmpc_online/

examples/3_train_policy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ def main():
8686
while not done:
8787
for batch in dataloader:
8888
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
89-
output_dict = policy.forward(batch)
90-
loss = output_dict["loss"]
89+
loss, _ = policy.forward(batch)
9190
loss.backward()
9291
optimizer.step()
9392
optimizer.zero_grad()

examples/4_train_policy_with_script.md

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,26 +161,30 @@ python lerobot/scripts/train.py \
161161
```
162162
You should see from the logging that your training picks up from where it left off.
163163

164-
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--offline.steps`, which is 100 000 by default.
164+
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default.
165165
You could double the number of steps of the previous run with:
166166
```bash
167167
python lerobot/scripts/train.py \
168168
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
169169
--resume=true \
170-
--offline.steps=200000
170+
--steps=200000
171171
```
172172

173173
## Outputs of a run
174174
In the output directory, there will be a folder called `checkpoints` with the following structure:
175175
```bash
176176
outputs/train/run_resumption/checkpoints
177177
├── 000100 # checkpoint_dir for training step 100
178-
│   ├── pretrained_model
179-
│   │   ├── config.json # pretrained policy config
180-
│   │   ├── model.safetensors # model weights
181-
│   │   ├── train_config.json # train config
182-
│ │ └── README.md # model card
183-
│   └── training_state.pth # optimizer/scheduler/rng state and training step
178+
│ ├── pretrained_model/
179+
│ │ ├── config.json # policy config
180+
│ │ ├── model.safetensors # policy weights
181+
│ │ └── train_config.json # train config
182+
│ └── training_state/
183+
│ ├── optimizer_param_groups.json # optimizer param groups
184+
│ ├── optimizer_state.safetensors # optimizer state
185+
│ ├── rng_state.safetensors # rng states
186+
│ ├── scheduler_state.json # scheduler state
187+
│ └── training_step.json # training step
184188
├── 000200
185189
└── last -> 000200 # symlink to the last available checkpoint
186190
```
@@ -250,7 +254,7 @@ python lerobot/scripts/train.py \
250254
python lerobot/scripts/train.py \
251255
--config_path=checkpoint/pretrained_model/ \
252256
--resume=true \
253-
--offline.steps=200000 # <- you can change some training parameters
257+
--steps=200000 # <- you can change some training parameters
254258
```
255259

256260
#### Fine-tuning

examples/advanced/2_calculate_validation_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def main():
7575
n_examples_evaluated = 0
7676
for batch in val_dataloader:
7777
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
78-
output_dict = policy.forward(batch)
78+
loss, _ = policy.forward(batch)
7979

80-
loss_cumsum += output_dict["loss"].item()
80+
loss_cumsum += loss.item()
8181
n_examples_evaluated += batch["index"].shape[0]
8282

8383
# Calculate the average loss over the validation set.

lerobot/common/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,14 @@
44
OBS_IMAGE = "observation.image"
55
OBS_IMAGES = "observation.images"
66
ACTION = "action"
7+
8+
# files & directories
9+
CHECKPOINTS_DIR = "checkpoints"
10+
LAST_CHECKPOINT_LINK = "last"
11+
PRETRAINED_MODEL_DIR = "pretrained_model"
12+
TRAINING_STATE_DIR = "training_state"
13+
RNG_STATE = "rng_state.safetensors"
14+
TRAINING_STEP = "training_step.json"
15+
OPTIMIZER_STATE = "optimizer_state.safetensors"
16+
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
17+
SCHEDULER_STATE = "scheduler_state.json"

lerobot/common/logger.py

Lines changed: 0 additions & 240 deletions
This file was deleted.

0 commit comments

Comments
 (0)