@@ -280,6 +280,12 @@ def log_validations(
280
280
)
281
281
for text_encoder in prompt_handler .text_encoders :
282
282
text_encoder .to ("cpu" )
283
+ current_validation_pooled_embeds .to (
284
+ device = accelerator .device , dtype = weight_dtype
285
+ )
286
+ validation_negative_pooled_embeds .to (
287
+ device = accelerator .device , dtype = weight_dtype
288
+ )
283
289
elif StateTracker .get_model_type () == "legacy" :
284
290
validation_negative_pooled_embeds = None
285
291
current_validation_pooled_embeds = None
@@ -307,6 +313,12 @@ def log_validations(
307
313
for text_encoder in prompt_handler .text_encoders :
308
314
if text_encoder :
309
315
text_encoder .to (accelerator .device )
316
+ current_validation_prompt_embeds .to (
317
+ device = accelerator .device , dtype = weight_dtype
318
+ )
319
+ validation_negative_prompt_embeds .to (
320
+ device = accelerator .device , dtype = weight_dtype
321
+ )
310
322
311
323
# logger.debug(
312
324
# f"Generating validation image: {validation_prompt}"
@@ -342,22 +354,10 @@ def log_validations(
342
354
# )
343
355
validation_images .extend (
344
356
pipeline (
345
- prompt_embeds = current_validation_prompt_embeds .to (
346
- device = accelerator .device ,
347
- dtype = prompt_handler .text_encoders [0 ].dtype ,
348
- ),
349
- pooled_prompt_embeds = current_validation_pooled_embeds .to (
350
- device = accelerator .device ,
351
- dtype = prompt_handler .text_encoders [0 ].dtype ,
352
- ),
353
- negative_prompt_embeds = validation_negative_prompt_embeds .to (
354
- device = accelerator .device ,
355
- dtype = prompt_handler .text_encoders [0 ].dtype ,
356
- ),
357
- negative_pooled_prompt_embeds = validation_negative_pooled_embeds .to (
358
- device = accelerator .device ,
359
- dtype = prompt_handler .text_encoders [0 ].dtype ,
360
- ),
357
+ prompt_embeds = current_validation_prompt_embeds ,
358
+ pooled_prompt_embeds = current_validation_pooled_embeds ,
359
+ negative_prompt_embeds = validation_negative_prompt_embeds ,
360
+ negative_pooled_prompt_embeds = validation_negative_pooled_embeds ,
361
361
num_images_per_prompt = args .num_validation_images ,
362
362
num_inference_steps = args .validation_num_inference_steps ,
363
363
guidance_scale = args .validation_guidance ,
0 commit comments