Skip to content

Commit 3b86050

Browse files
throw an error if config.do_maks_loss and action_is_pad not provided in batch (#213)
Co-authored-by: Alexander Soare <[email protected]>
1 parent 6d39b73 commit 3b86050

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

lerobot/common/policies/diffusion/modeling_diffusion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,11 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
304304
loss = F.mse_loss(pred, target, reduction="none")
305305

306306
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
307-
if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
307+
if self.config.do_mask_loss_for_padding:
308+
if "action_is_pad" not in batch:
309+
raise ValueError(
310+
f"You need to provide 'action_is_pad' in the batch when {self.config.do_mask_loss_for_padding=}."
311+
)
308312
in_episode_bound = ~batch["action_is_pad"]
309313
loss = loss * in_episode_bound.unsqueeze(-1)
310314

0 commit comments

Comments
 (0)