-
Notifications
You must be signed in to change notification settings - Fork 133
LTX Video #123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
finetrainers/trainer.py
Outdated
self.state.model_name = self.args.model_name | ||
self.model_config = get_config_from_model_name(self.args.model_name) | ||
|
||
def get_memory_statistics(self, precision: int = 3) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be moved to utils
. Or we could use class to keep track of it:
https://github.com/huggingface/peft/blob/ae55fdcc5c4830e0f9fb6e56f16555bafca392de/examples/oft_dreambooth/train_dreambooth.py#L421
Personally, I like the latter approach. Will also check with accelerate
folks if they can ship it from accelerate
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, sounds good. For now, I've moved it to a different file memory_utils.py
. We can refactor later
logger = get_logger("finetrainers") | ||
logger.setLevel(FINETRAINERS_LOG_LEVEL) | ||
|
||
class Trainer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could provide util methods just like our Mixin classes do and we could make methods like prepare_models()
abstract methods and raise a NotImplementedError
as needed. So, we'd some thing like CogVideoXTrainer(Trainer)
.
This way we delegate the abstractions better, IMO, and get rid of nasty if/else blocks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that sounds good. I would like to refactor this and make it easier to use, but for this PR, let's roll with sole focus on LTXV. Once I get to Hunyuan and Mochi, it will make it easier to find what needs abstraction and how best to do it. Also yes, providing abstractions for others to implement their own custom methods that are not part of our codebase sounds good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I can take either of Mochi and Hunyuan to free your plate a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Would you like to do Mochi since I have already started on Hunyuan locally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave it with me. Over n' out.
self.dataloader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=1, | ||
sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bucket sampling could be made configurable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, I will keep it as is and work on the abstractions for this later
|
||
# TODO: refactor | ||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | ||
def save_model_hook(models, weights, output_dir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DeepSpeed is missed in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I've not added any DeepSpeed support yet. This was very fragile and I want to take it up another PR so that any model can be used easily without an if-else hell
|
||
transformer_lora_config = LoraConfig( | ||
r=self.args.rank, | ||
lora_alpha=self.args.lora_alpha, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just saying that alphas different from LoRA ranks aren't supported during diffusers-formatted LoRA loading as we don't serialize the metadata.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but that's okay I think. The recommendation from a recent paper was to always train loras with alpha set to half the rank, and it makes sense to just provide this enablement. The users will have to remember their training settings. In our README, we can talk about how to set the appropriate scale via attention_kwargs or set_adapter methods
transformer_lora_layers=transformer_lora_layers_to_save, | ||
) | ||
|
||
def load_model_hook(models, input_dir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing DeepSpeed support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will take up in another PR, after the more immediate concern of support Hunyuan
tracker_name = self.args.tracker_name or "finetrainers-experiment" | ||
self.state.accelerator.init_trackers(tracker_name, config=self.args.to_dict()) | ||
|
||
def train(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be factored into include methods like the following:
prepare_inputs_for_loss()
compute_loss()
Both of these differ in Mochi-1 from the standard ones like Cog.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds good to me. For now, I will keep this as-is and looking into how best to abstract these details when I get to Mochi, Cog, and Hunyuan. I think we could nicely design it around most common training strategies
finetrainers/trainer.py
Outdated
# Checks if the accelerator has performed an optimization step behind the scenes | ||
if accelerator.sync_gradients: | ||
progress_bar.update(1) | ||
global_step += 1 | ||
|
||
# Checkpointing | ||
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: | ||
logger.info(f"Checkpointing at step {global_step}") | ||
if global_step % self.args.checkpointing_steps == 0: | ||
# _before_ saving state, check if this save would set us over the `checkpointing_limit` | ||
if self.args.checkpointing_limit is not None: | ||
checkpoints = os.listdir(self.args.output_dir) | ||
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] | ||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) | ||
|
||
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints | ||
if len(checkpoints) >= self.args.checkpointing_limit: | ||
num_to_remove = len(checkpoints) - self.args.checkpointing_limit + 1 | ||
removing_checkpoints = checkpoints[0:num_to_remove] | ||
|
||
logger.info( | ||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be factored out into a separate method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaned it up a bit, but yeah will refactor the entire thing in a follow-up as it's not the most important thing
|
||
|
||
# Default values copied from https://github.com/huggingface/diffusers/blob/8957324363d8b239d82db4909fbf8c0875683e3d/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L47 | ||
def resolution_dependant_timestep_flow_shift( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def resolution_dependant_timestep_flow_shift( | |
def resolution_dependent_timestep_flow_shift( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a terrific start! I left some comments on the initial structure which I think are easier to incorporate.
@@ -0,0 +1,779 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we wanna change to fine_video_trainers
?
from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS | ||
|
||
|
||
class Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be managed with dataclass
. Examples:
Requires huggingface/diffusers#10228
WIP after rewrite