Skip to content

Port LLaVA to new API #7817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

# pyright: reportIncompatibleVariableOverride=false
import json
import logging
import time
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -232,6 +233,23 @@ def component_paths(self):
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}

def repo_variant(self):
if self.format_type == ModelFormat.Checkpoint:
return None

weight_files = list(self.path.glob("**/*.safetensors"))
weight_files.extend(list(self.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default

@staticmethod
def load_state_dict(path: Path):
with SilenceWarnings():
Expand Down Expand Up @@ -359,21 +377,43 @@ def matches(cls, mod: ModelOnDisk) -> bool:
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
pass

@staticmethod
def cast_overrides(overrides: dict[str, Any]):
"""Casts user overrides from str to Enum"""
if "type" in overrides:
overrides["type"] = ModelType(overrides["type"])

if "format" in overrides:
overrides["format"] = ModelFormat(overrides["format"])

if "base" in overrides:
overrides["base"] = BaseModelType(overrides["base"])

if "source_type" in overrides:
overrides["source_type"] = ModelSourceType(overrides["source_type"])

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
if not cls.matches(mod):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")

fields = cls.parse(mod)
cls.cast_overrides(overrides)
fields.update(overrides)

type = fields.get("type") or cls.model_fields["type"].default
base = fields.get("base") or cls.model_fields["base"].default

fields["path"] = mod.path.as_posix()
fields["source"] = fields.get("source") or fields["path"]
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["name"] = mod.name
fields["name"] = name = fields.get("name") or mod.name
fields["hash"] = fields.get("hash") or mod.hash()
fields["key"] = fields.get("key") or uuid_string()
fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}"
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()

fields.update(overrides)
return cls(**fields)


Expand Down Expand Up @@ -625,12 +665,34 @@ class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint


class LlavaOnevisionConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for Llava Onevision models."""

type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers

@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.format_type == ModelFormat.Checkpoint:
return False

config_path = mod.path / "config.json"
try:
with open(config_path, "r") as file:
config = json.load(file)
except FileNotFoundError:
return False

architectures = config.get("architectures")
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"

@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": BaseModelType.Any,
"variant": ModelVariantType.Normal,
}


def get_model_discriminator_value(v: Any) -> str:
"""
Expand Down
8 changes: 5 additions & 3 deletions tests/test_model_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,24 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
configs_with_tests = set()
model_paths = ModelSearch().search(datadir / "stripped_models")
fake_hash = "abcdefgh" # skip hashing to make test quicker
fake_key = "123" # fixed uuid for comparison

for path in model_paths:
legacy_config = new_config = None

try:
legacy_config = ModelProbe.probe(path, {"hash": fake_hash})
legacy_config = ModelProbe.probe(path, {"hash": fake_hash, "key": fake_key})
except InvalidModelConfigException:
pass

try:
new_config = ModelConfigBase.classify(path, hash=fake_hash)
new_config = ModelConfigBase.classify(path, hash=fake_hash, key=fake_key)
except InvalidModelConfigException:
pass

if legacy_config and new_config:
assert legacy_config == new_config
assert type(legacy_config) is type(new_config)
assert legacy_config.model_dump_json() == new_config.model_dump_json()

elif legacy_config:
assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE
Expand Down
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Loading