Skip to content

Cleanup model load methods #333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 17 additions & 38 deletions finetrainers/models/cogvideox/base_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...logging import get_logger
from ...processors import ProcessorMixin, T5Processor
from ...typing import ArtifactType, SchedulerType
from ...utils import get_non_null_items
from ...utils import _enable_vae_memory_optimizations, get_non_null_items
from ..modeling_utils import ModelSpecification
from ..utils import DiagonalGaussianDistribution
from .utils import prepare_rotary_positional_embeddings
Expand Down Expand Up @@ -117,74 +117,58 @@ def _resolution_dim_keys(self):
return {"latents": (1, 3, 4)}

def load_condition_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

if self.tokenizer_id is not None:
tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
)
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
else:
tokenizer = T5Tokenizer.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=self.revision,
cache_dir=self.cache_dir,
self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
)

if self.text_encoder_id is not None:
text_encoder = AutoModel.from_pretrained(
self.text_encoder_id,
torch_dtype=self.text_encoder_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
)
else:
text_encoder = T5EncoderModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="text_encoder",
torch_dtype=self.text_encoder_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
**common_kwargs,
)

return {"tokenizer": tokenizer, "text_encoder": text_encoder}

def load_latent_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

if self.vae_id is not None:
vae = AutoencoderKLCogVideoX.from_pretrained(
self.vae_id,
torch_dtype=self.vae_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
vae = AutoencoderKLCogVideoX.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
else:
vae = AutoencoderKLCogVideoX.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="vae",
torch_dtype=self.vae_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
)

return {"vae": vae}

def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

if self.transformer_id is not None:
transformer = CogVideoXTransformer3DModel.from_pretrained(
self.transformer_id,
torch_dtype=self.transformer_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
)
else:
transformer = CogVideoXTransformer3DModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=self.transformer_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
**common_kwargs,
)

scheduler = CogVideoXDDIMScheduler.from_pretrained(
self.pretrained_model_name_or_path, subfolder="scheduler", revision=self.revision, cache_dir=self.cache_dir
self.pretrained_model_name_or_path, subfolder="scheduler", **common_kwargs
)

return {"transformer": transformer, "scheduler": scheduler}
Expand Down Expand Up @@ -217,16 +201,11 @@ def load_pipeline(
pipe.text_encoder.to(self.text_encoder_dtype)
pipe.vae.to(self.vae_dtype)

_enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
if not training:
pipe.transformer.to(self.transformer_dtype)

if enable_slicing:
pipe.vae.enable_slicing()
if enable_tiling:
pipe.vae.enable_tiling()
if enable_model_cpu_offload:
pipe.enable_model_cpu_offload()

return pipe

@torch.no_grad()
Expand Down
53 changes: 16 additions & 37 deletions finetrainers/models/cogview4/base_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...logging import get_logger
from ...processors import CogView4GLMProcessor, ProcessorMixin
from ...typing import ArtifactType, SchedulerType
from ...utils import get_non_null_items
from ...utils import _enable_vae_memory_optimizations, get_non_null_items
from ..modeling_utils import ModelSpecification


Expand Down Expand Up @@ -136,70 +136,54 @@ def _resolution_dim_keys(self):
return {"latents": (2, 3)}

def load_condition_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

if self.tokenizer_id is not None:
tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
)
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=self.revision,
cache_dir=self.cache_dir,
self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
)

if self.text_encoder_id is not None:
text_encoder = GlmModel.from_pretrained(
self.text_encoder_id,
torch_dtype=self.text_encoder_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
)
else:
text_encoder = GlmModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="text_encoder",
torch_dtype=self.text_encoder_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
**common_kwargs,
)

return {"tokenizer": tokenizer, "text_encoder": text_encoder}

def load_latent_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

if self.vae_id is not None:
vae = AutoencoderKL.from_pretrained(
self.vae_id,
torch_dtype=self.vae_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
vae = AutoencoderKL.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
else:
vae = AutoencoderKL.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="vae",
torch_dtype=self.vae_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
)

return {"vae": vae}

def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

if self.transformer_id is not None:
transformer = CogView4Transformer2DModel.from_pretrained(
self.transformer_id,
torch_dtype=self.transformer_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
)
else:
transformer = CogView4Transformer2DModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=self.transformer_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
**common_kwargs,
)

scheduler = FlowMatchEulerDiscreteScheduler()
Expand Down Expand Up @@ -235,16 +219,11 @@ def load_pipeline(
pipe.text_encoder.to(self.text_encoder_dtype)
pipe.vae.to(self.vae_dtype)

_enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
if not training:
pipe.transformer.to(self.transformer_dtype)

if enable_slicing:
pipe.vae.enable_slicing()
if enable_tiling:
pipe.vae.enable_tiling()
if enable_model_cpu_offload:
pipe.enable_model_cpu_offload()

return pipe

@torch.no_grad()
Expand Down
70 changes: 20 additions & 50 deletions finetrainers/models/hunyuan_video/base_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...logging import get_logger
from ...processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin
from ...typing import ArtifactType, SchedulerType
from ...utils import get_non_null_items
from ...utils import _enable_vae_memory_optimizations, get_non_null_items
from ..modeling_utils import ModelSpecification


Expand Down Expand Up @@ -120,60 +120,44 @@ def _resolution_dim_keys(self):
return {"latents": (2, 3, 4)}

def load_condition_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

if self.tokenizer_id is not None:
tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
)
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=self.revision,
cache_dir=self.cache_dir,
self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
)

if self.tokenizer_2_id is not None:
tokenizer_2 = CLIPTokenizer.from_pretrained(
self.tokenizer_2_id, revision=self.revision, cache_dir=self.cache_dir
)
tokenizer_2 = AutoTokenizer.from_pretrained(self.tokenizer_2_id, **common_kwargs)
else:
tokenizer_2 = CLIPTokenizer.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=self.revision,
cache_dir=self.cache_dir,
self.pretrained_model_name_or_path, subfolder="tokenizer_2" ** common_kwargs
)

if self.text_encoder_id is not None:
text_encoder = LlamaModel.from_pretrained(
self.text_encoder_id,
torch_dtype=self.text_encoder_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
)
else:
text_encoder = LlamaModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="text_encoder",
torch_dtype=self.text_encoder_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
**common_kwargs,
)

if self.text_encoder_2_id is not None:
text_encoder_2 = CLIPTextModel.from_pretrained(
self.text_encoder_2_id,
torch_dtype=self.text_encoder_2_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
self.text_encoder_2_id, torch_dtype=self.text_encoder_2_dtype, **common_kwargs
)
else:
text_encoder_2 = CLIPTextModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="text_encoder_2",
torch_dtype=self.text_encoder_2_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
**common_kwargs,
)

return {
Expand All @@ -184,39 +168,30 @@ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
}

def load_latent_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

if self.vae_id is not None:
vae = AutoencoderKLHunyuanVideo.from_pretrained(
self.vae_id,
torch_dtype=self.vae_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
vae = AutoencoderKLHunyuanVideo.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
else:
vae = AutoencoderKLHunyuanVideo.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="vae",
torch_dtype=self.vae_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
)

return {"vae": vae}

def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

if self.transformer_id is not None:
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
self.transformer_id,
torch_dtype=self.transformer_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
)
else:
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=self.transformer_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
**common_kwargs,
)

scheduler = FlowMatchEulerDiscreteScheduler()
Expand Down Expand Up @@ -256,16 +231,11 @@ def load_pipeline(
pipe.text_encoder_2.to(self.text_encoder_2_dtype)
pipe.vae.to(self.vae_dtype)

_enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
if not training:
pipe.transformer.to(self.transformer_dtype)

if enable_slicing:
pipe.vae.enable_slicing()
if enable_tiling:
pipe.vae.enable_tiling()
if enable_model_cpu_offload:
pipe.enable_model_cpu_offload()

return pipe

@torch.no_grad()
Expand Down
Loading