Skip to content

Remove hard-coded batch size configuration in LTXModel #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 20, 2025

Conversation

Linyou
Copy link
Contributor

@Linyou Linyou commented Mar 20, 2025

No description provided.

@a-r-r-o-w
Copy link
Member

Shouldn't broadcasting automatically handle batch_size > 1 case?

@Linyou
Copy link
Contributor Author

Linyou commented Mar 20, 2025

The broadcasting isn't the issue here.

latent_mean and latent_std have a shape of [batch_size] when batch_size > 1. When these tensors are reshaped via .view(1, -1, 1, 1, 1), the batch_size dimension is flattened into the second axis (resulting in shape [1, batch_size * (-1), 1, 1, 1]).

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Makes sense, thanks a lot! I'll try to handle this better in the future since we shouldn't need to have the same latents_mean and latents_std to be replicated multiple times for batch_size > 1.

@a-r-r-o-w a-r-r-o-w merged commit e9cf70d into huggingface:main Mar 20, 2025
1 check passed
@a-r-r-o-w
Copy link
Member

@Linyou Are you able to pass the batch_size > 1 test for LTX Video? I merged the PR earlier since it seemed correct, but I was running tests today and this fails.

torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___batch_size_2 and ___PTD and LoRA and LTXVideo"

I see that we ignore latents_mean and latents_std for collation (see here), so it should not be replicated when batch_size > 1 and work with broadcasting. Am I missing something important in understanding this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants