Skip to content

Commit e710959

Browse files
authored
Fixes following #670 (#719)
1 parent 90e099b commit e710959

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

lerobot/common/policies/pi0/modeling_pi0.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -
300300
self._action_queue.extend(actions.transpose(0, 1))
301301
return self._action_queue.popleft()
302302

303-
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
303+
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
304304
"""Do a full training forward pass to compute the loss"""
305305
if self.config.adapt_to_pi_aloha:
306306
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
@@ -328,12 +328,12 @@ def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str,
328328
losses = losses[:, :, : self.config.max_action_dim]
329329
loss_dict["losses_after_rm_padding"] = losses.clone()
330330

331-
loss = losses.mean()
332331
# For backward pass
333-
loss_dict["loss"] = loss
332+
loss = losses.mean()
334333
# For logging
335334
loss_dict["l2_loss"] = loss.item()
336-
return loss_dict
335+
336+
return loss, loss_dict
337337

338338
def prepare_images(self, batch):
339339
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and

lerobot/common/utils/wandb_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def log_policy(self, checkpoint_dir: Path):
102102
self._wandb.log_artifact(artifact)
103103

104104
def log_dict(self, d: dict, step: int, mode: str = "train"):
105-
if mode in {"train", "eval"}:
105+
if mode not in {"train", "eval"}:
106106
raise ValueError(mode)
107107

108108
for k, v in d.items():
@@ -114,7 +114,7 @@ def log_dict(self, d: dict, step: int, mode: str = "train"):
114114
self._wandb.log({f"{mode}/{k}": v}, step=step)
115115

116116
def log_video(self, video_path: str, step: int, mode: str = "train"):
117-
if mode in {"train", "eval"}:
117+
if mode not in {"train", "eval"}:
118118
raise ValueError(mode)
119119

120120
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")

lerobot/scripts/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def train(cfg: TrainPipelineConfig):
233233
logging.info(train_tracker)
234234
if wandb_logger:
235235
wandb_log_dict = {**train_tracker.to_dict(), **output_dict}
236-
wandb_logger.log_dict(wandb_log_dict)
236+
wandb_logger.log_dict(wandb_log_dict, step)
237237
train_tracker.reset_averages()
238238

239239
if cfg.save_checkpoint and is_saving_step:
@@ -271,6 +271,7 @@ def train(cfg: TrainPipelineConfig):
271271
logging.info(eval_tracker)
272272
if wandb_logger:
273273
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
274+
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
274275
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
275276

276277
if eval_env:

0 commit comments

Comments
 (0)