Skip to content

Commit 324e181

Browse files
committed
Resolve conflicts
Signed-off-by: Samet Akcay <[email protected]>
2 parents 244f50b + bcc0b43 commit 324e181

31 files changed

+3806
-49
lines changed

notebooks/700_metrics/701b_aupimo_advanced_i.ipynb

+255-30
Large diffs are not rendered by default.

notebooks/700_metrics/701c_aupimo_advanced_ii.ipynb

+134-18
Large diffs are not rendered by default.

src/anomalib/models/components/base/export_mixin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def to_onnx(
125125
dynamic_axes = (
126126
{"input": {0: "batch_size"}, "output": {0: "batch_size"}}
127127
if input_size
128-
else {"input": {0: "batch_size", 2: "height", 3: "width"}, "output": {0: "batch_size"}}
128+
else {"input": {0: "batch_size", 2: "height", 3: "weight"}, "output": {0: "batch_size"}}
129129
)
130130
onnx_path = export_root / "model.onnx"
131131
# apply pass through the model to get the output names
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Tiled ensemble pipelines."""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from .test_pipeline import EvalTiledEnsemble
7+
from .train_pipeline import TrainTiledEnsemble
8+
9+
__all__ = [
10+
"TrainTiledEnsemble",
11+
"EvalTiledEnsemble",
12+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Tiled ensemble pipeline components."""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from .merging import MergeJobGenerator
7+
from .metrics_calculation import MetricsCalculationJobGenerator
8+
from .model_training import TrainModelJobGenerator
9+
from .normalization import NormalizationJobGenerator
10+
from .prediction import PredictJobGenerator
11+
from .smoothing import SmoothingJobGenerator
12+
from .stats_calculation import StatisticsJobGenerator
13+
from .thresholding import ThresholdingJobGenerator
14+
from .utils import NormalizationStage, PredictData, ThresholdStage
15+
from .visualization import VisualizationJobGenerator
16+
17+
__all__ = [
18+
"NormalizationStage",
19+
"ThresholdStage",
20+
"PredictData",
21+
"TrainModelJobGenerator",
22+
"PredictJobGenerator",
23+
"MergeJobGenerator",
24+
"SmoothingJobGenerator",
25+
"StatisticsJobGenerator",
26+
"NormalizationJobGenerator",
27+
"ThresholdingJobGenerator",
28+
"VisualizationJobGenerator",
29+
"MetricsCalculationJobGenerator",
30+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Tiled ensemble - prediction merging job."""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
import logging
7+
from collections.abc import Generator
8+
from typing import Any
9+
10+
from tqdm import tqdm
11+
12+
from anomalib.pipelines.components import Job, JobGenerator
13+
from anomalib.pipelines.types import GATHERED_RESULTS, RUN_RESULTS
14+
15+
from .utils.ensemble_tiling import EnsembleTiler
16+
from .utils.helper_functions import get_ensemble_tiler
17+
from .utils.prediction_data import EnsemblePredictions
18+
from .utils.prediction_merging import PredictionMergingMechanism
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class MergeJob(Job):
24+
"""Job for merging tile-level predictions into image-level predictions.
25+
26+
Args:
27+
predictions (EnsemblePredictions): Object containing ensemble predictions.
28+
tiler (EnsembleTiler): Ensemble tiler used for untiling.
29+
"""
30+
31+
name = "Merge"
32+
33+
def __init__(self, predictions: EnsemblePredictions, tiler: EnsembleTiler) -> None:
34+
super().__init__()
35+
self.predictions = predictions
36+
self.tiler = tiler
37+
38+
def run(self, task_id: int | None = None) -> list[Any]:
39+
"""Run merging job that merges all batches of tile-level predictions into image-level predictions.
40+
41+
Args:
42+
task_id: Not used in this case.
43+
44+
Returns:
45+
list[Any]: List of merged predictions.
46+
"""
47+
del task_id # not needed here
48+
49+
merger = PredictionMergingMechanism(self.predictions, self.tiler)
50+
51+
logger.info("Merging predictions.")
52+
53+
# merge all batches
54+
merged_predictions = [
55+
merger.merge_tile_predictions(batch_idx)
56+
for batch_idx in tqdm(range(merger.num_batches), desc="Prediction merging")
57+
]
58+
59+
return merged_predictions # noqa: RET504
60+
61+
@staticmethod
62+
def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS:
63+
"""Nothing to collect in this job.
64+
65+
Returns:
66+
list[Any]: List of predictions.
67+
"""
68+
# take the first element as result is list of lists here
69+
return results[0]
70+
71+
@staticmethod
72+
def save(results: GATHERED_RESULTS) -> None:
73+
"""Nothing to save in this job."""
74+
75+
76+
class MergeJobGenerator(JobGenerator):
77+
"""Generate MergeJob."""
78+
79+
def __init__(self, tiling_args: dict, data_args: dict) -> None:
80+
super().__init__()
81+
self.tiling_args = tiling_args
82+
self.data_args = data_args
83+
84+
@property
85+
def job_class(self) -> type:
86+
"""Return the job class."""
87+
return MergeJob
88+
89+
def generate_jobs(
90+
self,
91+
args: dict | None = None,
92+
prev_stage_result: EnsemblePredictions | None = None,
93+
) -> Generator[MergeJob, None, None]:
94+
"""Return a generator producing a single merging job.
95+
96+
Args:
97+
args (dict): Tiled ensemble pipeline args.
98+
prev_stage_result (EnsemblePredictions): Ensemble predictions from predict step.
99+
100+
Returns:
101+
Generator[MergeJob, None, None]: MergeJob generator
102+
"""
103+
del args # args not used here
104+
105+
tiler = get_ensemble_tiler(self.tiling_args, self.data_args)
106+
if prev_stage_result is not None:
107+
yield MergeJob(prev_stage_result, tiler)
108+
else:
109+
msg = "Merging job requires tile level predictions from previous step."
110+
raise ValueError(msg)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""Tiled ensemble - metrics calculation job."""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
import logging
7+
from collections.abc import Generator
8+
from pathlib import Path
9+
from typing import Any
10+
11+
import pandas as pd
12+
from tqdm import tqdm
13+
14+
from anomalib import TaskType
15+
from anomalib.metrics import AnomalibMetricCollection, create_metric_collection
16+
from anomalib.pipelines.components import Job, JobGenerator
17+
from anomalib.pipelines.types import GATHERED_RESULTS, PREV_STAGE_RESULT, RUN_RESULTS
18+
19+
from .utils import NormalizationStage
20+
from .utils.helper_functions import get_threshold_values
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
class MetricsCalculationJob(Job):
26+
"""Job for image and pixel metrics calculation.
27+
28+
Args:
29+
accelerator (str): Accelerator (device) to use.
30+
predictions (list[Any]): List of batch predictions.
31+
root_dir (Path): Root directory to save checkpoints, stats and images.
32+
image_metrics (AnomalibMetricCollection): Collection of all image-level metrics.
33+
pixel_metrics (AnomalibMetricCollection): Collection of all pixel-level metrics.
34+
"""
35+
36+
name = "Metrics"
37+
38+
def __init__(
39+
self,
40+
accelerator: str,
41+
predictions: list[Any] | None,
42+
root_dir: Path,
43+
image_metrics: AnomalibMetricCollection,
44+
pixel_metrics: AnomalibMetricCollection,
45+
) -> None:
46+
super().__init__()
47+
self.accelerator = accelerator
48+
self.predictions = predictions
49+
self.root_dir = root_dir
50+
self.image_metrics = image_metrics
51+
self.pixel_metrics = pixel_metrics
52+
53+
def run(self, task_id: int | None = None) -> dict:
54+
"""Run a job that calculates image and pixel level metrics.
55+
56+
Args:
57+
task_id: Not used in this case.
58+
59+
Returns:
60+
dict[str, float]: Dictionary containing calculated metric values.
61+
"""
62+
del task_id # not needed here
63+
64+
logger.info("Starting metrics calculation.")
65+
66+
# add predicted data to metrics
67+
for data in tqdm(self.predictions, desc="Calculating metrics"):
68+
self.image_metrics.update(data["pred_scores"], data["label"].int())
69+
if "mask" in data and "anomaly_maps" in data:
70+
self.pixel_metrics.update(data["anomaly_maps"], data["mask"].int())
71+
72+
# compute all metrics on specified accelerator
73+
metrics_dict = {}
74+
for name, metric in self.image_metrics.items():
75+
metric.to(self.accelerator)
76+
metrics_dict[name] = metric.compute().item()
77+
metric.cpu()
78+
79+
if self.pixel_metrics.update_called:
80+
for name, metric in self.pixel_metrics.items():
81+
metric.to(self.accelerator)
82+
metrics_dict[name] = metric.compute().item()
83+
metric.cpu()
84+
85+
for name, value in metrics_dict.items():
86+
print(f"{name}: {value:.4f}")
87+
88+
# save path used in `save` method
89+
metrics_dict["save_path"] = self.root_dir / "metric_results.csv"
90+
91+
return metrics_dict
92+
93+
@staticmethod
94+
def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS:
95+
"""Nothing to collect in this job.
96+
97+
Returns:
98+
list[Any]: list of predictions.
99+
"""
100+
# take the first element as result is list of dict here
101+
return results[0]
102+
103+
@staticmethod
104+
def save(results: GATHERED_RESULTS) -> None:
105+
"""Save metrics values to csv."""
106+
logger.info("Saving metrics to csv.")
107+
108+
# get and remove path from stats dict
109+
results_path: Path = results.pop("save_path")
110+
results_path.parent.mkdir(parents=True, exist_ok=True)
111+
112+
df_dict = {k: [v] for k, v in results.items()}
113+
metrics_df = pd.DataFrame(df_dict)
114+
metrics_df.to_csv(results_path, index=False)
115+
116+
117+
class MetricsCalculationJobGenerator(JobGenerator):
118+
"""Generate MetricsCalculationJob.
119+
120+
Args:
121+
root_dir (Path): Root directory to save checkpoints, stats and images.
122+
"""
123+
124+
def __init__(
125+
self,
126+
accelerator: str,
127+
root_dir: Path,
128+
task: TaskType,
129+
metrics: dict,
130+
normalization_stage: NormalizationStage,
131+
) -> None:
132+
self.accelerator = accelerator
133+
self.root_dir = root_dir
134+
self.task = task
135+
self.metrics = metrics
136+
self.normalization_stage = normalization_stage
137+
138+
@property
139+
def job_class(self) -> type:
140+
"""Return the job class."""
141+
return MetricsCalculationJob
142+
143+
def configure_ensemble_metrics(
144+
self,
145+
image_metrics: list[str] | dict[str, dict[str, Any]] | None = None,
146+
pixel_metrics: list[str] | dict[str, dict[str, Any]] | None = None,
147+
) -> tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
148+
"""Configure image and pixel metrics and put them into a collection.
149+
150+
Args:
151+
image_metrics (list[str] | None): List of image-level metric names.
152+
pixel_metrics (list[str] | None): List of pixel-level metric names.
153+
154+
Returns:
155+
tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
156+
Image-metrics collection and pixel-metrics collection
157+
"""
158+
image_metrics = [] if image_metrics is None else image_metrics
159+
160+
if pixel_metrics is None:
161+
pixel_metrics = []
162+
elif self.task == TaskType.CLASSIFICATION:
163+
pixel_metrics = []
164+
logger.warning(
165+
"Cannot perform pixel-level evaluation when task type is classification. "
166+
"Ignoring the following pixel-level metrics: %s",
167+
pixel_metrics,
168+
)
169+
170+
# if a single metric is passed, transform to list to fit the creation function
171+
if isinstance(image_metrics, str):
172+
image_metrics = [image_metrics]
173+
if isinstance(pixel_metrics, str):
174+
pixel_metrics = [pixel_metrics]
175+
176+
image_metrics_collection = create_metric_collection(image_metrics, "image_")
177+
pixel_metrics_collection = create_metric_collection(pixel_metrics, "pixel_")
178+
179+
return image_metrics_collection, pixel_metrics_collection
180+
181+
def generate_jobs(
182+
self,
183+
args: dict | None = None,
184+
prev_stage_result: PREV_STAGE_RESULT = None,
185+
) -> Generator[MetricsCalculationJob, None, None]:
186+
"""Make a generator that yields a single metrics calculation job.
187+
188+
Args:
189+
args: ensemble run config.
190+
prev_stage_result: ensemble predictions from previous step.
191+
192+
Returns:
193+
Generator[MetricsCalculationJob, None, None]: MetricsCalculationJob generator
194+
"""
195+
del args # args not used here
196+
197+
image_metrics_config = self.metrics.get("image", None)
198+
pixel_metrics_config = self.metrics.get("pixel", None)
199+
200+
image_threshold, pixel_threshold = get_threshold_values(self.normalization_stage, self.root_dir)
201+
202+
image_metrics, pixel_metrics = self.configure_ensemble_metrics(
203+
image_metrics=image_metrics_config,
204+
pixel_metrics=pixel_metrics_config,
205+
)
206+
207+
# set thresholds for metrics that need it
208+
image_metrics.set_threshold(image_threshold)
209+
pixel_metrics.set_threshold(pixel_threshold)
210+
211+
yield MetricsCalculationJob(
212+
accelerator=self.accelerator,
213+
predictions=prev_stage_result,
214+
root_dir=self.root_dir,
215+
image_metrics=image_metrics,
216+
pixel_metrics=pixel_metrics,
217+
)

0 commit comments

Comments
 (0)