Skip to content

Commit 4c6844f

Browse files
authored
Cleanup model load methods (#333)
1 parent 127d969 commit 4c6844f

File tree

5 files changed

+83
-192
lines changed

5 files changed

+83
-192
lines changed

finetrainers/models/cogvideox/base_specification.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ...logging import get_logger
1818
from ...processors import ProcessorMixin, T5Processor
1919
from ...typing import ArtifactType, SchedulerType
20-
from ...utils import get_non_null_items
20+
from ...utils import _enable_vae_memory_optimizations, get_non_null_items
2121
from ..modeling_utils import ModelSpecification
2222
from ..utils import DiagonalGaussianDistribution
2323
from .utils import prepare_rotary_positional_embeddings
@@ -117,74 +117,58 @@ def _resolution_dim_keys(self):
117117
return {"latents": (1, 3, 4)}
118118

119119
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
120+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
121+
120122
if self.tokenizer_id is not None:
121-
tokenizer = AutoTokenizer.from_pretrained(
122-
self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
123-
)
123+
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
124124
else:
125125
tokenizer = T5Tokenizer.from_pretrained(
126-
self.pretrained_model_name_or_path,
127-
subfolder="tokenizer",
128-
revision=self.revision,
129-
cache_dir=self.cache_dir,
126+
self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
130127
)
131128

132129
if self.text_encoder_id is not None:
133130
text_encoder = AutoModel.from_pretrained(
134-
self.text_encoder_id,
135-
torch_dtype=self.text_encoder_dtype,
136-
revision=self.revision,
137-
cache_dir=self.cache_dir,
131+
self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
138132
)
139133
else:
140134
text_encoder = T5EncoderModel.from_pretrained(
141135
self.pretrained_model_name_or_path,
142136
subfolder="text_encoder",
143137
torch_dtype=self.text_encoder_dtype,
144-
revision=self.revision,
145-
cache_dir=self.cache_dir,
138+
**common_kwargs,
146139
)
147140

148141
return {"tokenizer": tokenizer, "text_encoder": text_encoder}
149142

150143
def load_latent_models(self) -> Dict[str, torch.nn.Module]:
144+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
145+
151146
if self.vae_id is not None:
152-
vae = AutoencoderKLCogVideoX.from_pretrained(
153-
self.vae_id,
154-
torch_dtype=self.vae_dtype,
155-
revision=self.revision,
156-
cache_dir=self.cache_dir,
157-
)
147+
vae = AutoencoderKLCogVideoX.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
158148
else:
159149
vae = AutoencoderKLCogVideoX.from_pretrained(
160-
self.pretrained_model_name_or_path,
161-
subfolder="vae",
162-
torch_dtype=self.vae_dtype,
163-
revision=self.revision,
164-
cache_dir=self.cache_dir,
150+
self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
165151
)
166152

167153
return {"vae": vae}
168154

169155
def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
156+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
157+
170158
if self.transformer_id is not None:
171159
transformer = CogVideoXTransformer3DModel.from_pretrained(
172-
self.transformer_id,
173-
torch_dtype=self.transformer_dtype,
174-
revision=self.revision,
175-
cache_dir=self.cache_dir,
160+
self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
176161
)
177162
else:
178163
transformer = CogVideoXTransformer3DModel.from_pretrained(
179164
self.pretrained_model_name_or_path,
180165
subfolder="transformer",
181166
torch_dtype=self.transformer_dtype,
182-
revision=self.revision,
183-
cache_dir=self.cache_dir,
167+
**common_kwargs,
184168
)
185169

186170
scheduler = CogVideoXDDIMScheduler.from_pretrained(
187-
self.pretrained_model_name_or_path, subfolder="scheduler", revision=self.revision, cache_dir=self.cache_dir
171+
self.pretrained_model_name_or_path, subfolder="scheduler", **common_kwargs
188172
)
189173

190174
return {"transformer": transformer, "scheduler": scheduler}
@@ -217,16 +201,11 @@ def load_pipeline(
217201
pipe.text_encoder.to(self.text_encoder_dtype)
218202
pipe.vae.to(self.vae_dtype)
219203

204+
_enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
220205
if not training:
221206
pipe.transformer.to(self.transformer_dtype)
222-
223-
if enable_slicing:
224-
pipe.vae.enable_slicing()
225-
if enable_tiling:
226-
pipe.vae.enable_tiling()
227207
if enable_model_cpu_offload:
228208
pipe.enable_model_cpu_offload()
229-
230209
return pipe
231210

232211
@torch.no_grad()

finetrainers/models/cogview4/base_specification.py

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ...logging import get_logger
1818
from ...processors import CogView4GLMProcessor, ProcessorMixin
1919
from ...typing import ArtifactType, SchedulerType
20-
from ...utils import get_non_null_items
20+
from ...utils import _enable_vae_memory_optimizations, get_non_null_items
2121
from ..modeling_utils import ModelSpecification
2222

2323

@@ -136,70 +136,54 @@ def _resolution_dim_keys(self):
136136
return {"latents": (2, 3)}
137137

138138
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
139+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
140+
139141
if self.tokenizer_id is not None:
140-
tokenizer = AutoTokenizer.from_pretrained(
141-
self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
142-
)
142+
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
143143
else:
144144
tokenizer = AutoTokenizer.from_pretrained(
145-
self.pretrained_model_name_or_path,
146-
subfolder="tokenizer",
147-
revision=self.revision,
148-
cache_dir=self.cache_dir,
145+
self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
149146
)
150147

151148
if self.text_encoder_id is not None:
152149
text_encoder = GlmModel.from_pretrained(
153-
self.text_encoder_id,
154-
torch_dtype=self.text_encoder_dtype,
155-
revision=self.revision,
156-
cache_dir=self.cache_dir,
150+
self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
157151
)
158152
else:
159153
text_encoder = GlmModel.from_pretrained(
160154
self.pretrained_model_name_or_path,
161155
subfolder="text_encoder",
162156
torch_dtype=self.text_encoder_dtype,
163-
revision=self.revision,
164-
cache_dir=self.cache_dir,
157+
**common_kwargs,
165158
)
166159

167160
return {"tokenizer": tokenizer, "text_encoder": text_encoder}
168161

169162
def load_latent_models(self) -> Dict[str, torch.nn.Module]:
163+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
164+
170165
if self.vae_id is not None:
171-
vae = AutoencoderKL.from_pretrained(
172-
self.vae_id,
173-
torch_dtype=self.vae_dtype,
174-
revision=self.revision,
175-
cache_dir=self.cache_dir,
176-
)
166+
vae = AutoencoderKL.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
177167
else:
178168
vae = AutoencoderKL.from_pretrained(
179-
self.pretrained_model_name_or_path,
180-
subfolder="vae",
181-
torch_dtype=self.vae_dtype,
182-
revision=self.revision,
183-
cache_dir=self.cache_dir,
169+
self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
184170
)
185171

186172
return {"vae": vae}
187173

188174
def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
175+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
176+
189177
if self.transformer_id is not None:
190178
transformer = CogView4Transformer2DModel.from_pretrained(
191-
self.transformer_id,
192-
torch_dtype=self.transformer_dtype,
193-
revision=self.revision,
194-
cache_dir=self.cache_dir,
179+
self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
195180
)
196181
else:
197182
transformer = CogView4Transformer2DModel.from_pretrained(
198183
self.pretrained_model_name_or_path,
199184
subfolder="transformer",
200185
torch_dtype=self.transformer_dtype,
201-
revision=self.revision,
202-
cache_dir=self.cache_dir,
186+
**common_kwargs,
203187
)
204188

205189
scheduler = FlowMatchEulerDiscreteScheduler()
@@ -235,16 +219,11 @@ def load_pipeline(
235219
pipe.text_encoder.to(self.text_encoder_dtype)
236220
pipe.vae.to(self.vae_dtype)
237221

222+
_enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
238223
if not training:
239224
pipe.transformer.to(self.transformer_dtype)
240-
241-
if enable_slicing:
242-
pipe.vae.enable_slicing()
243-
if enable_tiling:
244-
pipe.vae.enable_tiling()
245225
if enable_model_cpu_offload:
246226
pipe.enable_model_cpu_offload()
247-
248227
return pipe
249228

250229
@torch.no_grad()

finetrainers/models/hunyuan_video/base_specification.py

Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ...logging import get_logger
1818
from ...processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin
1919
from ...typing import ArtifactType, SchedulerType
20-
from ...utils import get_non_null_items
20+
from ...utils import _enable_vae_memory_optimizations, get_non_null_items
2121
from ..modeling_utils import ModelSpecification
2222

2323

@@ -120,60 +120,44 @@ def _resolution_dim_keys(self):
120120
return {"latents": (2, 3, 4)}
121121

122122
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
123+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
124+
123125
if self.tokenizer_id is not None:
124-
tokenizer = AutoTokenizer.from_pretrained(
125-
self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
126-
)
126+
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
127127
else:
128128
tokenizer = AutoTokenizer.from_pretrained(
129-
self.pretrained_model_name_or_path,
130-
subfolder="tokenizer",
131-
revision=self.revision,
132-
cache_dir=self.cache_dir,
129+
self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
133130
)
134131

135132
if self.tokenizer_2_id is not None:
136-
tokenizer_2 = CLIPTokenizer.from_pretrained(
137-
self.tokenizer_2_id, revision=self.revision, cache_dir=self.cache_dir
138-
)
133+
tokenizer_2 = AutoTokenizer.from_pretrained(self.tokenizer_2_id, **common_kwargs)
139134
else:
140135
tokenizer_2 = CLIPTokenizer.from_pretrained(
141-
self.pretrained_model_name_or_path,
142-
subfolder="tokenizer_2",
143-
revision=self.revision,
144-
cache_dir=self.cache_dir,
136+
self.pretrained_model_name_or_path, subfolder="tokenizer_2" ** common_kwargs
145137
)
146138

147139
if self.text_encoder_id is not None:
148140
text_encoder = LlamaModel.from_pretrained(
149-
self.text_encoder_id,
150-
torch_dtype=self.text_encoder_dtype,
151-
revision=self.revision,
152-
cache_dir=self.cache_dir,
141+
self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
153142
)
154143
else:
155144
text_encoder = LlamaModel.from_pretrained(
156145
self.pretrained_model_name_or_path,
157146
subfolder="text_encoder",
158147
torch_dtype=self.text_encoder_dtype,
159-
revision=self.revision,
160-
cache_dir=self.cache_dir,
148+
**common_kwargs,
161149
)
162150

163151
if self.text_encoder_2_id is not None:
164152
text_encoder_2 = CLIPTextModel.from_pretrained(
165-
self.text_encoder_2_id,
166-
torch_dtype=self.text_encoder_2_dtype,
167-
revision=self.revision,
168-
cache_dir=self.cache_dir,
153+
self.text_encoder_2_id, torch_dtype=self.text_encoder_2_dtype, **common_kwargs
169154
)
170155
else:
171156
text_encoder_2 = CLIPTextModel.from_pretrained(
172157
self.pretrained_model_name_or_path,
173158
subfolder="text_encoder_2",
174159
torch_dtype=self.text_encoder_2_dtype,
175-
revision=self.revision,
176-
cache_dir=self.cache_dir,
160+
**common_kwargs,
177161
)
178162

179163
return {
@@ -184,39 +168,30 @@ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
184168
}
185169

186170
def load_latent_models(self) -> Dict[str, torch.nn.Module]:
171+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
172+
187173
if self.vae_id is not None:
188-
vae = AutoencoderKLHunyuanVideo.from_pretrained(
189-
self.vae_id,
190-
torch_dtype=self.vae_dtype,
191-
revision=self.revision,
192-
cache_dir=self.cache_dir,
193-
)
174+
vae = AutoencoderKLHunyuanVideo.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
194175
else:
195176
vae = AutoencoderKLHunyuanVideo.from_pretrained(
196-
self.pretrained_model_name_or_path,
197-
subfolder="vae",
198-
torch_dtype=self.vae_dtype,
199-
revision=self.revision,
200-
cache_dir=self.cache_dir,
177+
self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
201178
)
202179

203180
return {"vae": vae}
204181

205182
def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
183+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
184+
206185
if self.transformer_id is not None:
207186
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
208-
self.transformer_id,
209-
torch_dtype=self.transformer_dtype,
210-
revision=self.revision,
211-
cache_dir=self.cache_dir,
187+
self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
212188
)
213189
else:
214190
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
215191
self.pretrained_model_name_or_path,
216192
subfolder="transformer",
217193
torch_dtype=self.transformer_dtype,
218-
revision=self.revision,
219-
cache_dir=self.cache_dir,
194+
**common_kwargs,
220195
)
221196

222197
scheduler = FlowMatchEulerDiscreteScheduler()
@@ -256,16 +231,11 @@ def load_pipeline(
256231
pipe.text_encoder_2.to(self.text_encoder_2_dtype)
257232
pipe.vae.to(self.vae_dtype)
258233

234+
_enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
259235
if not training:
260236
pipe.transformer.to(self.transformer_dtype)
261-
262-
if enable_slicing:
263-
pipe.vae.enable_slicing()
264-
if enable_tiling:
265-
pipe.vae.enable_tiling()
266237
if enable_model_cpu_offload:
267238
pipe.enable_model_cpu_offload()
268-
269239
return pipe
270240

271241
@torch.no_grad()

0 commit comments

Comments
 (0)