@@ -483,7 +483,7 @@ def convert_vae_decoder(pipe, args):
483
483
args .latent_w or pipe .unet .config .sample_size , # w
484
484
)
485
485
486
- if args .xl_version :
486
+ if args .custom_vae_version is None and args . xl_version :
487
487
inputs_dtype = torch .float32
488
488
compute_precision = ct .precision .FLOAT32
489
489
# FIXME: Hardcoding to CPU_AND_GPU since ANE doesn't support FLOAT32
@@ -1277,11 +1277,22 @@ def get_pipeline(args):
1277
1277
model_version = args .model_version
1278
1278
1279
1279
logger .info (f"Initializing DiffusionPipeline with { model_version } .." )
1280
- pipe = DiffusionPipeline .from_pretrained (model_version ,
1280
+ if args .custom_vae_version :
1281
+ from diffusers import AutoencoderKL
1282
+ vae = AutoencoderKL .from_pretrained (args .custom_vae_version , torch_dtype = torch .float16 )
1283
+ pipe = DiffusionPipeline .from_pretrained (model_version ,
1284
+ torch_dtype = torch .float16 ,
1285
+ variant = "fp16" ,
1286
+ use_safetensors = True ,
1287
+ vae = vae ,
1288
+ use_auth_token = True )
1289
+ else :
1290
+ pipe = DiffusionPipeline .from_pretrained (model_version ,
1281
1291
torch_dtype = torch .float16 ,
1282
1292
variant = "fp16" ,
1283
1293
use_safetensors = True ,
1284
1294
use_auth_token = True )
1295
+
1285
1296
logger .info (f"Done. Pipeline in effect: { pipe .__class__ .__name__ } " )
1286
1297
1287
1298
return pipe
@@ -1395,6 +1406,15 @@ def parser_spec():
1395
1406
"If you would like to convert a refiner model on it's own, use the --model-version argument instead."
1396
1407
"For available versions: https://huggingface.co/models?sort=trending&search=stable-diffusion+refiner"
1397
1408
))
1409
+ parser .add_argument (
1410
+ "--custom-vae-version" ,
1411
+ type = str ,
1412
+ default = None ,
1413
+ help =
1414
+ ("Custom VAE checkpoint to override the pipeline's built-in VAE. "
1415
+ "If specified, the specified VAE will be converted instead of the one associated to the `--model-version` checkpoint. "
1416
+ "No precision override is applied when using a custom VAE."
1417
+ ))
1398
1418
parser .add_argument ("--compute-unit" ,
1399
1419
choices = tuple (cu
1400
1420
for cu in ct .ComputeUnit ._member_names_ ),
0 commit comments