Skip to content

Commit 6131a93

Browse files
DavyMorganyiyixuxu
andauthored
support sd3.5 for controlnet example (#9860)
* support sd3.5 in controlnet --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent 3cb7b86 commit 6131a93

File tree

5 files changed

+86
-12
lines changed

5 files changed

+86
-12
lines changed

examples/controlnet/README_sd3.md

+27-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# ControlNet training example for Stable Diffusion 3 (SD3)
1+
# ControlNet training example for Stable Diffusion 3/3.5 (SD3/3.5)
22

3-
The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206).
3+
The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206) and [Stable Diffusion 3.5](https://stability.ai/news/introducing-stable-diffusion-3-5).
44

55
## Running locally with PyTorch
66

@@ -51,9 +51,9 @@ Please download the dataset and unzip it in the directory `fill50k` in the `exam
5151

5252
## Training
5353

54-
First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium). We will use it as a base model for the ControlNet training.
54+
First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or the SD3.5 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). We will use it as a base model for the ControlNet training.
5555
> [!NOTE]
56-
> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
56+
> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or [Stable Diffusion 3.5 Large Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
5757
5858
```bash
5959
huggingface-cli login
@@ -90,6 +90,8 @@ accelerate launch train_controlnet_sd3.py \
9090
--gradient_accumulation_steps=4
9191
```
9292

93+
To train a ControlNet model for Stable Diffusion 3.5, replace the `MODEL_DIR` with `stabilityai/stable-diffusion-3.5-medium`.
94+
9395
To better track our training experiments, we're using flags `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
9496

9597
Our experiments were conducted on a single 40GB A100 GPU.
@@ -124,6 +126,8 @@ image = pipe(
124126
image.save("./output.png")
125127
```
126128

129+
Similarly, for SD3.5, replace the `base_model_path` with `stabilityai/stable-diffusion-3.5-medium` and controlnet_path `DavyMorgan/sd35-controlnet-out'.
130+
127131
## Notes
128132

129133
### GPU usage
@@ -135,6 +139,8 @@ Make sure to use the right GPU when configuring the [accelerator](https://huggin
135139

136140
## Example results
137141

142+
### SD3
143+
138144
#### After 500 steps with batch size 8
139145

140146
| | |
@@ -150,3 +156,20 @@ Make sure to use the right GPU when configuring the [accelerator](https://huggin
150156
|| pale golden rod circle with old lace background |
151157
![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-6500.png) |
152158

159+
### SD3.5
160+
161+
#### After 500 steps with batch size 8
162+
163+
| | |
164+
|-------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------:|
165+
|| pale golden rod circle with old lace background |
166+
![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-500-3.5.png) |
167+
168+
169+
#### After 3000 steps with batch size 8:
170+
171+
| | |
172+
|-------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------:|
173+
|| pale golden rod circle with old lace background |
174+
![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-3000-3.5.png) |
175+

examples/controlnet/test_controlnet.py

+21
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,27 @@ def test_controlnet_sd3(self):
138138
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
139139

140140

141+
class ControlNetSD35(ExamplesTestsAccelerate):
142+
def test_controlnet_sd3(self):
143+
with tempfile.TemporaryDirectory() as tmpdir:
144+
test_args = f"""
145+
examples/controlnet/train_controlnet_sd3.py
146+
--pretrained_model_name_or_path=hf-internal-testing/tiny-sd35-pipe
147+
--dataset_name=hf-internal-testing/fill10
148+
--output_dir={tmpdir}
149+
--resolution=64
150+
--train_batch_size=1
151+
--gradient_accumulation_steps=1
152+
--controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd35
153+
--max_train_steps=4
154+
--checkpointing_steps=2
155+
""".split()
156+
157+
run_command(self._launch_args + test_args)
158+
159+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
160+
161+
141162
class ControlNetflux(ExamplesTestsAccelerate):
142163
def test_controlnet_flux(self):
143164
with tempfile.TemporaryDirectory() as tmpdir:

examples/controlnet/train_controlnet_sd3.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,12 @@ def parse_args(input_args=None):
263263
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
264264
" If not specified controlnet weights are initialized from unet.",
265265
)
266+
parser.add_argument(
267+
"--num_extra_conditioning_channels",
268+
type=int,
269+
default=0,
270+
help="Number of extra conditioning channels for controlnet.",
271+
)
266272
parser.add_argument(
267273
"--revision",
268274
type=str,
@@ -539,6 +545,9 @@ def parse_args(input_args=None):
539545
default=77,
540546
help="Maximum sequence length to use with with the T5 text encoder",
541547
)
548+
parser.add_argument(
549+
"--dataset_preprocess_batch_size", type=int, default=1000, help="Batch size for preprocessing dataset."
550+
)
542551
parser.add_argument(
543552
"--validation_prompt",
544553
type=str,
@@ -986,7 +995,9 @@ def main(args):
986995
controlnet = SD3ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
987996
else:
988997
logger.info("Initializing controlnet weights from transformer")
989-
controlnet = SD3ControlNetModel.from_transformer(transformer)
998+
controlnet = SD3ControlNetModel.from_transformer(
999+
transformer, num_extra_conditioning_channels=args.num_extra_conditioning_channels
1000+
)
9901001

9911002
transformer.requires_grad_(False)
9921003
vae.requires_grad_(False)
@@ -1123,7 +1134,12 @@ def compute_text_embeddings(batch, text_encoders, tokenizers):
11231134
# fingerprint used by the cache for the other processes to load the result
11241135
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
11251136
new_fingerprint = Hasher.hash(args)
1126-
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
1137+
train_dataset = train_dataset.map(
1138+
compute_embeddings_fn,
1139+
batched=True,
1140+
batch_size=args.dataset_preprocess_batch_size,
1141+
new_fingerprint=new_fingerprint,
1142+
)
11271143

11281144
del text_encoder_one, text_encoder_two, text_encoder_three
11291145
del tokenizer_one, tokenizer_two, tokenizer_three

src/diffusers/models/transformers/transformer_sd3.py

-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
1614
from typing import Any, Dict, List, Optional, Tuple, Union
1715

1816
import torch

tests/pipelines/controlnet_sd3/test_controlnet_sd3.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
6060
)
6161
batch_params = frozenset(["prompt", "negative_prompt"])
6262

63-
def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm"):
63+
def get_dummy_components(
64+
self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm", use_dual_attention=False
65+
):
6466
torch.manual_seed(0)
6567
transformer = SD3Transformer2DModel(
6668
sample_size=32,
@@ -74,6 +76,7 @@ def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional
7476
pooled_projection_dim=64,
7577
out_channels=8,
7678
qk_norm=qk_norm,
79+
dual_attention_layers=() if not use_dual_attention else (0, 1),
7780
)
7881

7982
torch.manual_seed(0)
@@ -88,7 +91,10 @@ def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional
8891
caption_projection_dim=32,
8992
pooled_projection_dim=64,
9093
out_channels=8,
94+
qk_norm=qk_norm,
95+
dual_attention_layers=() if not use_dual_attention else (0,),
9196
)
97+
9298
clip_text_encoder_config = CLIPTextConfig(
9399
bos_token_id=0,
94100
eos_token_id=2,
@@ -173,8 +179,7 @@ def get_dummy_inputs(self, device, seed=0):
173179

174180
return inputs
175181

176-
def test_controlnet_sd3(self):
177-
components = self.get_dummy_components()
182+
def run_pipe(self, components, use_sd35=False):
178183
sd_pipe = StableDiffusion3ControlNetPipeline(**components)
179184
sd_pipe = sd_pipe.to(torch_device, dtype=torch.float16)
180185
sd_pipe.set_progress_bar_config(disable=None)
@@ -187,12 +192,23 @@ def test_controlnet_sd3(self):
187192

188193
assert image.shape == (1, 32, 32, 3)
189194

190-
expected_slice = np.array([0.5767, 0.7100, 0.5981, 0.5674, 0.5952, 0.4102, 0.5093, 0.5044, 0.6030])
195+
if not use_sd35:
196+
expected_slice = np.array([0.5767, 0.7100, 0.5981, 0.5674, 0.5952, 0.4102, 0.5093, 0.5044, 0.6030])
197+
else:
198+
expected_slice = np.array([1.0000, 0.9072, 0.4209, 0.2744, 0.5737, 0.3840, 0.6113, 0.6250, 0.6328])
191199

192200
assert (
193201
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
194202
), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
195203

204+
def test_controlnet_sd3(self):
205+
components = self.get_dummy_components()
206+
self.run_pipe(components)
207+
208+
def test_controlnet_sd35(self):
209+
components = self.get_dummy_components(num_controlnet_layers=1, qk_norm="rms_norm", use_dual_attention=True)
210+
self.run_pipe(components, use_sd35=True)
211+
196212
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
197213
def test_xformers_attention_forwardGenerator_pass(self):
198214
pass

0 commit comments

Comments
 (0)