Skip to content

Commit 485d213

Browse files
authored
Merge pull request #1428 from bghira/bugfix/hidream_ddp
hidream multigpu fixes; PEFT LoRA support
2 parents 10eaa08 + b3b459c commit 485d213

File tree

7 files changed

+48
-116
lines changed

7 files changed

+48
-116
lines changed

configure.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"wan",
4343
"deepfloyd",
4444
"auraflow",
45+
"hidream",
4546
],
4647
"controlnet": ["sdxl", "sd1x", "sd2x"],
4748
}
@@ -61,7 +62,7 @@
6162
"hidream": "HiDream-ai/HiDream-I1-Full",
6263
"auraflow": "terminusresearch/auraflow-v0.3",
6364
"deepfloyd": "DeepFloyd/DeepFloyd-IF-I-XL-v1.0",
64-
"omnigen": "Shitao/OmniGen-v1-diffusers",
65+
"omnigen": "Shitao/OmniGen-v1-diffusers",
6566
}
6667

6768
default_cfg = {
@@ -78,7 +79,7 @@
7879
"omnigen": 3.2,
7980
"deepfloyd": 6.0,
8081
"sd2x": 7.0,
81-
"sd1x": 6.0,
82+
"sd1x": 6.0,
8283
}
8384

8485
model_labels = {

helpers/models/hidream/model.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class HiDream(ImageModelFoundation):
4646
MODEL_TYPE = ModelTypes.TRANSFORMER
4747
AUTOENCODER_CLASS = AutoencoderKL
4848
LATENT_CHANNEL_COUNT = 16
49+
DEFAULT_NOISE_SCHEDULER = "flow_unipc"
4950
# The safe diffusers default value for LoRA training targets.
5051
DEFAULT_LORA_TARGET = ["to_k", "to_q", "to_v", "to_out.0"]
5152
# Only training the Attention blocks by default seems to help more with HiDream.
@@ -123,7 +124,11 @@ def _load_pipeline(
123124
"""
124125
active_pipelines = getattr(self, "pipelines", {})
125126
if pipeline_type in active_pipelines:
126-
setattr(active_pipelines[pipeline_type], self.MODEL_TYPE.value, self.unwrap_model())
127+
setattr(
128+
active_pipelines[pipeline_type],
129+
self.MODEL_TYPE.value,
130+
self.unwrap_model(),
131+
)
127132
return active_pipelines[pipeline_type]
128133
pipeline_kwargs = {
129134
"pretrained_model_name_or_path": self._model_config_path(),
@@ -187,7 +192,6 @@ def _load_pipeline(
187192

188193
return self.pipelines[pipeline_type]
189194

190-
191195
def _format_text_embedding(self, text_embedding: torch.Tensor):
192196
"""
193197
Models can optionally format the stored text embedding, eg. in a dict, or
@@ -308,16 +312,16 @@ def model_predict(self, prepared_batch):
308312
):
309313
B, C, H, W = prepared_batch["noisy_latents"].shape
310314
pH, pW = (
311-
H // self.model.config.patch_size,
312-
W // self.model.config.patch_size,
315+
H // self.unwrap_model(model=self.model).config.patch_size,
316+
W // self.unwrap_model(model=self.model).config.patch_size,
313317
)
314318

315319
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
316320
img_ids = torch.zeros(pH, pW, 3)
317321
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
318322
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
319323
img_ids = img_ids.reshape(pH * pW, -1)
320-
img_ids_pad = torch.zeros(self.model.max_seq, 3)
324+
img_ids_pad = torch.zeros(self.unwrap_model(model=self.model).max_seq, 3)
321325
img_ids_pad[: pH * pW, :] = img_ids
322326

323327
img_sizes = img_sizes.unsqueeze(0).to(
@@ -334,10 +338,15 @@ def model_predict(self, prepared_batch):
334338
latent_model_input = prepared_batch["noisy_latents"]
335339
if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
336340
B, C, H, W = latent_model_input.shape
337-
patch_size = self.model.config.patch_size
341+
patch_size = self.unwrap_model(model=self.model).config.patch_size
338342
pH, pW = H // patch_size, W // patch_size
339343
out = torch.zeros(
340-
(B, C, self.model.max_seq, patch_size * patch_size),
344+
(
345+
B,
346+
C,
347+
self.unwrap_model(model=self.model).max_seq,
348+
patch_size * patch_size,
349+
),
341350
dtype=latent_model_input.dtype,
342351
device=latent_model_input.device,
343352
)

helpers/models/hidream/pipeline.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
from diffusers.image_processor import VaeImageProcessor
16-
from diffusers.loaders import FromSingleFileMixin
16+
from diffusers.loaders import FromSingleFileMixin, HiDreamImageLoraLoaderMixin
1717
from diffusers.models.autoencoders import AutoencoderKL
1818
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
1919
from diffusers.utils import (
@@ -136,7 +136,9 @@ class HiDreamImagePipelineOutput(BaseOutput):
136136
images: Union[List[PIL.Image.Image], np.ndarray]
137137

138138

139-
class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
139+
class HiDreamImagePipeline(
140+
DiffusionPipeline, FromSingleFileMixin, HiDreamImageLoraLoaderMixin
141+
):
140142
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae"
141143
_optional_components = ["image_encoder", "feature_extractor"]
142144
_callback_tensor_inputs = ["latents", "prompt_embeds"]

helpers/models/hidream/schedule.py

+3
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def __init__(
103103
raise NotImplementedError(
104104
f"{solver_type} is not implemented for {self.__class__}"
105105
)
106+
if prediction_type is None:
107+
prediction_type = "flow_prediction"
108+
self.config.prediction_type = prediction_type
106109

107110
self.predict_x0 = predict_x0
108111
# setable values

helpers/training/validation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DDIMScheduler,
2727
DDPMScheduler,
2828
)
29+
from helpers.models.hidream.schedule import FlowUniPCMultistepScheduler
2930
from diffusers.utils.torch_utils import is_compiled_module
3031
from helpers.multiaspect.image import MultiaspectImage
3132
from helpers.image_manipulation.brightness import calculate_luminance
@@ -45,6 +46,7 @@
4546
"euler-a": EulerAncestralDiscreteScheduler,
4647
"flow_matching": FlowMatchEulerDiscreteScheduler,
4748
"unipc": UniPCMultistepScheduler,
49+
"flow_unipc": FlowUniPCMultistepScheduler,
4850
"ddim": DDIMScheduler,
4951
"ddpm": DDPMScheduler,
5052
"dpm++": DPMSolverMultistepScheduler,
@@ -793,7 +795,7 @@ def setup_scheduler(self):
793795
# The Beta schedule looks WAY better...
794796
scheduler_args["use_beta_sigmas"] = True
795797
scheduler_args["shift"] = self.config.flow_schedule_shift
796-
if self.config.validation_noise_scheduler == "unipc":
798+
if self.config.validation_noise_scheduler in ["flow_unipc", "unipc"]:
797799
scheduler_args["prediction_type"] = "flow_prediction"
798800
scheduler_args["use_flow_sigmas"] = True
799801
scheduler_args["num_train_timesteps"] = 1000

0 commit comments

Comments
 (0)