Skip to content

Commit 232be1c

Browse files
committed
Refactor OverlapPatchEmbed to inherit from Convolution and streamline forward method
1 parent 1683b14 commit 232be1c

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

monai/networks/nets/restormer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5757
return x
5858

5959

60-
class OverlapPatchEmbed(nn.Module):
60+
class OverlapPatchEmbed(Convolution):
6161
"""Initial feature extraction using overlapped convolutions.
6262
Unlike standard patch embeddings that use non-overlapping patches,
6363
this approach maintains spatial continuity through 3x3 convolutions.
@@ -70,8 +70,7 @@ class OverlapPatchEmbed(nn.Module):
7070
"""
7171

7272
def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, bias: bool = False):
73-
super().__init__()
74-
self.proj = Convolution(
73+
super().__init__(
7574
spatial_dims=spatial_dims,
7675
in_channels=in_channels,
7776
out_channels=embed_dim,
@@ -82,8 +81,8 @@ def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48,
8281
conv_only=True,
8382
)
8483

85-
def forward(self, x: torch.Tensor) -> torch.Tensor:
86-
return self.proj(x)
84+
def forward(self, x: torch.Tensor) -> torch.Tensor:
85+
return super().forward(x)
8786

8887

8988
class Restormer(nn.Module):

0 commit comments

Comments
 (0)