Skip to content

LTX Video #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 18, 2024
Merged

LTX Video #123

merged 11 commits into from
Dec 18, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Dec 16, 2024

Requires huggingface/diffusers#10228

WIP after rewrite

#!/bin/bash

# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
# export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0

GPU_IDS="2,3"

DATA_ROOT="/raid/aryan/video-dataset-disney"
CAPTION_COLUMN="prompts.txt"
VIDEO_COLUMN="videos.txt"

# Model arguments
model_cmd="--model_name ltx_video \
  --pretrained_model_name_or_path Lightricks/LTX-Video"

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
  --video_column $VIDEO_COLUMN \
  --caption_column $CAPTION_COLUMN \
  --id_token afkx \
  --video_resolution_buckets 49x480x768 \
  --caption_dropout_p 0.05"

# Dataloader arguments
dataloader_cmd="--dataloader_num_workers 0"

# Diffusion arguments
diffusion_cmd="--flow_resolution_shifting"

# Training arguments
training_cmd="--training_type lora \
  --seed 42 \
  --mixed_precision bf16 \
  --batch_size 1 \
  --train_steps 2000 \
  --rank 128 \
  --lora_alpha 64 \
  --target_modules to_q to_k to_v to_out.0 \
  --gradient_accumulation_steps 1 \
  --gradient_checkpointing \
  --checkpointing_steps 500 \
  --checkpointing_limit 2 \
  --enable_slicing \
  --enable_tiling"

# Optimizer arguments
optimizer_cmd="--optimizer adamw \
  --lr 1e-4 \
  --scale_lr \
  --lr_scheduler cosine_with_restarts \
  --lr_warmup_steps 100 \
  --lr_num_cycles 1 \
  --beta1 0.9 \
  --beta2 0.95 \
  --weight_decay 1e-4 \
  --epsilon 1e-8 \
  --max_grad_norm 1.0"

# Validation arguments
validation_cmd="--validation_prompts \"afkx A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.@@@49x480x768:::A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage@@@49x480x768\" \
  --num_validation_videos 1 \
  --validation_steps 100"

# Miscellaneous arguments
miscellaneous_cmd="--tracker_name finetrainers-ltxv \
  --output_dir /raid/aryan/ltx-video \
  --nccl_timeout 1800 \
  --report_to wandb"

cmd="accelerate launch --config_file accelerate_configs/uncompiled_2.yaml --gpu_ids $GPU_IDS train.py \
  $model_cmd \
  $dataset_cmd \
  $dataloader_cmd \
  $diffusion_cmd \
  $training_cmd \
  $optimizer_cmd \
  $validation_cmd \
  $miscellaneous_cmd"

echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"

self.state.model_name = self.args.model_name
self.model_config = get_config_from_model_name(self.args.model_name)

def get_memory_statistics(self, precision: int = 3) -> Dict[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be moved to utils. Or we could use class to keep track of it:
https://github.com/huggingface/peft/blob/ae55fdcc5c4830e0f9fb6e56f16555bafca392de/examples/oft_dreambooth/train_dreambooth.py#L421

Personally, I like the latter approach. Will also check with accelerate folks if they can ship it from accelerate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sounds good. For now, I've moved it to a different file memory_utils.py. We can refactor later

logger = get_logger("finetrainers")
logger.setLevel(FINETRAINERS_LOG_LEVEL)

class Trainer:
Copy link
Member

@sayakpaul sayakpaul Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could provide util methods just like our Mixin classes do and we could make methods like prepare_models() abstract methods and raise a NotImplementedError as needed. So, we'd some thing like CogVideoXTrainer(Trainer).

This way we delegate the abstractions better, IMO, and get rid of nasty if/else blocks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sounds good. I would like to refactor this and make it easier to use, but for this PR, let's roll with sole focus on LTXV. Once I get to Hunyuan and Mochi, it will make it easier to find what needs abstraction and how best to do it. Also yes, providing abstractions for others to implement their own custom methods that are not part of our codebase sounds good

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I can take either of Mochi and Hunyuan to free your plate a bit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Would you like to do Mochi since I have already started on Hunyuan locally?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave it with me. Over n' out.

self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=1,
sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bucket sampling could be made configurable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I will keep it as is and work on the abstractions for this later


# TODO: refactor
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeepSpeed is missed in this case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've not added any DeepSpeed support yet. This was very fragile and I want to take it up another PR so that any model can be used easily without an if-else hell


transformer_lora_config = LoraConfig(
r=self.args.rank,
lora_alpha=self.args.lora_alpha,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just saying that alphas different from LoRA ranks aren't supported during diffusers-formatted LoRA loading as we don't serialize the metadata.

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but that's okay I think. The recommendation from a recent paper was to always train loras with alpha set to half the rank, and it makes sense to just provide this enablement. The users will have to remember their training settings. In our README, we can talk about how to set the appropriate scale via attention_kwargs or set_adapter methods

transformer_lora_layers=transformer_lora_layers_to_save,
)

def load_model_hook(models, input_dir):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing DeepSpeed support.

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take up in another PR, after the more immediate concern of support Hunyuan

tracker_name = self.args.tracker_name or "finetrainers-experiment"
self.state.accelerator.init_trackers(tracker_name, config=self.args.to_dict())

def train(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be factored into include methods like the following:

  • prepare_inputs_for_loss()
  • compute_loss()

Both of these differ in Mochi-1 from the standard ones like Cog.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good to me. For now, I will keep this as-is and looking into how best to abstract these details when I get to Mochi, Cog, and Hunyuan. I think we could nicely design it around most common training strategies

Comment on lines 418 to 439
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1

# Checkpointing
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
logger.info(f"Checkpointing at step {global_step}")
if global_step % self.args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpointing_limit`
if self.args.checkpointing_limit is not None:
checkpoints = os.listdir(self.args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= self.args.checkpointing_limit:
num_to_remove = len(checkpoints) - self.args.checkpointing_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]

logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be factored out into a separate method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaned it up a bit, but yeah will refactor the entire thing in a follow-up as it's not the most important thing



# Default values copied from https://github.com/huggingface/diffusers/blob/8957324363d8b239d82db4909fbf8c0875683e3d/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L47
def resolution_dependant_timestep_flow_shift(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def resolution_dependant_timestep_flow_shift(
def resolution_dependent_timestep_flow_shift(

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a terrific start! I left some comments on the initial structure which I think are easier to incorporate.

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review December 17, 2024 01:36
@@ -0,0 +1,779 @@
import argparse
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we wanna change to fine_video_trainers?

from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS


class Args:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w a-r-r-o-w merged commit 9ef58e2 into main Dec 18, 2024
@a-r-r-o-w a-r-r-o-w deleted the ltxv branch December 18, 2024 22:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants