Skip to content

Commit 3d625ae

Browse files
Handle crop_shape=None in Diffusion Policy (#219)
1 parent e3b9f1c commit 3d625ae

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

lerobot/common/policies/diffusion/configuration_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def __post_init__(self):
155155
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
156156
)
157157
image_key = next(iter(image_keys))
158-
if (
158+
if self.crop_shape is not None and (
159159
self.crop_shape[0] > self.input_shapes[image_key][1]
160160
or self.crop_shape[1] > self.input_shapes[image_key][2]
161161
):

lerobot/common/policies/diffusion/modeling_diffusion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,11 +427,15 @@ def __init__(self, config: DiffusionConfig):
427427
# Set up pooling and final layers.
428428
# Use a dry run to get the feature map shape.
429429
# The dummy input should take the number of image channels from `config.input_shapes` and it should
430-
# use the height and width from `config.crop_shape`.
430+
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
431+
# height and width from `config.input_shapes`.
431432
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
432433
assert len(image_keys) == 1
433434
image_key = image_keys[0]
434-
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
435+
dummy_input_h_w = (
436+
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
437+
)
438+
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
435439
with torch.inference_mode():
436440
dummy_feature_map = self.backbone(dummy_input)
437441
feature_map_shape = tuple(dummy_feature_map.shape[1:])

poetry.lock

Lines changed: 1 addition & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)