-
Notifications
You must be signed in to change notification settings - Fork 256
Add SD3 fine-tuning scripts #1966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Daniel Socek <[email protected]> Co-authored-by: Deepak Gowda Doddbele Aswatha Narayana <[email protected]> Co-authored-by: Pavel Evsikov <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Daniel,
Thanks for the work here. I just have some suggestion on this PR, and am getting an error on one of the examples.
Let me know what you think.
thanks.
import wandb | ||
|
||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | ||
check_min_version("0.29.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_min_version("0.29.0") | |
check_min_version("0.32.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
choices=["no", "fp32", "fp16", "bf16"], | ||
help=( | ||
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" | ||
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to revisit the nvidia Ampre GPU comment here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nah we should remove this option
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to leave the option but we removed ampere comment and code and default to bf16.
|
||
|
||
class PromptDataset(Dataset): | ||
"A simple dataset to prepare the prompts to generate class images on multiple GPUs." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"A simple dataset to prepare the prompts to generate class images on multiple GPUs." | |
"A simple dataset to prepare the prompts to generate class images on multiple HPUs." |
|
||
|
||
class PromptDataset(Dataset): | ||
"A simple dataset to prepare the prompts to generate class images on multiple GPUs." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"A simple dataset to prepare the prompts to generate class images on multiple GPUs." | |
"A simple dataset to prepare the prompts to generate class images on multiple HPUs." |
@imangohari1 thanks for thorough review! We fixed the issue you saw in multi-card training, we forgot to pass training mode properly in validation. Could you help do your final review? |
Signed-off-by: Daniel Socek <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just left a few comments;
Can you also update the table in the README to tick the training column for SD3? Here:
Line 299 in f08e27a
| Stable Diffusion 3 | | <li>Single card</li> | <li>[text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion#stable-diffusion-3-and-35-sd3)</li> | |
Signed-off-by: Daniel Socek <[email protected]>
@regisss thanks for review! Added new test for full sd3 training and fixed READMEs as per your suggestions. $ python -m pytest tests/test_diffusers.py -v -s -k "test_dreambooth_sd3"
...
PASSED
======================= 1 passed, 152 deselected in 125.27s (0:02:05) Test takes ~2mins on G2, I left it in fast tests category, LMK if we should move to slow |
What does this PR do?
Adds Gaudi optimized Stable Diffusion 3 and 3.5 (SD3) fine-tuning/training scripts.
Features:
Example of SD3 Training:
Train images:

Inference images after training SD3 LoRA with 1000 train steps on Gaudi

prompt="A picture of sks dog in a bucket"