Skip to content

Commit 7a4a126

Browse files
PromeAIproYour Namesayakpaul
authored
fix issue that training flux controlnet was unstable and validation r… (#11373)
* fix issue that training flux controlnet was unstable and validation results were unstable * del unused code pieces, fix grammar --------- Co-authored-by: Your Name <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 0dec414 commit 7a4a126

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

examples/controlnet/README_flux.md

+15-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,19 @@ Training script provided by LibAI, which is an institution dedicated to the prog
66
> [!NOTE]
77
> **Memory consumption**
88
>
9-
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
9+
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
10+
11+
Here is a gpu memory consumption for reference, tested on a single A100 with 80G.
12+
13+
| period | GPU |
14+
| - | - |
15+
| load as float32 | ~70G |
16+
| mv transformer and vae to bf16 | ~48G |
17+
| pre compute txt embeddings | ~62G |
18+
| **offload te to cpu** | ~30G |
19+
| training | ~58G |
20+
| validation | ~71G |
21+
1022

1123
> **Gated access**
1224
>
@@ -98,8 +110,9 @@ accelerate launch train_controlnet_flux.py \
98110
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
99111
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
100112
--train_batch_size=1 \
101-
--gradient_accumulation_steps=4 \
113+
--gradient_accumulation_steps=16 \
102114
--report_to="wandb" \
115+
--lr_scheduler="cosine" \
103116
--num_double_layers=4 \
104117
--num_single_layers=0 \
105118
--seed=42 \

examples/controlnet/train_controlnet_flux.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def log_validation(
148148
pooled_prompt_embeds=pooled_prompt_embeds,
149149
control_image=validation_image,
150150
num_inference_steps=28,
151-
controlnet_conditioning_scale=0.7,
151+
controlnet_conditioning_scale=1,
152152
guidance_scale=3.5,
153153
generator=generator,
154154
).images[0]
@@ -1085,8 +1085,6 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline
10851085
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids}
10861086

10871087
train_dataset = get_train_dataset(args, accelerator)
1088-
text_encoders = [text_encoder_one, text_encoder_two]
1089-
tokenizers = [tokenizer_one, tokenizer_two]
10901088
compute_embeddings_fn = functools.partial(
10911089
compute_embeddings,
10921090
flux_controlnet_pipeline=flux_controlnet_pipeline,
@@ -1103,7 +1101,8 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline
11031101
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
11041102
)
11051103

1106-
del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
1104+
text_encoder_one.to("cpu")
1105+
text_encoder_two.to("cpu")
11071106
free_memory()
11081107

11091108
# Then get the training dataset ready to be passed to the dataloader.

0 commit comments

Comments
 (0)