Skip to content

Support CogVideoX T2V #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 3, 2025
Merged

Support CogVideoX T2V #165

merged 17 commits into from
Jan 3, 2025

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Dec 30, 2024

Example command:

Command
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export FINETRAINERS_LOG_LEVEL=DEBUG

GPU_IDS="1,2"

DATA_ROOT="/home/sayak/finetrainers/video-dataset-disney"
CAPTION_COLUMN="prompt.txt"
VIDEO_COLUMN="videos.txt"
OUTPUT_DIR="/raid/.cache/huggingface/sayak/cog2b/cog2b_disney"
ID_TOKEN="BW_STYLE"

# Model arguments
model_cmd="--model_name cogvideox \
  --pretrained_model_name_or_path THUDM/CogVideoX-2b"

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
  --video_column $VIDEO_COLUMN \
  --caption_column $CAPTION_COLUMN \
  --id_token $ID_TOKEN \
  --video_resolution_buckets 49x480x720 \
  --caption_dropout_p 0.05"

# Dataloader arguments
dataloader_cmd="--dataloader_num_workers 4"

# Diffusion arguments
diffusion_cmd="--flow_resolution_shifting"

# Training arguments
training_cmd="--training_type lora \
  --seed 42 \
  --mixed_precision fp16 \
  --transformer_dtype fp16 \
  --text_encoder_dtype fp16 \
  --vae_dtype fp16 \
  --batch_size 1 \
  --precompute_conditions \
  --train_steps 10 \
  --rank 128 \
  --lora_alpha 128 \
  --target_modules to_q to_k to_v to_out.0 \
  --gradient_accumulation_steps 1 \
  --gradient_checkpointing \
  --checkpointing_steps 5 \
  --checkpointing_limit 2 \
  --resume_from_checkpoint=latest \
  --enable_slicing \
  --enable_tiling"

# Optimizer arguments
optimizer_cmd="--optimizer adamw \
  --use_8bit_bnb \
  --lr 3e-5 \
  --lr_scheduler constant_with_warmup \
  --lr_warmup_steps 100 \
  --lr_num_cycles 1 \
  --beta1 0.9 \
  --beta2 0.95 \
  --weight_decay 1e-4 \
  --epsilon 1e-8 \
  --max_grad_norm 1.0"

# Validation arguments
validation_cmd="--validation_prompts \"$ID_TOKEN A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.@@@49x512x768:::A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage@@@49x512x768\" \
  --num_validation_videos 1 \
  --validation_steps 100"

# Miscellaneous arguments
miscellaneous_cmd="--tracker_name finetrainers-cog \
  --output_dir $OUTPUT_DIR \
  --nccl_timeout 1800 \
  --report_to wandb"

cmd="accelerate launch --config_file accelerate_configs/deepspeed.yaml --gpu_ids $GPU_IDS train.py \
  $model_cmd \
  $dataset_cmd \
  $dataloader_cmd \
  $diffusion_cmd \
  $training_cmd \
  $optimizer_cmd \
  $validation_cmd \
  $miscellaneous_cmd"

echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"

Run: https://wandb.ai/sayakpaul/finetrainers-cog/runs/he705j4z

Testing Cog 1.5 as well as the 5B variant from 1.0.

@sayakpaul
Copy link
Member Author

@sayakpaul
Copy link
Member Author

@sayakpaul sayakpaul marked this pull request as ready for review December 30, 2024 12:35
return [("video", output)]


def _get_t5_prompt_embeds(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, I would love to think of a way to reuse encode_prompt() but that is a battle for a different day.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

during fine-tuning, if cfg is not used, should the negative prompt = "" also be encoded here (encode the prompt and negative_prompt = "" for all lora / sft datasets)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We randomly drop out the caption or zero it out based on the scheme chosen. So, that should suffice I guess?

@sayakpaul sayakpaul requested a review from a-r-r-o-w December 30, 2024 12:41
Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick look, this is different from how I would have tried to implement the CogVideoX integration, and this is where we could introduce some better abstraction for loss calculation. Will do a refactor tomorrow if I find time or whenever I'm back from new year's holidays

As such, the trainer should not contain anything model-specific. If something needs to be implemented per-model, then it should be done via a helper utility invoked from the trainer.

For schedulers, we need to support three kinds:

  • CogVideoXDDPM
  • FlowMatching
  • Normal DDPM (we don't have any model yet, so this is not actionable)



def post_latent_preparation(latents: torch.Tensor, **kwargs) -> torch.Tensor:
return {"latents": latents}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also have to use vae scaling factor here. The same bug is in Hunyuan Video (will update in separate PR).

from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid


def prepare_rotary_positional_embeddings(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In diffusers, we plan to make the RoPE a part of the modeling components itself like done for Flux, Mochi, Hunyuan Video, etc. Will be doing it for Cog too so that we don't have to do it here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Until that's happening I guess it will stay here?

@sayakpaul
Copy link
Member Author

From a quick look, this is different from how I would have tried to implement the CogVideoX integration, and this is where we could introduce some better abstraction for loss calculation. Will do a refactor tomorrow if I find time or whenever I'm back from new year's holidays

As such, the trainer should not contain anything model-specific. If something needs to be implemented per-model, then it should be done via a helper utility invoked from the trainer.

I think I get the idea you're trying to convey and can do it myself. I will let you do the reactors you mentioned in comments like this.

IMO, in this PR:

  • Implement model-specific changes (Cog) in utilities and have them called from the trainer.
  • Minimal abstraction for loss and introduced in a separate PR so that we can iterate faster. Even I have some ideas to do that.

I will revert the generator changes too.

Copy link
Contributor

@zRzRzRzRzRzRzR zRzRzRzRzRzRzR left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked questions about some code, looking forward to the answers. Thank you for your support of CogVideoX.

@sayakpaul
Copy link
Member Author

Thanks @zRzRzRzRzRzRzR! I am working on addressing some of @a-r-r-o-w's feedback and will let both of you know once I have pushed the latest changes.

@sayakpaul
Copy link
Member Author

@a-r-r-o-w, in the latest commits, I have

  • Reverted generator related changes.
  • Addressed your comments on post_latent_preparation().
  • Moved model-specific utilities as much as I could.
  • Introduced an abstraction for computing losses.

PTAL and I can iterate further if need be.

@zRzRzRzRzRzRzR, thanks for suggesting the changes related to padding the frames. I have committed your fix here. PTAL and you're also welcome to review the other changes too.

Comment on lines 705 to 715
if "calculate_timesteps" in self.model_config.keys():
timesteps = self.model_config["calculate_timesteps"](
scheduler=self.scheduler,
latent_conditions=latent_conditions,
generator=self.state.generator
)
else: # As flow-based calculations are more common for now.
sigmas = self.derive_sigmas_for_flow(batch_size=batch_size, scheduler_sigmas=scheduler_sigmas)
timesteps = (sigmas * 1000.0).long()

# Same reason as above.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because right now, the else branch is more common so it made sense to do it this way. Open to changes.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay to merge now. Just wanted to do a few 5000 step runs on each of the following cases to ensure everything is working as expected:

  • Does SNR weighting for Cog have to be 1 / (1 - alphas_cumprod)? This seems incorrect in the original implementation and our old script. SNR should be calculated as alphas_cumprod / (1 - alphas_cumprod). If the latter works better, that is what we should use.
  • I've commented out the invert_scale_latent bits. I think it is required only for Image-to-Video training, but will do a test run to confirm.
  • CogVideoX 1.0 LoRA with current SNR implementation
  • CogVideoX 1.5 LoRA with current SNR implementation

TODO:

  • make SNR weighting opt-in. Should be done in a follow-up PR as I would like to add different commonly options for DDIM weighting and test all of them
  • update README like we have for LTX and Hunyuan. Should think about creating a docs/ folder with model-specific guides and training run examples

I believe Image-to-Video is not supported yet. Let's do it in follow-up PR.

@sayakpaul
Copy link
Member Author

sayakpaul commented Jan 3, 2025

@a-r-r-o-w I don't appreciate the intervention like this. I have tried respecting the comments and have tried my best to address and sort them out in the good spirits.

The changes you have added go without any discussions or any reviews whatsoever. I don't think this this is explained by "we want to ship fast".

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Jan 3, 2025

I thought you said above that you were okay with me taking the refactors up.

To address why the implementation before my latest commits were a no-go:

  • When asked for a review, I assume PR is complete, but multi-res training did not even run. The code had batch_size, num_channels, num_frames, height, width in forward_pass. This creates RoPE of incorrect dimension because it should have been batch_size, num_frames, num_channels, height, width.
  • Why is loss part of the trainer? In my mind, it should be completely decoupled because it only depends on training objective and not be part of the trainer. If we have N different trainers, we will then have N different implementations of the same thing (ofcourse, we would then refactor it to the same/similar state as it is now). I'm made it clear that we need to decouple things and introduce good abstractions as needed - this will allow only one true source of (possibly any) errors for localization will help us debug this faster when we jump to multiple trainers.
  • In my review comment above, I mentioned that post_latent_preparation should only be called during precomputation. It was still being called in both cases when I was asked for latest review and the comment was marked as resolved. What does it mean to be resolved here? Either:
    • you have addressed it and reverted back to original implementation
    • as I mentioned in review comment that I will take it up, it means you are okay with me taking it up
  • I do not understand the point of having separate weighting schemes per model, or timestep preparation. I have never seen this in practice. These should be decoupled completely from trainer and modeling utilities. This PR served as one example where a new abstraction should be introduced, but the new abstraction was not okay IMHO.
  • What is the need of another denoiser_config when we already have self.transformer_config?
  • Cloning alphas_cumprod at every forward pass. Why should this be different from what we do for sigmas and maintaining one copy in the trainer?

The reason for doing it this way is:

  • I want to eventually implement model parallelism. This would involve pipelining different stages. If you couple things like the way it was before (different models having different loss calculations, invoking utility functions within themselves). it is going to be a hard refactor in the future compared to now. From the very beginning, I've made clear many times that trainer should invoke utilities. Each utility function should do one thing and one thing only, for respective stage of preparation/training loop. Following this design will make it easy for me when I get to it.
  • implement FSDP without debugging nightmares from all the weird code paths introduced. As said before, accelerator and state should not be passed around. Trainer should invoke utilities with appropriate params. Utilities should have no knowledge of that.
  • instead of spending more time trying to convey what I mean (like I'm doing with this long message), we want to ship fast so I did exactly that and tried to move it to completion. We've already spent a lot of time on something that is not "new" and I think it is a waste of time when considering arguable more important features for users

When I said "we want to ship fast", I understood most of how this was to be designed for minimal coupling between different moving parts. If I had implemented it, this would have been a straight shot 30 mins of work because:

  • I implemented the CogVideoX training scripts and know exactly what changes would be needed, and
  • implemented the Trainer API with exactly this design in mind of decoupling trainer, models, loss, etc.

From working with all the existing trainers for image models, I have some ideas in my mind that I would like to re-use and some that I want to introduce. For me, it is clear on how to do it fast and move fast, and the changes that you see are exactly how I would have done it first-try. So, in that sense, I do think we are moving fast.

That said, I was under the assumption that it was okay for me to take up the refactor, and only then made the changes. With that in mind, will no longer modify your PRs unless explicitly granted permission. Explaining, iterating and testing all this would just have take a lot more time, and I want to prioritize more important things:

  • Layerwise Upcasting
  • Groupwise Offloading
  • Custom triton kernels

This will help us truly be a "memory-efficient" trainer because currently we are nowhere near that for decent resolutions people want to use.

Eventually, instead of the functions that we use for "model utilities" can be made into ModelSpec classes with syntax sugar, so that we don't have to work with all the kwargs ignored, and other hacks made to support the three models we have now.

@sayakpaul
Copy link
Member Author

In my review comment above, I mentioned that post_latent_preparation should only be called during precomputation. It was still being called in both cases when I was asked for latest review and the comment was marked as resolved. What does it mean to be resolved here?

I think the latent scaling is required regardless of you're precomputing or not no? But I see that you have moved that to prepare_latents() and post_latent_preparation() accordingly. That is fine by me.

@sayakpaul sayakpaul merged commit b8352ab into main Jan 3, 2025
@sayakpaul sayakpaul deleted the cog branch January 3, 2025 05:24
@sayakpaul
Copy link
Member Author

When I said "we want to ship fast", I understood most of how this was to be designed for minimal coupling between different moving parts. If I had implemented it, this would have been a straight shot 30 mins of work because:

Yeah of course. As also mentioned over DM, if you have collaborators on the repo who are willing to take work with good intentions, it doesn't hurt to allow them a bit more time and to allow them to eventually grow in doing so. If I were being slow, I would not have adjusted the this PR itself quickly, addressing your feedback.

It's not repetitive, either. This was the first PR I wanted to take up and would have been a good opportunity for me to develop the muscle memory of the design we're hoping to have. Doesn't hurt to have this much of allowance.

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Jan 3, 2025

I think the latent scaling is required regardless of you're precomputing or not no? But I see that you have moved that to prepare_latents() and post_latent_preparation() accordingly. That is fine by me.

The code for non-preprocessing latent training and multiplying the scaling factor (in prepare_latents) was not added by me. My commits don't have any such changes. At the time I was asked for review, the scaling factor was being multiplied twice for non-preprocessing code path - once in prepare_latents and once in post_latent_preparation (via call to _scale_latents), which was also wrong. It was only correct when using pre-processed latents was enabled.

That said, sincere apologies for the misunderstanding on my part that you were okay with me taking up the refactor.

@sayakpaul
Copy link
Member Author

At the time I was asked for review, the scaling factor was being multiplied twice for non-postprocessing - once in prepare_latents and once in post_latent_preparation (via call to _scale_latents)

It was properly guarded:

def post_latent_preparation(latents: torch.Tensor, **kwargs) -> torch.Tensor:
    if kwargs.get("precompute_conditions", False) and kwargs.get("vae_config", None) is not None:
        latents = _scale_latents(latents, kwargs.get("vae_config"))
    latents = _pad_frames(latents, kwargs.get("denoier_config", None))
    latents = latents.permute(0, 2, 1, 3, 4)  # [B, F, C, H, W]
    return {"latents": latents}

Isn't this exactly what was done in this commit?

def prepare_latents(
    vae: AutoencoderKLCogVideoX,
    image_or_video: torch.Tensor,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    generator: Optional[torch.Generator] = None,
    precompute: bool = False,
    **kwargs,
) -> torch.Tensor:
    device = device or vae.device
    dtype = dtype or vae.dtype
    if image_or_video.ndim == 4:
        image_or_video = image_or_video.unsqueeze(2)
    assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
    image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
    image_or_video = image_or_video.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]
    if not precompute:
        latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
        if not vae.config.invert_scale_latents:
            latents = latents * vae.config.scaling_factor
        else:

...
def post_latent_preparation(vae_config: Dict[str, Any], latents: torch.Tensor, **kwargs) -> torch.Tensor:
    if not vae_config.invert_scale_latents:
        latents = latents * vae_config.scaling_factor
    else:
        latents = 1 / vae_config.scaling_factor * latents
    latents = _pad_frames(latents, kwargs.get("denoier_config", None))
    latents = latents.permute(0, 2, 1, 3, 4)  # [B, F, C, H, W]
    return {"latents": latents}

Or am I terribly misunderstanding something?

@a-r-r-o-w a-r-r-o-w mentioned this pull request Jan 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants