Skip to content

Add SD3 fine-tuning scripts #1966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

dsocek
Copy link
Contributor

@dsocek dsocek commented May 7, 2025

What does this PR do?

Adds Gaudi optimized Stable Diffusion 3 and 3.5 (SD3) fine-tuning/training scripts.

Features:

  • Training with Gaudi optimized attention with Fused SDPA kernel
  • Embeddings padded to Gaudi TPC optimal size
  • Both LoRA and Full Model fine-tuning enabled
  • Updated README doc with tested examples
  • Fast SD3 train CI test added

Example of SD3 Training:

Train images:
dog1_training_images

Inference images after training SD3 LoRA with 1000 train steps on Gaudi
prompt="A picture of sks dog in a bucket"
dog1_inference_g2

Signed-off-by: Daniel Socek <[email protected]>
Co-authored-by: Deepak Gowda Doddbele Aswatha Narayana <[email protected]>
Co-authored-by: Pavel Evsikov <[email protected]>
@dsocek dsocek requested a review from regisss as a code owner May 7, 2025 00:22
Copy link
Contributor

@imangohari1 imangohari1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Daniel,
Thanks for the work here. I just have some suggestion on this PR, and am getting an error on one of the examples.
Let me know what you think.
thanks.

import wandb

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
check_min_version("0.29.0")
check_min_version("0.32.0")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

choices=["no", "fp32", "fp16", "bf16"],
help=(
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to revisit the nvidia Ampre GPU comment here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah we should remove this option

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to leave the option but we removed ampere comment and code and default to bf16.



class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
"A simple dataset to prepare the prompts to generate class images on multiple HPUs."



class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
"A simple dataset to prepare the prompts to generate class images on multiple HPUs."

@dsocek dsocek requested a review from imangohari1 May 8, 2025 19:58
@dsocek
Copy link
Contributor Author

dsocek commented May 8, 2025

@imangohari1 thanks for thorough review! We fixed the issue you saw in multi-card training, we forgot to pass training mode properly in validation. Could you help do your final review?

Copy link
Contributor

@imangohari1 imangohari1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dsocek LGTM but I think test_dreambooth_lora_sd3 should be @slow. @regisss WDYT?

Copy link
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just left a few comments;
Can you also update the table in the README to tick the training column for SD3? Here:

| Stable Diffusion 3 | | <li>Single card</li> | <li>[text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion#stable-diffusion-3-and-35-sd3)</li> |

Signed-off-by: Daniel Socek <[email protected]>
@dsocek
Copy link
Contributor Author

dsocek commented May 12, 2025

@regisss thanks for review!

Added new test for full sd3 training and fixed READMEs as per your suggestions.

$ python -m pytest tests/test_diffusers.py -v -s -k "test_dreambooth_sd3"
...
PASSED
======================= 1 passed, 152 deselected in 125.27s (0:02:05) 

Test takes ~2mins on G2, I left it in fast tests category, LMK if we should move to slow

@dsocek dsocek requested a review from regisss May 12, 2025 17:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants