@@ -168,20 +168,19 @@ def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True
168
168
if "pixel_threshold_class" in state_dict :
169
169
self .pixel_threshold = self ._get_instance (state_dict , "pixel_threshold_class" )
170
170
171
- if "anomaly_maps_normalization_class" in state_dict :
172
- self .anomaly_maps_normalization_metrics = self ._get_instance (state_dict , "anomaly_maps_normalization_class" )
173
- if "box_scores_normalization_class" in state_dict :
174
- self .box_scores_normalization_metrics = self ._get_instance (state_dict , "box_scores_normalization_class" )
171
+ # check only for pred score normalization metrics, because if this one is present, all others are too
175
172
if "pred_scores_normalization_class" in state_dict :
173
+ self .box_scores_normalization_metrics = self ._get_instance (state_dict , "box_scores_normalization_class" )
174
+ self .anomaly_maps_normalization_metrics = self ._get_instance (state_dict , "anomaly_maps_normalization_class" )
176
175
self .pred_scores_normalization_metrics = self ._get_instance (state_dict , "pred_scores_normalization_class" )
177
176
178
- self .normalization_metrics = MetricCollection (
179
- {
180
- "anomaly_maps" : self .anomaly_maps_normalization_metrics ,
181
- "box_scores" : self .box_scores_normalization_metrics ,
182
- "pred_scores" : self .pred_scores_normalization_metrics ,
183
- },
184
- )
177
+ self .normalization_metrics = MetricCollection (
178
+ {
179
+ "anomaly_maps" : self .anomaly_maps_normalization_metrics ,
180
+ "box_scores" : self .box_scores_normalization_metrics ,
181
+ "pred_scores" : self .pred_scores_normalization_metrics ,
182
+ },
183
+ )
185
184
# Used to load metrics if there is any related data in state_dict
186
185
self ._load_metrics (state_dict )
187
186
0 commit comments