Skip to content

Commit afdfbcd

Browse files
Merge pull request #2 from pcuenca/custom-vae-version
Allow a custom VAE to be converted.
2 parents 68c39b3 + a5f0280 commit afdfbcd

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def convert_vae_decoder(pipe, args):
483483
args.latent_w or pipe.unet.config.sample_size, # w
484484
)
485485

486-
if args.xl_version:
486+
if args.custom_vae_version is None and args.xl_version:
487487
inputs_dtype = torch.float32
488488
compute_precision = ct.precision.FLOAT32
489489
# FIXME: Hardcoding to CPU_AND_GPU since ANE doesn't support FLOAT32
@@ -1277,11 +1277,22 @@ def get_pipeline(args):
12771277
model_version = args.model_version
12781278

12791279
logger.info(f"Initializing DiffusionPipeline with {model_version}..")
1280-
pipe = DiffusionPipeline.from_pretrained(model_version,
1280+
if args.custom_vae_version:
1281+
from diffusers import AutoencoderKL
1282+
vae = AutoencoderKL.from_pretrained(args.custom_vae_version, torch_dtype=torch.float16)
1283+
pipe = DiffusionPipeline.from_pretrained(model_version,
1284+
torch_dtype=torch.float16,
1285+
variant="fp16",
1286+
use_safetensors=True,
1287+
vae=vae,
1288+
use_auth_token=True)
1289+
else:
1290+
pipe = DiffusionPipeline.from_pretrained(model_version,
12811291
torch_dtype=torch.float16,
12821292
variant="fp16",
12831293
use_safetensors=True,
12841294
use_auth_token=True)
1295+
12851296
logger.info(f"Done. Pipeline in effect: {pipe.__class__.__name__}")
12861297

12871298
return pipe
@@ -1395,6 +1406,15 @@ def parser_spec():
13951406
"If you would like to convert a refiner model on it's own, use the --model-version argument instead."
13961407
"For available versions: https://huggingface.co/models?sort=trending&search=stable-diffusion+refiner"
13971408
))
1409+
parser.add_argument(
1410+
"--custom-vae-version",
1411+
type=str,
1412+
default=None,
1413+
help=
1414+
("Custom VAE checkpoint to override the pipeline's built-in VAE. "
1415+
"If specified, the specified VAE will be converted instead of the one associated to the `--model-version` checkpoint. "
1416+
"No precision override is applied when using a custom VAE."
1417+
))
13981418
parser.add_argument("--compute-unit",
13991419
choices=tuple(cu
14001420
for cu in ct.ComputeUnit._member_names_),

0 commit comments

Comments
 (0)