Skip to content

Stripped models #7797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 18, 2025
2 changes: 2 additions & 0 deletions invokeai/backend/model_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelConfigBase,
ModelConfigFactory,
ModelFormat,
ModelRepoVariant,
Expand Down Expand Up @@ -32,4 +33,5 @@
"ModelVariantType",
"SchedulerPredictionType",
"SubModelType",
"ModelConfigBase",
]
38 changes: 33 additions & 5 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,26 @@
import time
from abc import ABC, abstractmethod
from enum import Enum
from functools import cached_property
from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union

import diffusers
import onnxruntime as ort
import safetensors.torch
import torch
from diffusers.models.modeling_utils import ModelMixin
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict

from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.raw_model import RawModel
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.silence_warnings import SilenceWarnings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -215,12 +218,37 @@ def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.name = path.name
self.hash_algo = hash_algo

@cached_property
def hash(self):
return ModelHash(algorithm=self.hash_algo).hash(self.path)

def lazy_load_state_dict(self) -> dict[str, torch.Tensor]:
raise NotImplementedError()
def size(self):
if self.format_type == ModelFormat.Checkpoint:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))

def component_paths(self):
if self.format_type == ModelFormat.Checkpoint:
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}

@staticmethod
def load_state_dict(path: Path):
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(path)
else:
raise ValueError(f"Unrecognized model extension: {path.suffix}")

state_dict = checkpoint.get("state_dict", checkpoint)
return state_dict


class MatchSpeed(int, Enum):
Expand Down Expand Up @@ -343,7 +371,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
fields["source"] = fields.get("source") or fields["path"]
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["name"] = mod.name
fields["hash"] = fields.get("hash") or mod.hash
fields["hash"] = fields.get("hash") or mod.hash()

fields.update(overrides)
return cls(**fields)
Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/model_manager/legacy_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from pathlib import Path
from typing import Any, Callable, Dict, Literal, Optional, Union

import picklescan.scanner as pscan
import safetensors.torch
import spandrel
import torch
from picklescan.scanner import scan_file_path

import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
Expand Down Expand Up @@ -483,7 +483,7 @@ def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
scan_result = pscan.scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception(f"The model {model_name} is potentially infected by malware. Aborting import.")
if scan_result.scan_err:
Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/model_manager/util/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from pathlib import Path
from typing import Dict, Optional, Union

import picklescan.scanner as pscan
import safetensors
import torch
from picklescan.scanner import scan_file_path

from invokeai.backend.model_manager.config import ClipVariantType
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
Expand Down Expand Up @@ -57,7 +57,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str,
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
else:
if scan:
scan_result = scan_file_path(path)
scan_result = pscan.scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f"The model at {path} is potentially infected by malware. Aborting import.")
if scan_result.scan_err:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ dependencies = [
"semver~=3.0.1",
"test-tube",
"windows-curses; sys_platform=='win32'",
"humanize==4.12.1",
]

[project.optional-dependencies]
Expand All @@ -103,6 +104,7 @@ dependencies = [
"xformers>=0.0.28.post1; sys_platform!='darwin'",
# torch 2.4+cu carries its own triton dependency
]

"onnx" = ["onnxruntime"]
"onnx-cuda" = ["onnxruntime-gpu"]
"onnx-directml" = ["onnxruntime-directml"]
Expand Down
18 changes: 13 additions & 5 deletions scripts/probe-model.py → scripts/classify-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import get_args

from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe
from invokeai.backend.model_manager import InvalidModelConfigException, ModelConfigBase, ModelProbe

algos = ", ".join(set(get_args(HASHING_ALGORITHMS)))

Expand All @@ -25,9 +25,17 @@
)
args = parser.parse_args()


def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS):
try:
return ModelConfigBase.classify(path, hash_algo)
except InvalidModelConfigException:
return ModelProbe.probe(path, hash_algo=hash_algo)


for path in args.model_path:
try:
info = ModelProbe.probe(path, hash_algo=args.hash_algo)
print(f"{path}:{info.model_dump_json(indent=4)}")
except InvalidModelConfigException as exc:
print(exc)
config = classify_with_fallback(path, args.hash_algo)
print(f"{path}:{config.model_dump_json(indent=4)}")
except InvalidModelConfigException as e:
print(e)
115 changes: 115 additions & 0 deletions scripts/strip_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Usage:
strip_models.py <models_input_dir> <stripped_output_dir>

Strips tensor data from model state_dicts while preserving metadata.
Used to create lightweight models for testing model classification.

Parameters:
<models_input_dir> Directory containing original models.
<stripped_output_dir> Directory where stripped models will be saved.

Options:
-h, --help Show this help message and exit
"""

import argparse
import json
import shutil
import sys
from pathlib import Path

import humanize
import torch

from invokeai.backend.model_manager.config import ModelFormat, ModelOnDisk
from invokeai.backend.model_manager.search import ModelSearch


def strip(v):
match v:
case torch.Tensor():
return {"shape": v.shape, "dtype": str(v.dtype), "fakeTensor": True}
case dict():
return {k: strip(v) for k, v in v.items()}
case list() | tuple():
return [strip(x) for x in v]
case _:
return v


STR_TO_DTYPE = {str(dtype): dtype for dtype in torch.__dict__.values() if isinstance(dtype, torch.dtype)}


def dress(v):
match v:
case {"shape": shape, "dtype": dtype_str, "fakeTensor": True}:
dtype = STR_TO_DTYPE[dtype_str]
return torch.empty(shape, dtype=dtype)
case dict():
return {k: dress(v) for k, v in v.items()}
case list() | tuple():
return [dress(x) for x in v]
case _:
return v


def load_stripped_model(path: Path, *args, **kwargs):
with open(path, "r") as f:
contents = json.load(f)
return dress(contents)


def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk:
original = ModelOnDisk(original_model_path)
if original.format_type == ModelFormat.Checkpoint:
shutil.copy2(original.path, stripped_model_path)
else:
shutil.copytree(original.path, stripped_model_path, dirs_exist_ok=True)
stripped = ModelOnDisk(stripped_model_path)
print(f"Created clone of {original.name} at {stripped.path}")

for component_path in stripped.component_paths():
original_state_dict = ModelOnDisk.load_state_dict(component_path)
stripped_state_dict = strip(original_state_dict) # type: ignore
with open(component_path, "w") as f:
json.dump(stripped_state_dict, f, indent=4)

before_size = humanize.naturalsize(original.size())
after_size = humanize.naturalsize(stripped.size())
print(f"{original.name} before: {before_size}, after: {after_size}")

return stripped


def parse_arguments():
class Parser(argparse.ArgumentParser):
def error(self, reason):
raise ValueError(reason)

parser = Parser()
parser.add_argument("models_input_dir", type=Path)
parser.add_argument("stripped_output_dir", type=Path)

try:
args = parser.parse_args()
except ValueError as e:
print(f"Error: {e}", file=sys.stderr)
print(__doc__, file=sys.stderr)
sys.exit(2)

if not args.models_input_dir.exists():
parser.error(f"Error: Input models directory '{args.models_input_dir}' does not exist.")
if not args.models_input_dir.is_dir():
parser.error(f"Error: '{args.input_models_dir}' is not a directory.")

return args


if __name__ == "__main__":
args = parse_arguments()
model_paths = sorted(ModelSearch().search(args.models_input_dir))

for path in model_paths:
stripped_path = args.stripped_output_dir / path.name
create_stripped_model(path, stripped_path)
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
import logging
import shutil
from pathlib import Path
from types import SimpleNamespace

import picklescan.scanner
import pytest
import safetensors.torch
import torch

import invokeai.backend.quantization.gguf.loaders as gguf_loaders
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
Expand All @@ -20,6 +25,7 @@
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from scripts.strip_models import load_stripped_model
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
from tests.test_nodes import TestEventService
Expand Down Expand Up @@ -73,3 +79,23 @@ def invokeai_root_dir(tmp_path_factory) -> Path:
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
shutil.copytree(root_template, temp_dir)
return temp_dir


@pytest.fixture(scope="function")
def override_model_loading(monkeypatch):
"""The legacy model probe directly calls model loading functions (e.g. torch.load) and also performs file scanning
via picklescan.scanner.scan_file_path. This fixture replaces these functions with test-friendly versions for
model files that have been 'stripped' to reduce their size (see scripts/strip_models.py).

Ideally, model loading would be injected as a dependency (i.e. ModelOnDisk) - but to avoid modifying the legacy probe,
we monkeypatch as a temporary workaround until the legacy probe is fully deprecated.
"""
monkeypatch.setattr(torch, "load", load_stripped_model)
monkeypatch.setattr(safetensors.torch, "load", load_stripped_model)
monkeypatch.setattr(safetensors.torch, "load_file", load_stripped_model)
monkeypatch.setattr(gguf_loaders, "gguf_sd_loader", load_stripped_model)

def fake_scan(*args, **kwargs):
return SimpleNamespace(infected_files=0, scan_err=None)

monkeypatch.setattr(picklescan.scanner, "scan_file_path", fake_scan)
Loading
Loading