Skip to content

Commit cf15cba

Browse files
Remove redundant slicing operation in Diffusion Policy (#240)
1 parent 042e193 commit cf15cba

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

lerobot/common/policies/diffusion/modeling_diffusion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,8 @@ def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
239239
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
240240

241241
# run sampling
242-
sample = self.conditional_sample(batch_size, global_cond=global_cond)
242+
actions = self.conditional_sample(batch_size, global_cond=global_cond)
243243

244-
# `horizon` steps worth of actions (from the first observation).
245-
actions = sample[..., : self.config.output_shapes["action"][0]]
246244
# Extract `n_action_steps` steps worth of actions (from the current observation).
247245
start = n_obs_steps - 1
248246
end = start + self.config.n_action_steps

0 commit comments

Comments
 (0)