Skip to content

Be more clear about Spandrel model nomenclature and types #14477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions extensions-builtin/SwinIR/scripts/swinir_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,18 @@ def load_model(self, path, scale=4):
else:
filename = path

model = modelloader.load_spandrel_model(
model_descriptor = modelloader.load_spandrel_model(
filename,
device=self._get_device(),
dtype=devices.dtype,
expected_architecture="SwinIR",
)
if getattr(opts, 'SWIN_torch_compile', False):
try:
model = torch.compile(model)
model_descriptor.model.compile()
except Exception:
logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
return model
return model_descriptor

def _get_device(self):
return devices.get_device_for('swinir')
Expand Down
10 changes: 6 additions & 4 deletions modules/gfpgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
import os

import torch

from modules import (
devices,
errors,
Expand All @@ -25,7 +27,7 @@ def name(self):
def get_device(self):
return devices.device_gfpgan

def load_net(self) -> None:
def load_net(self) -> torch.Module:
for model_path in modelloader.load_models(
model_path=self.model_path,
model_url=model_url,
Expand All @@ -34,13 +36,13 @@ def load_net(self) -> None:
ext_filter=['.pth'],
):
if 'GFPGAN' in os.path.basename(model_path):
net = modelloader.load_spandrel_model(
model = modelloader.load_spandrel_model(
model_path,
device=self.get_device(),
expected_architecture='GFPGAN',
).model
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
return net
model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
return model
raise ValueError("No GFPGAN model found")

def restore(self, np_image):
Expand Down
25 changes: 14 additions & 11 deletions modules/modelloader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

import importlib
import logging
import os
import importlib
from typing import TYPE_CHECKING
from urllib.parse import urlparse

import torch

from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone

if TYPE_CHECKING:
import spandrel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -140,19 +143,19 @@ def load_spandrel_model(
*,
device: str | torch.device | None,
half: bool = False,
dtype: str | None = None,
dtype: str | torch.dtype | None = None,
expected_architecture: str | None = None,
):
) -> spandrel.ModelDescriptor:
import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model.architecture != expected_architecture:
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model_descriptor.architecture != expected_architecture:
logger.warning(
f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})",
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
)
if half:
model = model.model.half()
model_descriptor.model.half()
if dtype:
model = model.model.to(dtype=dtype)
model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype)
return model
model_descriptor.model.to(dtype=dtype)
model_descriptor.model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
return model_descriptor
4 changes: 2 additions & 2 deletions modules/realesrgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def do_upscale(self, img, path):
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img

mod = modelloader.load_spandrel_model(
model_descriptor = modelloader.load_spandrel_model(
info.local_data_path,
device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
)
return upscale_with_model(
mod,
model_descriptor,
img,
tile_size=opts.ESRGAN_tile,
tile_overlap=opts.ESRGAN_tile_overlap,
Expand Down
2 changes: 1 addition & 1 deletion modules/upscaler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tqdm
from PIL import Image

from modules import devices, images
from modules import images

logger = logging.getLogger(__name__)

Expand Down