Skip to content

πŸ”¨ Minor Refactor #2345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
Fastflow,
Fre,
Ganomaly,
Llava,
Llavanext,
Padim,
Patchcore,
ReverseDistillation,
Expand Down Expand Up @@ -63,8 +61,6 @@ class UnknownModelError(ModuleNotFoundError):
"AiVad",
"VlmAd",
"WinClip",
"Llava",
"Llavanext",
]

logger = logging.getLogger(__name__)
Expand Down
4 changes: 0 additions & 4 deletions src/anomalib/models/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from .fastflow import Fastflow
from .fre import Fre
from .ganomaly import Ganomaly
from .llava import Llava
from .llava_next import Llavanext
from .padim import Padim
from .patchcore import Patchcore
from .reverse_distillation import ReverseDistillation
Expand Down Expand Up @@ -45,6 +43,4 @@
"Uflow",
"VlmAd",
"WinClip",
"Llava",
"Llavanext",
]
6 changes: 4 additions & 2 deletions src/anomalib/models/image/vlm_ad/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
from abc import ABC, abstractmethod
from pathlib import Path

from anomalib.models.image.vlm_ad.utils import Prompt


class Backend(ABC):
"""Base backend."""

@abstractmethod
def __init__(self, model_name: str, api_key: str | None = None) -> None:
def __init__(self, model_name: str) -> None:
"""Initialize the backend."""

@abstractmethod
def add_reference_images(self, image: str | Path) -> None:
"""Add reference images for k-shot."""

@abstractmethod
def predict(self, image: str | Path) -> str:
def predict(self, image: str | Path, prompt: Prompt) -> str:
"""Predict the anomaly label."""
42 changes: 11 additions & 31 deletions src/anomalib/models/image/vlm_ad/backends/chat_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from pathlib import Path
from typing import TYPE_CHECKING

from anomalib.models.image.vlm_ad.utils import Prompt
from anomalib.utils.exceptions import try_import

from .base import Backend
from .dataclasses import Prompt

if try_import("openai"):
from openai import OpenAI
Expand All @@ -27,11 +27,8 @@
class ChatGPT(Backend):
"""ChatGPT backend."""

def __init__(self, api_key: str | None = None, model_name: str = "gpt-4o-mini") -> None:
def __init__(self, api_key: str, model_name: str) -> None:
"""Initialize the ChatGPT backend."""
if api_key is None:
msg = "API key is required for ChatGPT backend."
raise ValueError(msg)
self.api_key = api_key
self._ref_images_encoded: list[str] = []
self.model_name: str = model_name
Expand All @@ -51,30 +48,30 @@ def add_reference_images(self, image: str | Path) -> None:
"""Add reference images for k-shot."""
self._ref_images_encoded.append(self._encode_image_to_url(image))

def predict(self, image: str | Path) -> str:
def predict(self, image: str | Path, prompt: Prompt) -> str:
"""Predict the anomaly label."""
image_encoded = self._encode_image_to_url(image)
messages = []

# few-shot
if len(self._ref_images_encoded) > 0:
messages.append(self._generate_message(content=self.prompt.few_shot, images=self._ref_images_encoded))
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images_encoded))

messages.append(self._generate_message(content=self.prompt.predict, images=[image_encoded]))
messages.append(self._generate_message(content=prompt.predict, images=[image_encoded]))

response: ChatCompletion = self.client.chat.completions.create(messages=messages, model=self.model_name)
return response.choices[0].message.content

@staticmethod
def _generate_message(content: str, images: list[str] | None) -> dict:
"""Generate a message."""
message = {"role": "user"}
if images is None:
message["content"] = content
message: dict[str, list[dict] | str] = {"role": "user"}
if images is not None:
_content: list[dict[str, str | dict]] = [{"type": "text", "text": content}]
_content.extend([{"type": "image_url", "image_url": {"url": image}} for image in images])
message["content"] = _content
else:
message["content"] = [{"type": "text", "text": content}]
for image in images:
message["content"].append({"type": "image_url", "image_url": {"url": image}})
message["content"] = content
return message

def _encode_image_to_url(self, image: str | Path) -> str:
Expand All @@ -89,20 +86,3 @@ def _encode_image_to_base_64(image: str | Path) -> str:
"""Encode the image to base64."""
image = Path(image)
return base64.b64encode(image.read_bytes()).decode("utf-8")

@property
def prompt(self) -> Prompt:
"""Get the Ollama prompt."""
return Prompt(
predict=(
"You are given an image. It is either normal or anomalous."
"First say 'YES' if the image is anomalous, or 'NO' if it is normal.\n"
"Then give the reason for your decision.\n"
"For example, 'YES: The image has a crack on the wall.'"
),
few_shot=(
"These are a few examples of normal picture without any anomalies."
" You have to use these to determine if the image I provide in the next"
" chat is normal or anomalous."
),
)
14 changes: 0 additions & 14 deletions src/anomalib/models/image/vlm_ad/backends/dataclasses.py

This file was deleted.

59 changes: 17 additions & 42 deletions src/anomalib/models/image/vlm_ad/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from enum import Enum
from pathlib import Path

from PIL import Image
from transformers.modeling_utils import PreTrainedModel

from anomalib.models.image.vlm_ad.utils import Prompt
from anomalib.utils.exceptions import try_import

from .base import Backend
from .dataclasses import Prompt

if try_import("transformers"):
import transformers
Expand All @@ -26,49 +25,25 @@
logger = logging.getLogger(__name__)


class LlavaNextModels(Enum):
"""Available models."""

VICUNA_7B = "llava-hf/llava-v1.6-vicuna-7b-hf"
VICUNA_13B = "llava-hf/llava-v1.6-vicuna-13b-hf"
MISTRAL_7B = "llava-hf/llava-v1.6-mistral-7b-hf"


class Huggingface(Backend):
"""Huggingface backend."""

def __init__(
self,
model_name: str,
api_key: str | None = None,
model_name: str | LlavaNextModels = LlavaNextModels.VICUNA_7B,
) -> None:
"""Initialize the Huggingface backend."""
if api_key:
logger.warning("API key is not required for Huggingface backend.")
self.model_name: str = LlavaNextModels(model_name).value
self.model_name: str = model_name
self._ref_images: list[str] = []
self._processor: ProcessorMixin | None = None
self._model: PreTrainedModel | None = None

@property
def prompt(self) -> Prompt:
"""Get the Ollama prompt."""
return Prompt(
predict=(
"You are given an image. It is either normal or anomalous."
" First say 'YES' if the image is anomalous, or 'NO' if it is normal.\n"
"Then give the reason for your decision.\n"
"For example, 'YES: The image has a crack on the wall.'"
),
few_shot=(
"These are a few examples of normal picture without any anomalies."
" You have to use these to determine if the image I provide in the next"
" chat is normal or anomalous."
),
)

@property
def processor(self) -> ProcessorMixin:
"""Get the Huggingface processor."""
if self._processor is None:
if transformers is None:
msg = "transformers is not installed."
Expand All @@ -78,41 +53,41 @@ def processor(self) -> ProcessorMixin:

@property
def model(self) -> PreTrainedModel:
"""Get the Huggingface model."""
if self._model is None:
if transformers is None:
msg = "transformers is not installed."
raise ValueError(msg)
self._model: PreTrainedModel = transformers.LlavaNextForConditionalGeneration.from_pretrained(
self.model_name,
)
self._model = transformers.LlavaNextForConditionalGeneration.from_pretrained(self.model_name)
return self._model

@staticmethod
def _generate_message(content: str, images: list[str] | None) -> dict:
"""Generate a message."""
message = {"role": "user"}
message["content"] = [{"type": "text", "text": content}]
message: dict[str, str | list[dict]] = {"role": "user"}
_content: list[dict[str, str]] = [{"type": "text", "text": content}]
if images is not None:
for _ in images:
message["content"].append({"type": "image"})
_content.extend([{"type": "image"} for _ in images])
message["content"] = _content
return message

def add_reference_images(self, image: str | Path) -> None:
"""Add reference images for k-shot."""
self._ref_images.append(Image.open(image))

def predict(self, image_path: str | Path) -> str:
def predict(self, image_path: str | Path, prompt: Prompt) -> str:
"""Predict the anomaly label."""
image = Image.open(image_path)
messages = []
messages: list[dict] = []

if len(self._ref_images) > 0:
messages.append(self._generate_message(content=self.prompt.few_shot, images=self._ref_images))
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images))

messages.append(self._generate_message(content=self.prompt.predict, images=[image]))
prompt = [self.processor.apply_chat_template(messages, add_generation_prompt=True)]
messages.append(self._generate_message(content=prompt.predict, images=[image]))
processed_prompt = [self.processor.apply_chat_template(messages, add_generation_prompt=True)]

images = [*self._ref_images, image]
inputs = self.processor(images, prompt, return_tensors="pt", padding=True).to(self.model.device)
inputs = self.processor(images, processed_prompt, return_tensors="pt", padding=True).to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=100)
result = self.processor.decode(outputs[0], skip_special_tokens=True)
print(result)
Expand Down
31 changes: 6 additions & 25 deletions src/anomalib/models/image/vlm_ad/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import logging
from pathlib import Path

from anomalib.models.image.vlm_ad.utils import Prompt
from anomalib.utils.exceptions import try_import

from .base import Backend
from .dataclasses import Prompt

if try_import("ollama"):
from ollama import chat
Expand All @@ -28,43 +28,24 @@
class Ollama(Backend):
"""Ollama backend."""

def __init__(self, api_key: str | None = None, model_name: str = "llava") -> None:
def __init__(self, model_name: str) -> None:
"""Initialize the Ollama backend."""
if api_key:
logger.warning("API key is not required for Ollama backend.")
self.model_name: str = model_name
self._ref_images_encoded: list[str] = []

def add_reference_images(self, image: str | Path) -> None:
"""Encode the image to base64."""
self._ref_images_encoded.append(_encode_image(image))

@property
def prompt(self) -> Prompt:
"""Get the Ollama prompt."""
return Prompt(
predict=(
"You are given an image. It is either normal or anomalous."
"First say 'YES' if the image is anomalous, or 'NO' if it is normal.\n"
"Then give the reason for your decision.\n"
"For example, 'YES: The image has a crack on the wall.'"
),
few_shot=(
"These are a few examples of normal picture without any anomalies."
" You have to use these to determine if the image I provide in the next"
" chat is normal or anomalous."
),
)

@staticmethod
def _generate_message(content: str, images: list[str] | None) -> dict:
"""Generate a message."""
message = {"role": "user", "content": content}
message: dict[str, str | list[str]] = {"role": "user", "content": content}
if images:
message["images"] = images
return message

def predict(self, image: str | Path) -> str:
def predict(self, image: str | Path, prompt: Prompt) -> str:
"""Predict the anomaly label."""
if not chat:
msg = "Ollama is not installed. Please install it using `pip install ollama`."
Expand All @@ -74,9 +55,9 @@ def predict(self, image: str | Path) -> str:

# few-shot
if len(self._ref_images_encoded) > 0:
messages.append(self._generate_message(content=self.prompt.few_shot, images=self._ref_images_encoded))
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images_encoded))

messages.append(self._generate_message(content=self.prompt.predict, images=[image_encoded]))
messages.append(self._generate_message(content=prompt.predict, images=[image_encoded]))

response = chat(
model=self.model_name,
Expand Down
Loading
Loading