Skip to content

Commit 98c1475

Browse files
author
bghira
committed
fixes for ckpt-only validations when EMA is enabled, and, token length limit application for sd3/auraflow/omnigen fix
1 parent 8871d79 commit 98c1475

File tree

4 files changed

+25
-16
lines changed

4 files changed

+25
-16
lines changed

helpers/models/auraflow/model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,15 @@ def check_user_config(self):
179179
f"{self.NAME} does not support fp8-quanto. Please use fp8-torchao or int8 precision level instead."
180180
)
181181
t5_max_length = 120
182-
if (
183-
self.config.tokenizer_max_length is None
184-
or int(self.config.tokenizer_max_length) > t5_max_length
185-
):
182+
if self.config.tokenizer_max_length is None or self.config.tokenizer_max_length == 0:
183+
logger.warning(
184+
f"Setting T5 XXL tokeniser max length to {t5_max_length} for {self.NAME}."
185+
)
186+
self.config.tokenizer_max_length = t5_max_length
187+
if int(self.config.tokenizer_max_length) > t5_max_length:
186188
if not self.config.i_know_what_i_am_doing:
187189
logger.warning(
188-
f"Updating T5 XXL tokeniser max length to {t5_max_length} for {self.NAME}."
190+
f"Overriding T5 XXL tokeniser max length to {t5_max_length} for {self.NAME} because `--i_know_what_i_am_doing` has not been set."
189191
)
190192
self.config.tokenizer_max_length = t5_max_length
191193
else:

helpers/models/auraflow/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def encode_prompt(
658658
negative_prompt_embeds: Optional[torch.Tensor] = None,
659659
prompt_attention_mask: Optional[torch.Tensor] = None,
660660
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
661-
max_sequence_length: int = 256,
661+
max_sequence_length: int = 120,
662662
lora_scale: Optional[float] = None,
663663
):
664664
r"""

helpers/models/hidream/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,15 @@ def check_user_config(self):
306306
raise ValueError(
307307
f"{self.NAME} does not support fp8-quanto. Please use fp8-torchao or int8 precision level instead."
308308
)
309-
t5_max_length = 128
310-
if (
311-
self.config.tokenizer_max_length is None
312-
or int(self.config.tokenizer_max_length) > t5_max_length
313-
):
309+
if self.config.tokenizer_max_length is None or self.config.tokenizer_max_length == 0:
310+
logger.warning(
311+
f"Setting T5 XXL tokeniser max length to {t5_max_length} for {self.NAME}."
312+
)
313+
self.config.tokenizer_max_length = t5_max_length
314+
if int(self.config.tokenizer_max_length) > t5_max_length:
314315
if not self.config.i_know_what_i_am_doing:
315316
logger.warning(
316-
f"Updating T5 XXL tokeniser max length to {t5_max_length} for {self.NAME}."
317+
f"Overriding T5 XXL tokeniser max length to {t5_max_length} for {self.NAME} because `--i_know_what_i_am_doing` has not been set."
317318
)
318319
self.config.tokenizer_max_length = t5_max_length
319320
else:

helpers/training/validation.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ def validate_prompt(
944944
):
945945
"""Generate validation images for a single prompt."""
946946
# Placeholder for actual image generation and logging
947-
logger.debug(f"Validating prompt: {prompt}")
947+
logger.debug(f"Validating ({validation_shortname}) prompt: {prompt}")
948948
# benchmarked / stitched validation images
949949
stitched_validation_images = {}
950950
# untouched / un-stitched validation images
@@ -1179,9 +1179,15 @@ def validate_prompt(
11791179
validation_image_results
11801180
)
11811181
if self.config.use_ema:
1182-
ema_validation_images[validation_shortname].extend(
1183-
ema_image_results
1184-
)
1182+
if validation_shortname in ema_validation_images and ema_image_results is not None:
1183+
if ema_validation_images[validation_shortname] is None:
1184+
# init the value
1185+
ema_validation_images[validation_shortname] = []
1186+
if isinstance(ema_validation_images[validation_shortname], list):
1187+
# if we have a list of images, we can stitch them.
1188+
ema_validation_images[validation_shortname].extend(
1189+
ema_image_results
1190+
)
11851191

11861192
except Exception as e:
11871193
import traceback

0 commit comments

Comments
 (0)