Skip to content

Commit 018c836

Browse files
author
bghira
committed
sd2x: fix validation error referencing unset pooled embeds
1 parent 59bf509 commit 018c836

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

helpers/legacy/validation.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ def log_validations(
280280
)
281281
for text_encoder in prompt_handler.text_encoders:
282282
text_encoder.to("cpu")
283+
current_validation_pooled_embeds.to(
284+
device=accelerator.device, dtype=weight_dtype
285+
)
286+
validation_negative_pooled_embeds.to(
287+
device=accelerator.device, dtype=weight_dtype
288+
)
283289
elif StateTracker.get_model_type() == "legacy":
284290
validation_negative_pooled_embeds = None
285291
current_validation_pooled_embeds = None
@@ -307,6 +313,12 @@ def log_validations(
307313
for text_encoder in prompt_handler.text_encoders:
308314
if text_encoder:
309315
text_encoder.to(accelerator.device)
316+
current_validation_prompt_embeds.to(
317+
device=accelerator.device, dtype=weight_dtype
318+
)
319+
validation_negative_prompt_embeds.to(
320+
device=accelerator.device, dtype=weight_dtype
321+
)
310322

311323
# logger.debug(
312324
# f"Generating validation image: {validation_prompt}"
@@ -342,22 +354,10 @@ def log_validations(
342354
# )
343355
validation_images.extend(
344356
pipeline(
345-
prompt_embeds=current_validation_prompt_embeds.to(
346-
device=accelerator.device,
347-
dtype=prompt_handler.text_encoders[0].dtype,
348-
),
349-
pooled_prompt_embeds=current_validation_pooled_embeds.to(
350-
device=accelerator.device,
351-
dtype=prompt_handler.text_encoders[0].dtype,
352-
),
353-
negative_prompt_embeds=validation_negative_prompt_embeds.to(
354-
device=accelerator.device,
355-
dtype=prompt_handler.text_encoders[0].dtype,
356-
),
357-
negative_pooled_prompt_embeds=validation_negative_pooled_embeds.to(
358-
device=accelerator.device,
359-
dtype=prompt_handler.text_encoders[0].dtype,
360-
),
357+
prompt_embeds=current_validation_prompt_embeds,
358+
pooled_prompt_embeds=current_validation_pooled_embeds,
359+
negative_prompt_embeds=validation_negative_prompt_embeds,
360+
negative_pooled_prompt_embeds=validation_negative_pooled_embeds,
361361
num_images_per_prompt=args.num_validation_images,
362362
num_inference_steps=args.validation_num_inference_steps,
363363
guidance_scale=args.validation_guidance,

0 commit comments

Comments
 (0)