@@ -46,6 +46,7 @@ class HiDream(ImageModelFoundation):
46
46
MODEL_TYPE = ModelTypes .TRANSFORMER
47
47
AUTOENCODER_CLASS = AutoencoderKL
48
48
LATENT_CHANNEL_COUNT = 16
49
+ DEFAULT_NOISE_SCHEDULER = "flow_unipc"
49
50
# The safe diffusers default value for LoRA training targets.
50
51
DEFAULT_LORA_TARGET = ["to_k" , "to_q" , "to_v" , "to_out.0" ]
51
52
# Only training the Attention blocks by default seems to help more with HiDream.
@@ -123,7 +124,11 @@ def _load_pipeline(
123
124
"""
124
125
active_pipelines = getattr (self , "pipelines" , {})
125
126
if pipeline_type in active_pipelines :
126
- setattr (active_pipelines [pipeline_type ], self .MODEL_TYPE .value , self .unwrap_model ())
127
+ setattr (
128
+ active_pipelines [pipeline_type ],
129
+ self .MODEL_TYPE .value ,
130
+ self .unwrap_model (),
131
+ )
127
132
return active_pipelines [pipeline_type ]
128
133
pipeline_kwargs = {
129
134
"pretrained_model_name_or_path" : self ._model_config_path (),
@@ -187,7 +192,6 @@ def _load_pipeline(
187
192
188
193
return self .pipelines [pipeline_type ]
189
194
190
-
191
195
def _format_text_embedding (self , text_embedding : torch .Tensor ):
192
196
"""
193
197
Models can optionally format the stored text embedding, eg. in a dict, or
@@ -308,16 +312,16 @@ def model_predict(self, prepared_batch):
308
312
):
309
313
B , C , H , W = prepared_batch ["noisy_latents" ].shape
310
314
pH , pW = (
311
- H // self .model .config .patch_size ,
312
- W // self .model .config .patch_size ,
315
+ H // self .unwrap_model ( model = self . model ) .config .patch_size ,
316
+ W // self .unwrap_model ( model = self . model ) .config .patch_size ,
313
317
)
314
318
315
319
img_sizes = torch .tensor ([pH , pW ], dtype = torch .int64 ).reshape (- 1 )
316
320
img_ids = torch .zeros (pH , pW , 3 )
317
321
img_ids [..., 1 ] = img_ids [..., 1 ] + torch .arange (pH )[:, None ]
318
322
img_ids [..., 2 ] = img_ids [..., 2 ] + torch .arange (pW )[None , :]
319
323
img_ids = img_ids .reshape (pH * pW , - 1 )
320
- img_ids_pad = torch .zeros (self .model .max_seq , 3 )
324
+ img_ids_pad = torch .zeros (self .unwrap_model ( model = self . model ) .max_seq , 3 )
321
325
img_ids_pad [: pH * pW , :] = img_ids
322
326
323
327
img_sizes = img_sizes .unsqueeze (0 ).to (
@@ -334,10 +338,15 @@ def model_predict(self, prepared_batch):
334
338
latent_model_input = prepared_batch ["noisy_latents" ]
335
339
if latent_model_input .shape [- 2 ] != latent_model_input .shape [- 1 ]:
336
340
B , C , H , W = latent_model_input .shape
337
- patch_size = self .model .config .patch_size
341
+ patch_size = self .unwrap_model ( model = self . model ) .config .patch_size
338
342
pH , pW = H // patch_size , W // patch_size
339
343
out = torch .zeros (
340
- (B , C , self .model .max_seq , patch_size * patch_size ),
344
+ (
345
+ B ,
346
+ C ,
347
+ self .unwrap_model (model = self .model ).max_seq ,
348
+ patch_size * patch_size ,
349
+ ),
341
350
dtype = latent_model_input .dtype ,
342
351
device = latent_model_input .device ,
343
352
)
0 commit comments