diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index b59150376599..6ea382d721de 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -70,6 +70,7 @@ def text_encoder_attn_modules(text_encoder): "FluxLoraLoaderMixin", "CogVideoXLoraLoaderMixin", "Mochi1LoraLoaderMixin", + "HunyuanVideoLoraLoaderMixin", "SanaLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] @@ -90,6 +91,7 @@ def text_encoder_attn_modules(text_encoder): AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, FluxLoraLoaderMixin, + HunyuanVideoLoraLoaderMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Mochi1LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b8c44e480093..46d744233014 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3870,6 +3870,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) +class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`HunyuanVideoTransformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components) + + class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index a791a250af08..9c00012ebc65 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -53,6 +53,7 @@ "FluxTransformer2DModel": lambda model_cls, weights: weights, "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, "MochiTransformer3DModel": lambda model_cls, weights: weights, + "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights, "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, "SanaTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 737be99c5a10..089389b5f9ad 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -19,7 +19,8 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor from ..embeddings import ( @@ -32,6 +33,9 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class HunyuanVideoAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -496,7 +500,7 @@ def forward( return hidden_states, encoder_hidden_states -class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). @@ -670,8 +674,24 @@ def forward( encoder_attention_mask: torch.Tensor, pooled_projections: torch.Tensor, guidance: torch.Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + batch_size, num_channels, num_frames, height, width = hidden_states.shape p, p_t = self.config.patch_size, self.config.patch_size_t post_patch_num_frames = num_frames // p_t @@ -757,6 +777,10 @@ def custom_forward(*inputs): hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (hidden_states,) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index bd3d3c1e8485..4423ccf97932 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -20,6 +20,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import HunyuanVideoLoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring @@ -132,7 +133,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class HunyuanVideoPipeline(DiffusionPipeline): +class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): r""" Pipeline for text-to-video generation using HunyuanVideo. @@ -447,6 +448,10 @@ def guidance_scale(self): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def interrupt(self): return self._interrupt @@ -471,6 +476,7 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -525,6 +531,10 @@ def __call__( The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. @@ -562,6 +572,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False device = self._execution_device @@ -640,6 +651,7 @@ def __call__( encoder_attention_mask=prompt_attention_mask, pooled_projections=pooled_prompt_embeds, guidance=guidance, + attention_kwargs=attention_kwargs, return_dict=False, )[0] diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py new file mode 100644 index 000000000000..59464c052684 --- /dev/null +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -0,0 +1,228 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np +import pytest +import torch +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + floats_tensor, + is_torch_version, + require_peft_backend, + skip_mps, + torch_device, +) + + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 + + +@require_peft_backend +@skip_mps +class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = HunyuanVideoPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_classes = [FlowMatchEulerDiscreteScheduler] + scheduler_kwargs = {} + + transformer_kwargs = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 10, + "num_layers": 1, + "num_single_layers": 1, + "num_refiner_layers": 1, + "patch_size": 1, + "patch_size_t": 1, + "guidance_embeds": True, + "text_embed_dim": 16, + "pooled_projection_dim": 8, + "rope_axes_dim": (2, 4, 4), + } + transformer_cls = HunyuanVideoTransformer3DModel + vae_kwargs = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 4, + "down_block_types": ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + "up_block_types": ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + "block_out_channels": (8, 8, 8, 8), + "layers_per_block": 1, + "act_fn": "silu", + "norm_num_groups": 4, + "scaling_factor": 0.476986, + "spatial_compression_ratio": 8, + "temporal_compression_ratio": 4, + "mid_block_add_attention": True, + } + vae_cls = AutoencoderKLHunyuanVideo + has_two_text_encoders = True + tokenizer_cls, tokenizer_id, tokenizer_subfolder = ( + LlamaTokenizerFast, + "hf-internal-testing/tiny-random-hunyuanvideo", + "tokenizer", + ) + tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = ( + CLIPTokenizer, + "hf-internal-testing/tiny-random-hunyuanvideo", + "tokenizer_2", + ) + text_encoder_cls, text_encoder_id, text_encoder_subfolder = ( + LlamaModel, + "hf-internal-testing/tiny-random-hunyuanvideo", + "text_encoder", + ) + text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = ( + CLIPTextModel, + "hf-internal-testing/tiny-random-hunyuanvideo", + "text_encoder_2", + ) + + @property + def output_shape(self): + return (1, 9, 32, 32, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + num_frames = 9 + num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1 + sizes = (4, 4) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "", + "num_frames": num_frames, + "num_inference_steps": 1, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "max_sequence_length": sequence_length, + "prompt_template": {"template": "{}", "crop_start": 0}, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + @pytest.mark.xfail( + condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), + reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", + strict=True, + ) + def test_lora_fuse_nan(self): + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) + + # without we should not see an error, but every image will be black + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) + + out = pipe( + prompt=inputs["prompt"], + height=inputs["height"], + width=inputs["width"], + num_frames=inputs["num_frames"], + num_inference_steps=inputs["num_inference_steps"], + max_sequence_length=inputs["max_sequence_length"], + output_type="np", + )[0] + + self.assertTrue(np.isnan(out).all()) + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + + def test_simple_inference_with_text_denoiser_lora_unfused(self): + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + + # TODO(aryan): Fix the following test + @unittest.skip("This test fails with an error I haven't been able to debug yet.") + def test_simple_inference_save_pretrained(self): + pass + + @unittest.skip("Not supported in HunyuanVideo.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in HunyuanVideo.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in HunyuanVideo.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + def test_simple_inference_with_text_lora_save_load(self): + pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index ac7a944cd026..73ed17049c1b 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -89,12 +89,12 @@ class PeftLoraLoaderMixinTests: has_two_text_encoders = False has_three_text_encoders = False - text_encoder_cls, text_encoder_id = None, None - text_encoder_2_cls, text_encoder_2_id = None, None - text_encoder_3_cls, text_encoder_3_id = None, None - tokenizer_cls, tokenizer_id = None, None - tokenizer_2_cls, tokenizer_2_id = None, None - tokenizer_3_cls, tokenizer_3_id = None, None + text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, None + text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, None + text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, None + tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, None + tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, None + tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, None unet_kwargs = None transformer_cls = None @@ -124,16 +124,26 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): torch.manual_seed(0) vae = self.vae_cls(**self.vae_kwargs) - text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id) - tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id) + text_encoder = self.text_encoder_cls.from_pretrained( + self.text_encoder_id, subfolder=self.text_encoder_subfolder + ) + tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder) if self.text_encoder_2_cls is not None: - text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id) - tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id) + text_encoder_2 = self.text_encoder_2_cls.from_pretrained( + self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder + ) + tokenizer_2 = self.tokenizer_2_cls.from_pretrained( + self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder + ) if self.text_encoder_3_cls is not None: - text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id) - tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id) + text_encoder_3 = self.text_encoder_3_cls.from_pretrained( + self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder + ) + tokenizer_3 = self.tokenizer_3_cls.from_pretrained( + self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder + ) text_lora_config = LoraConfig( r=rank,