Skip to content

Commit d975834

Browse files
authored
Merge pull request #332 from bghira/main
sd2x: num steps remaining fix | vaecache: exit with problematic data backend id
2 parents 5f3dda7 + b5616a7 commit d975834

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

helpers/caching/vae.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def encode_images(self, images, filepaths, load_from_cache=True):
392392
if len(uncached_image_indices) > 0 and load_from_cache:
393393
# We wanted only uncached images. Something went wrong.
394394
raise Exception(
395-
f"Some images were not correctly cached during the VAE Cache operations. Ensure --skip_file_discovery=vae is not set.\nProblematic images: {uncached_image_paths}"
395+
f"(id={self.id}) Some images were not correctly cached during the VAE Cache operations. Ensure --skip_file_discovery=vae is not set.\nProblematic images: {uncached_image_paths}"
396396
)
397397

398398
if load_from_cache:

train_sd21.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,13 @@ def main():
209209

210210
# Enable TF32 for faster training on Ampere GPUs,
211211
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
212-
if args.allow_tf32:
212+
if args.allow_tf32 and not torch.backends.mps.is_available():
213213
logger.info(
214214
"Enabling tf32 precision boost for NVIDIA devices due to --allow_tf32."
215215
)
216216
torch.backends.cuda.matmul.allow_tf32 = True
217217
torch.backends.cudnn.allow_tf32 = True
218-
else:
218+
elif torch.backends.cuda.is_available():
219219
logger.warning(
220220
"If using an Ada or Ampere NVIDIA device, --allow_tf32 could add a bit more performance."
221221
)
@@ -870,8 +870,9 @@ def main():
870870
total_steps_remaining_at_start = args.max_train_steps
871871
# We store the number of dataset resets that have occurred inside the checkpoint.
872872
if first_epoch > 1:
873-
steps_to_remove = first_epoch * num_update_steps_per_epoch
874-
total_steps_remaining_at_start -= steps_to_remove
873+
total_steps_remaining_at_start = (
874+
total_steps_remaining_at_start - resume_global_step
875+
)
875876
logger.debug(
876877
f"Resuming from epoch {first_epoch}, which leaves us with {total_steps_remaining_at_start}."
877878
)

0 commit comments

Comments
 (0)