Skip to content

Commit 3d37124

Browse files
committed
Fix anomaly map shapes to work with tiling
Signed-off-by: blaz-r <[email protected]>
1 parent 5390d7a commit 3d37124

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

src/anomalib/models/image/padim/torch_model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
126126
torch.Size([32, 128, 28, 28]),
127127
torch.Size([32, 256, 14, 14])]
128128
"""
129+
output_size = input_tensor.shape[-2:]
129130
if self.tiler:
130131
input_tensor = self.tiler.tile(input_tensor)
131132

@@ -143,7 +144,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
143144
embedding=embeddings,
144145
mean=self.gaussian.mean,
145146
inv_covariance=self.gaussian.inv_covariance,
146-
image_size=input_tensor.shape[-2:],
147+
image_size=output_size,
147148
)
148149
return output
149150

src/anomalib/models/image/patchcore/torch_model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | dict[str, torch.
7070
Returns:
7171
Tensor | dict[str, torch.Tensor]: Embedding for training, anomaly map and anomaly score for testing.
7272
"""
73+
output_size = input_tensor.shape[-2:]
7374
if self.tiler:
7475
input_tensor = self.tiler.tile(input_tensor)
7576

@@ -98,7 +99,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | dict[str, torch.
9899
# reshape to w, h
99100
patch_scores = patch_scores.reshape((batch_size, 1, width, height))
100101
# get anomaly map
101-
anomaly_map = self.anomaly_map_generator(patch_scores, input_tensor.shape[-2:])
102+
anomaly_map = self.anomaly_map_generator(patch_scores, output_size)
102103

103104
output = {"anomaly_map": anomaly_map, "pred_score": pred_score}
104105

src/anomalib/models/image/stfpm/torch_model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def forward(self, images: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor
6161
Returns:
6262
Teacher and student features when in training mode, otherwise the predicted anomaly maps.
6363
"""
64+
output_size = images.shape[-2:]
6465
if self.tiler:
6566
images = self.tiler.tile(images)
6667
teacher_features: dict[str, torch.Tensor] = self.teacher_model(images)
@@ -78,7 +79,7 @@ def forward(self, images: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor
7879
output = self.anomaly_map_generator(
7980
teacher_features=teacher_features,
8081
student_features=student_features,
81-
image_size=images.shape[-2:],
82+
image_size=output_size,
8283
)
8384

8485
return output

0 commit comments

Comments
 (0)