Skip to content

Commit f473df8

Browse files
blaz-rsamet-akcay
andauthored
Add check before loading metrics data from checkpoint (#2323)
Add check before loading from checkpoint Signed-off-by: Blaz Rolih <[email protected]> Co-authored-by: Samet Akcay <[email protected]>
1 parent 983ec58 commit f473df8

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

src/anomalib/models/components/base/anomaly_module.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,19 @@ def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True
168168
if "pixel_threshold_class" in state_dict:
169169
self.pixel_threshold = self._get_instance(state_dict, "pixel_threshold_class")
170170

171-
if "anomaly_maps_normalization_class" in state_dict:
172-
self.anomaly_maps_normalization_metrics = self._get_instance(state_dict, "anomaly_maps_normalization_class")
173-
if "box_scores_normalization_class" in state_dict:
174-
self.box_scores_normalization_metrics = self._get_instance(state_dict, "box_scores_normalization_class")
171+
# check only for pred score normalization metrics, because if this one is present, all others are too
175172
if "pred_scores_normalization_class" in state_dict:
173+
self.box_scores_normalization_metrics = self._get_instance(state_dict, "box_scores_normalization_class")
174+
self.anomaly_maps_normalization_metrics = self._get_instance(state_dict, "anomaly_maps_normalization_class")
176175
self.pred_scores_normalization_metrics = self._get_instance(state_dict, "pred_scores_normalization_class")
177176

178-
self.normalization_metrics = MetricCollection(
179-
{
180-
"anomaly_maps": self.anomaly_maps_normalization_metrics,
181-
"box_scores": self.box_scores_normalization_metrics,
182-
"pred_scores": self.pred_scores_normalization_metrics,
183-
},
184-
)
177+
self.normalization_metrics = MetricCollection(
178+
{
179+
"anomaly_maps": self.anomaly_maps_normalization_metrics,
180+
"box_scores": self.box_scores_normalization_metrics,
181+
"pred_scores": self.pred_scores_normalization_metrics,
182+
},
183+
)
185184
# Used to load metrics if there is any related data in state_dict
186185
self._load_metrics(state_dict)
187186

0 commit comments

Comments
 (0)