Skip to content

Commit 660acf1

Browse files
🔨 Scaffold for refactor (#2340)
* initial scafold Signed-off-by: Ashwin Vaidya <[email protected]> * Apply PR comments Signed-off-by: Ashwin Vaidya <[email protected]> * rename dir Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 21287ee commit 660acf1

File tree

9 files changed

+277
-88
lines changed

9 files changed

+277
-88
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ core = [
5656
"open-clip-torch>=2.23.0,<2.26.1",
5757
]
5858
openvino = ["openvino>=2024.0", "nncf>=2.10.0", "onnx>=1.16.0"]
59+
vlm = ["ollama", "transformers"]
5960
loggers = [
6061
"comet-ml>=3.31.7",
6162
"gradio>=4",
@@ -84,7 +85,7 @@ test = [
8485
"coverage[toml]",
8586
"tox",
8687
]
87-
full = ["anomalib[core,openvino,loggers,notebooks]"]
88+
full = ["anomalib[core,openvino,loggers,notebooks, vlm]"]
8889
dev = ["anomalib[full,docs,test]"]
8990

9091
[project.scripts]

src/anomalib/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Rkde,
3535
Stfpm,
3636
Uflow,
37+
VlmAd,
3738
WinClip,
3839
)
3940
from .video import AiVad
@@ -62,6 +63,7 @@ class UnknownModelError(ModuleNotFoundError):
6263
"Stfpm",
6364
"Uflow",
6465
"AiVad",
66+
"VlmAd",
6567
"WinClip",
6668
"Llm",
6769
"Llmollama",

src/anomalib/models/image/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .rkde import Rkde
2525
from .stfpm import Stfpm
2626
from .uflow import Uflow
27+
from .vlm import VlmAd
2728
from .winclip import WinClip
2829

2930
__all__ = [
@@ -44,6 +45,7 @@
4445
"Rkde",
4546
"Stfpm",
4647
"Uflow",
48+
"VlmAd",
4749
"WinClip",
4850
"Llm",
4951
"Llmollama",
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Visual Anomaly Model."""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from .lightning_model import VlmAd
7+
8+
__all__ = ["VlmAd"]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""VLM backends."""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from .base import Backend
7+
from .ollama import Ollama
8+
9+
__all__ = ["Backend", "Ollama"]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Base backend."""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from abc import ABC, abstractmethod
7+
from pathlib import Path
8+
9+
10+
class Backend(ABC):
11+
"""Base backend."""
12+
13+
@abstractmethod
14+
def __init__(self, api_key: str | None = None) -> None:
15+
"""Initialize the backend."""
16+
17+
@abstractmethod
18+
def add_reference_images(self, image: str | Path) -> None:
19+
"""Add reference images for k-shot."""
20+
21+
@abstractmethod
22+
def predict(self, image: str | Path) -> str:
23+
"""Predict the anomaly label."""
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Ollama backend.
2+
3+
Assumes that the Ollama service is running in the background.
4+
See: https://github.com/ollama/ollama
5+
Ensure that ollama is running. On linux: `ollama serve`
6+
"""
7+
8+
# Copyright (C) 2024 Intel Corporation
9+
# SPDX-License-Identifier: Apache-2.0
10+
11+
import logging
12+
from dataclasses import dataclass
13+
from pathlib import Path
14+
15+
from anomalib.utils.exceptions import try_import
16+
17+
from .base import Backend
18+
19+
if try_import("ollama"):
20+
from ollama import chat
21+
from ollama._client import _encode_image
22+
else:
23+
chat = None
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
@dataclass
29+
class Prompt:
30+
"""Ollama prompt."""
31+
32+
few_shot: str
33+
predict: str
34+
35+
36+
class Ollama(Backend):
37+
"""Ollama backend."""
38+
39+
def __init__(self, api_key: str | None = None, model_name: str = "llava") -> None:
40+
"""Initialize the Ollama backend."""
41+
if api_key:
42+
logger.warning("API key is not required for Ollama backend.")
43+
self.model_name: str = model_name
44+
self._ref_images_encoded: list[str] = []
45+
46+
def add_reference_images(self, image: str | Path) -> None:
47+
"""Encode the image to base64."""
48+
self._ref_images_encoded.append(_encode_image(image))
49+
50+
@property
51+
def prompt(self) -> Prompt:
52+
"""Get the Ollama prompt."""
53+
return Prompt(
54+
predict=(
55+
"You are given an image. It is either normal or anomalous."
56+
"First say 'YES' if the image is anomalous, or 'NO' if it is normal.\n"
57+
"Then give the reason for your decision.\n"
58+
"For example, 'YES: The image has a crack on the wall.'"
59+
),
60+
few_shot=(
61+
"These are a few examples of normal picture without any anomalies."
62+
" You have to use these to determine if the image I provide in the next"
63+
" chat is normal or anomalous."
64+
),
65+
)
66+
67+
def predict(self, image: str | Path) -> str:
68+
"""Predict the anomaly label."""
69+
if not chat:
70+
msg = "Ollama is not installed. Please install it using `pip install ollama`."
71+
raise ImportError(msg)
72+
image_encoded = _encode_image(image)
73+
messages = []
74+
75+
# few-shot
76+
if len(self._ref_images_encoded) > 0:
77+
messages.append({
78+
"role": "user",
79+
"images": self._ref_images_encoded,
80+
"content": self.prompt.few_shot,
81+
})
82+
83+
messages.append({"role": "user", "images": [image_encoded], "content": self.prompt.predict})
84+
85+
response = chat(
86+
model=self.model_name,
87+
messages=messages,
88+
)
89+
return response["message"]["content"].strip()
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Visual Anomaly Model for Zero/Few-Shot Anomaly Classification."""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
import logging
7+
from enum import Enum
8+
9+
import torch
10+
from torch.utils.data import DataLoader
11+
12+
from anomalib import LearningType
13+
from anomalib.models import AnomalyModule
14+
15+
from .backends import Backend, Ollama
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class VlmAdBackend(Enum):
21+
"""Supported VLM backends."""
22+
23+
OLLAMA = "ollama"
24+
25+
26+
class VlmAd(AnomalyModule):
27+
"""Visual anomaly model."""
28+
29+
def __init__(
30+
self,
31+
backend: VlmAdBackend | str = VlmAdBackend.OLLAMA,
32+
api_key: str | None = None,
33+
k_shot: int = 3,
34+
) -> None:
35+
super().__init__()
36+
self.k_shot = k_shot
37+
backend = VlmAdBackend(backend)
38+
self.vlm_backend: Backend = self._setup_vlm(backend, api_key)
39+
40+
@staticmethod
41+
def _setup_vlm(backend: VlmAdBackend, api_key: str | None) -> Backend:
42+
match backend:
43+
case VlmAdBackend.OLLAMA:
44+
return Ollama()
45+
case _:
46+
msg = f"Unsupported VLM backend: {backend}"
47+
raise ValueError(msg)
48+
49+
def _setup(self) -> None:
50+
if self.k_shot:
51+
logger.info("Collecting reference images from training dataset.")
52+
dataloader = self.trainer.datamodule.train_dataloader()
53+
self.collect_reference_images(dataloader)
54+
55+
def collect_reference_images(self, dataloader: DataLoader) -> None:
56+
"""Collect reference images for few-shot inference."""
57+
count = 0
58+
for batch in dataloader:
59+
for img_path in batch["image_path"]:
60+
self.vlm_backend.add_reference_images(img_path)
61+
count += 1
62+
if count == self.k_shot:
63+
return
64+
65+
def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> dict:
66+
"""Validation step."""
67+
del args, kwargs # These variables are not used.
68+
responses = [(self.vlm_backend.predict(img_path)) for img_path in batch["image_path"]]
69+
70+
batch["str_output"] = responses
71+
batch["pred_scores"] = torch.tensor([1.0 if r.startswith("Y") else 0.0 for r in responses], device=self.device)
72+
return batch
73+
74+
@property
75+
def learning_type(self) -> LearningType:
76+
"""The learning type of the model."""
77+
return LearningType.ZERO_SHOT if self.k_shot == 0 else LearningType.FEW_SHOT
78+
79+
@property
80+
def trainer_arguments(self) -> dict[str, int | float]:
81+
"""Doesn't need training."""
82+
return {}
83+
84+
@staticmethod
85+
def configure_transforms(image_size: tuple[int, int] | None = None) -> None:
86+
"""This modes does not require any transforms."""
87+
if image_size is not None:
88+
logger.warning("Ignoring image_size argument as each backend has its own transforms.")

0 commit comments

Comments
 (0)