Skip to content

Add LLaVA OneVision model support #7693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 18, 2025
Merged
2 changes: 2 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
ControlLoRAModel = "ControlLoRAModelField"
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
LlavaOnevisionModel = "LLaVAModelField"
# endregion

# region Misc Field Types
Expand Down Expand Up @@ -205,6 +206,7 @@ class FieldDescriptions:
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
flux_redux_conditioning = "FLUX Redux conditioning tensor"
vllm_model = "The VLLM model to use"


class ImageField(BaseModel):
Expand Down
60 changes: 60 additions & 0 deletions invokeai/app/invocations/llava_onevision_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Any

import torch
from PIL.Image import Image
from pydantic import field_validator

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.util.devices import TorchDevice


@invocation("llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], category="vllm", version="1.0.0")
class LlavaOnevisionVllmInvocation(BaseInvocation):
"""Run a LLaVA OneVision VLLM model."""

images: list[ImageField] | ImageField | None = InputField(default=None, max_length=3, description="Input image.")
prompt: str = InputField(
default="",
description="Input text prompt.",
ui_component=UIComponent.Textarea,
)
vllm_model: ModelIdentifierField = InputField(
title="LLaVA Model Type",
description=FieldDescriptions.vllm_model,
ui_type=UIType.LlavaOnevisionModel,
)

@field_validator("images", mode="before")
def listify_images(cls, v: Any) -> list:
if v is None:
return v
if not isinstance(v, list):
return [v]
return v

def _get_images(self, context: InvocationContext) -> list[Image]:
if self.images is None:
return []

image_fields = self.images if isinstance(self.images, list) else [self.images]
return [context.images.get_pil(image_field.image_name, "RGB") for image_field in image_fields]

@torch.no_grad()
def invoke(self, context: InvocationContext) -> StringOutput:
images = self._get_images(context)

with context.models.load(self.vllm_model) as vllm_model:
assert isinstance(vllm_model, LlavaOnevisionModel)
output = vllm_model.run(
prompt=self.prompt,
images=images,
device=TorchDevice.choose_torch_device(),
dtype=TorchDevice.choose_torch_dtype(),
)

return StringOutput(value=output)
49 changes: 49 additions & 0 deletions invokeai/backend/llava_onevision_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pathlib import Path
from typing import Optional

import torch
from PIL.Image import Image
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor

from invokeai.backend.raw_model import RawModel


class LlavaOnevisionModel(RawModel):
def __init__(self, vllm_model: LlavaOnevisionForConditionalGeneration, processor: LlavaOnevisionProcessor):
self._vllm_model = vllm_model
self._processor = processor

@classmethod
def load_from_path(cls, path: str | Path):
vllm_model = LlavaOnevisionForConditionalGeneration.from_pretrained(path, local_files_only=True)
assert isinstance(vllm_model, LlavaOnevisionForConditionalGeneration)
processor = AutoProcessor.from_pretrained(path, local_files_only=True)
assert isinstance(processor, LlavaOnevisionProcessor)
return cls(vllm_model, processor)

def run(self, prompt: str, images: list[Image], device: torch.device, dtype: torch.dtype) -> str:
# TODO(ryand): Tune the max number of images that are useful for the model.
if len(images) > 3:
raise ValueError(
f"{len(images)} images were provided as input to the LLaVA OneVision model. "
"Pass <=3 images for good performance."
)

# Define a chat history and use `apply_chat_template` to get correctly formatted prompt.
# "content" is a list of dicts with types "text" or "image".
content = [{"type": "text", "text": prompt}]
# Add the correct number of images.
for _ in images:
content.append({"type": "image"})

conversation = [{"role": "user", "content": content}]
prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = self._processor(images=images or None, text=prompt, return_tensors="pt").to(device=device, dtype=dtype)
output = self._vllm_model.generate(**inputs, max_new_tokens=400, do_sample=False)
output_str: str = self._processor.decode(output[0][2:], skip_special_tokens=True)
# The output_str will include the prompt, so we extract the response.
response = output_str.split("assistant\n", 1)[1].strip()
return response

def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
self._vllm_model.to(device=device, dtype=dtype)
13 changes: 13 additions & 0 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ModelType(str, Enum):
SpandrelImageToImage = "spandrel_image_to_image"
SigLIP = "siglip"
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"


class SubModelType(str, Enum):
Expand Down Expand Up @@ -552,6 +553,17 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.FluxRedux.value}.{ModelFormat.Checkpoint.value}")


class LlavaOnevisionConfig(DiffusersConfigBase):
"""Model config for Llava Onevision models."""

type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers

@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LlavaOnevision.value}.{ModelFormat.Diffusers.value}")


def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
Expand Down Expand Up @@ -601,6 +613,7 @@ def get_model_discriminator_value(v: Any) -> str:
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path
from typing import Optional

from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LlavaOnevision, format=ModelFormat.Diffusers)
class LlavaOnevisionModelLoader(ModelLoader):
"""Class for loading LLaVA Onevision VLLM models."""

def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")

model_path = Path(config.path)
model = LlavaOnevisionModel.load_from_path(model_path)
model.to(dtype=self._torch_dtype)
return model
13 changes: 13 additions & 0 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class ModelProbe(object):
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
"SiglipModel": ModelType.SigLIP,
"LlavaOnevisionForConditionalGeneration": ModelType.LlavaOnevision,
}

TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
Expand Down Expand Up @@ -767,6 +768,11 @@ def get_base_type(self) -> BaseModelType:
return BaseModelType.Flux


class LlavaOnevisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()


########################################################
# classes for probing folders
#######################################################
Expand Down Expand Up @@ -1047,6 +1053,11 @@ def get_base_type(self) -> BaseModelType:
raise NotImplementedError()


class LlaveOnevisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any


class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
Expand Down Expand Up @@ -1082,6 +1093,7 @@ def get_base_type(self) -> BaseModelType:
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)

ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
Expand All @@ -1095,5 +1107,6 @@ def get_base_type(self) -> BaseModelType:
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe)

ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
11 changes: 11 additions & 0 deletions invokeai/backend/model_manager/starter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,16 @@ class StarterModelBundles(BaseModel):
)
# endregion

# region LlavaOnevisionModel
llava_onevision = StarterModel(
name="LLaVA Onevision Qwen2 0.5B",
base=BaseModelType.Any,
source="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
description="LLaVA Onevision VLLM model",
type=ModelType.LlavaOnevision,
)
# endregion

# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
Expand Down Expand Up @@ -683,6 +693,7 @@ class StarterModelBundles(BaseModel):
clip_l_encoder,
siglip,
flux_redux,
llava_onevision,
]

sd1_bundle: list[StarterModel] = [
Expand Down
1 change: 1 addition & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"controlLora": "Control LoRA",
"llavaOnevision": "LLaVA OneVision",
"syncModels": "Sync Models",
"textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
useEmbeddingModels,
useFluxReduxModels,
useIPAdapterModels,
useLLaVAModels,
useLoRAModels,
useMainModels,
useRefinerModels,
Expand Down Expand Up @@ -126,6 +127,12 @@ const ModelList = () => {
[fluxReduxModels, searchTerm, filteredModelType]
);

const [llavaOneVisionModels, { isLoading: isLoadingLlavaOneVisionModels }] = useLLaVAModels();
const filteredLlavaOneVisionModels = useMemo(
() => modelsFilter(llavaOneVisionModels, searchTerm, filteredModelType),
[llavaOneVisionModels, searchTerm, filteredModelType]
);

const totalFilteredModels = useMemo(() => {
return (
filteredMainModels.length +
Expand Down Expand Up @@ -236,6 +243,17 @@ const ModelList = () => {
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
<ModelListWrapper title={t('modelManager.clipEmbed')} modelList={filteredClipEmbedModels} key="clip-embed" />
)}

{/* LLaVA OneVision List */}
{isLoadingLlavaOneVisionModels && <FetchingModelsLoader loadingMessage="Loading LLaVA OneVision Models..." />}
{!isLoadingLlavaOneVisionModels && filteredLlavaOneVisionModels.length > 0 && (
<ModelListWrapper
title={t('modelManager.llavaOnevision')}
modelList={filteredLlavaOneVisionModels}
key="llava-onevision"
/>
)}

{/* Spandrel Image to Image List */}
{isLoadingSpandrelImageToImageModels && (
<FetchingModelsLoader loadingMessage="Loading Image-to-Image Models..." />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export const ModelTypeFilter = memo(() => {
control_lora: t('modelManager.controlLora'),
siglip: t('modelManager.siglip'),
flux_redux: t('modelManager.fluxRedux'),
llava_onevision: t('modelManager.llavaOnevision'),
}),
[t]
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ import {
isIntegerGeneratorFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
isIPAdapterModelFieldInputTemplate,
isLLaVAModelFieldInputInstance,
isLLaVAModelFieldInputTemplate,
isLoRAModelFieldInputInstance,
isLoRAModelFieldInputTemplate,
isMainModelFieldInputInstance,
Expand Down Expand Up @@ -112,6 +114,7 @@ import FluxReduxModelFieldInputComponent from './inputs/FluxReduxModelFieldInput
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import LLaVAModelFieldInputComponent from './inputs/LLaVAModelFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
Expand Down Expand Up @@ -322,6 +325,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isLLaVAModelFieldInputTemplate(template)) {
if (!isLLaVAModelFieldInputInstance(field)) {
return null;
}
return <LLaVAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isFluxVAEModelFieldInputTemplate(template)) {
if (!isFluxVAEModelFieldInputInstance(field)) {
return null;
Expand Down
Loading