Skip to content

Commit a5f0280

Browse files
committed
Allow a custom VAE to be converted.
1 parent efda893 commit a5f0280

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def convert_vae_decoder(pipe, args):
483483
args.latent_w or pipe.unet.config.sample_size, # w
484484
)
485485

486-
if args.xl_version:
486+
if args.custom_vae_version is None and args.xl_version:
487487
inputs_dtype = torch.float32
488488
compute_precision = ct.precision.FLOAT32
489489
# FIXME: Hardcoding to CPU_AND_GPU since ANE doesn't support FLOAT32
@@ -1276,11 +1276,22 @@ def get_pipeline(args):
12761276
model_version = args.model_version
12771277

12781278
logger.info(f"Initializing DiffusionPipeline with {model_version}..")
1279-
pipe = DiffusionPipeline.from_pretrained(model_version,
1279+
if args.custom_vae_version:
1280+
from diffusers import AutoencoderKL
1281+
vae = AutoencoderKL.from_pretrained(args.custom_vae_version, torch_dtype=torch.float16)
1282+
pipe = DiffusionPipeline.from_pretrained(model_version,
1283+
torch_dtype=torch.float16,
1284+
variant="fp16",
1285+
use_safetensors=True,
1286+
vae=vae,
1287+
use_auth_token=True)
1288+
else:
1289+
pipe = DiffusionPipeline.from_pretrained(model_version,
12801290
torch_dtype=torch.float16,
12811291
variant="fp16",
12821292
use_safetensors=True,
12831293
use_auth_token=True)
1294+
12841295
logger.info(f"Done. Pipeline in effect: {pipe.__class__.__name__}")
12851296
return pipe
12861297

@@ -1392,6 +1403,15 @@ def parser_spec():
13921403
"If you would like to convert a refiner model on it's own, use the --model-version argument instead."
13931404
"For available versions: https://huggingface.co/models?sort=trending&search=stable-diffusion+refiner"
13941405
))
1406+
parser.add_argument(
1407+
"--custom-vae-version",
1408+
type=str,
1409+
default=None,
1410+
help=
1411+
("Custom VAE checkpoint to override the pipeline's built-in VAE. "
1412+
"If specified, the specified VAE will be converted instead of the one associated to the `--model-version` checkpoint. "
1413+
"No precision override is applied when using a custom VAE."
1414+
))
13951415
parser.add_argument("--compute-unit",
13961416
choices=tuple(cu
13971417
for cu in ct.ComputeUnit._member_names_),

0 commit comments

Comments
 (0)