-
Notifications
You must be signed in to change notification settings - Fork 377
Add validation and batched inference to flux #1205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: flux-train
Are you sure you want to change the base?
Conversation
…ng dataset index logic
Hi @CarlosGomes98! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we take first 30_000 samples as validation dataset, will it overlap with the training dataset?
A alternative ways to specify the data_files
(eg, dataset = load_dataset("json", data_files={"train": base_url + "train-v1.1.json", "validation": base_url + "dev-v1.1.json"}, field="data")
). https://huggingface.co/docs/datasets/en/loading, if we are loading dataset from hugging face directly.
If we are loading data locally, we could keep a _info.json, (https://huggingface.co/datasets/pixparse/cc12m-wds/blob/main/_info.json) to specify train / validation split
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it will. This was just a temporary solution to functionally verify the validation loop. I wanted to ask if you had some insights on how we should include the coco2014 dataset, given that its not easily available on hf hub.
Would we add download instructions to the readme and load it locally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to using coco dataset because the stable diffusion paper? I think we should keep it simplify and just cut some part from the cc12m dataset to work as validation group.
|
||
return result | ||
|
||
def _coco2014_data_processor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are not using coco dataset as validation set right now, we should remove this function
continue | ||
except (UnicodeDecodeError, SyntaxError, OSError) as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my training, I added this line before to capture some data loading error, eg, corrupted image header when PIL.image is reading, or corrupted .tar file header etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I must have removed it by accident
@@ -176,9 +237,9 @@ def denoise( | |||
# create positional encodings | |||
POSITION_DIM = 3 | |||
latent_pos_enc = create_position_encoding_for_latents( | |||
bsz, latent_height, latent_width, POSITION_DIM | |||
1, latent_height, latent_width, POSITION_DIM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain a little bit more on this line, why we change batch_size
to 1 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes in this method from bsz to 1 allow the denoise method to deal with batches of images + the possible doubling of the batch size due to classifier free guidance.
For this case, since they will all have the same position encoding, we can set the batch dimension to 1 and allow PyTorch broadcasting to match it to whatever the batch dimension will be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wwwjn would it be possible for you to test this batched inference code, with / without classifier free guidance, on your trained model? the code functionally runs but I have not tested if it correctly produces images as I dont have a properly trained checkpoint
output_name = os.path.join(output_dir, name) | ||
# bring into PIL format and save | ||
x = x.clamp(-1, 1) | ||
x = rearrange(x[0], "c h w -> h w c") | ||
if len(x.shape) == 4: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we add len(x.shape) == 4
here? Under which cases will this happen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change allows the save_image method to correctly handle being passed a single image with or without the batch dimension. In the current code, the image to be saved must always be passed with 4 dimensions, from which we take x[0]
torchtitan/train.py
Outdated
time.perf_counter() - data_load_start | ||
) | ||
|
||
def batch_generator( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to separate this diff as what we did before: #1138 , it's easier to track changes seprately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, will revert and we can merge this in later
if ( | ||
parallel_dims.dp_replicate_enabled | ||
or parallel_dims.dp_shard_enabled | ||
or parallel_dims.cp_enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: currently we are not enabling cp for Flux model. We could remove this line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just copied this from the train step. For consistency I would either keep it or remove it in both
torchtitan/experiments/flux/train.py
Outdated
self.step, force=(self.step == job_config.training.steps) | ||
) | ||
|
||
if self.step % job_config.eval.eval_freq == 0 and job_config.eval.dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could wrap all these parts into eval(), and make the main train() loop easier to read.
torchtitan/experiments/flux/train.py
Outdated
return global_loss_per_timestep, global_timestep_counts | ||
|
||
@record | ||
def inference(self, prompts: list[str], bs: int = 1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a line in README, or add a simliar file as run_train.sh
to run the inference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or a better way is to move the inference code outside of the train.py, and create another subclass of FluxTrainer() to do so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With simplicity in mind, I think I agree with your second suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm somewhat conflicted about using Trainer for inference. On one hand, we definitely would like to re-use all the logic for model loading and parallelization.
On the other hand, it forces us to do things like loading the training dataset, which doesnt really make much sense.
For now I have left it like that, but in the future creating a more light-weight Trainer-like class for inference only may be better.
torchtitan/experiments/flux/train.py
Outdated
return global_loss_per_timestep, global_timestep_counts | ||
|
||
@record | ||
def inference(self, prompts: list[str], bs: int = 1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or a better way is to move the inference code outside of the train.py, and create another subclass of FluxTrainer() to do so
torchtitan/experiments/flux/train.py
Outdated
results = torch.cat(results, dim=0) | ||
return results | ||
|
||
def generate_and_save_images(self, inputs) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this function to sampling.py, and reuse some of the functions there? We could calculate empty_batch in train.py, and pass it to the function call.
In general, we want to make the train.py similar and clean to read
torchtitan/experiments/flux/train.py
Outdated
) | ||
return images | ||
|
||
def generate_val_timesteps(self, cur_val_timestep, samples): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this can be moved to sampling.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one I would argue belongs here. It is really only relevant during the training process for validation, not for sampling in general
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
It seems a lot of good stuff is being added. While I can clearly sense the values of most changes, to be honest it's a bit difficult for reviewers to keep track of all the changes and their motivations.
Do you think it's doable to split the changes into several PRs, each with its own theme and documentation as PR summary / doc string / comments?
@@ -21,6 +21,36 @@ | |||
from torchtitan.tools.utils import device_module, device_type | |||
|
|||
|
|||
def dist_collect( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the only difference is .item()
call, we should just move it out of _dist_reduce
and reuse the function where you'd use this dist_collect
.
name="flux", | ||
cls=FluxModel, | ||
config=flux_configs, | ||
parallelize_fn=parallelize_flux, | ||
pipelining_fn=None, | ||
build_optimizers_fn=build_optimizers, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
build_dataloader_fn=build_flux_dataloader, | ||
build_dataloader_fn=build_flux_train_dataloader, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this aligns with my proposal to do validation in torchtitan (not just for flux but also for other models). #1210
I would hope we can take a more principled approach and make general improvements, instead of doing an ad hoc change here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely. I wanted to enable this functionality for flux asap, so this is hacky.
Since it will involve changes to some central components in torchtitan, I didnt want to attempt a full implementation just yet, and Im not sure I'd have the bandwidth for this, especially if its work that someone is already doing / plans on doing.
I'm happy to remove the validation dataset bit and wait on a proper implementation being added to main. until then the validation metrics I added in this pr could instead target a subset of the training set, for example
) -> Optional[torch.Tensor]: | ||
"""Process CC12M image to the desired size.""" | ||
|
||
width, height = img.size | ||
# Skip low resolution images | ||
if width < output_size or height < output_size: | ||
if skip_low_resolution and (width < output_size or height < output_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code seems still right with this flag being False -- the smaller dimension will be enlarged to output_size
and the other dimension will be enlarged proportionally and then cropped.
But is this used anywhere?
@@ -106,14 +109,14 @@ def _cc12m_wds_data_processor( | |||
result = { | |||
"image": img, | |||
"clip_tokens": clip_tokens, # type: List[int] | |||
"t5_tokens": t5_tokens, # type: List[int] | |||
"t5_tokens": t5_tokens, # type: List[int], | |||
"txt": sample["txt"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm why adding this?
@@ -285,43 +292,50 @@ def __init__( | |||
|
|||
# Variables for checkpointing | |||
self._sample_idx = 0 | |||
self._all_samples: list[dict[str, Any]] = [] | |||
self._epoch = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the purpose of adding this variable and in general what's the purpose of making these changes around data loader?
dp_world_size: int, | ||
dp_rank: int, | ||
job_config: JobConfig, | ||
# This parameter is not used, keep it for compatibility | ||
tokenizer: FluxTokenizer | None, | ||
infinite: bool = True, | ||
include_sample_id: bool = False, | ||
batch_size: int = 4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this magic number?
torchtitan/experiments/flux/infer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we adding this file and what's its relationship with https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/tests/test_generate_image.py or https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/sampling.py
"""Common MSE loss function for Transformer models training.""" | ||
return torch.nn.functional.mse_loss(pred.float(), labels.float().detach()) | ||
return torch.nn.functional.mse_loss(pred.float(), labels.float().detach(), reduction=reduction) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe not 100% sure, but I think FSDP / PP doesn't work with reduction other than "mean"
.
Also, why do we need to alter this?
Yes it did grow a bit out of hand. I can definitely split it at least into inference and validation. Will see if I can make it more granular than that |
@CarlosGomes98 one quick note is |
Ideally we would also add COCO2014 as dataset. However, I havent been able to find a hf dataset containing both the images and the captions. So, for now, Ive added a dataset which is just the first 30k samples of the training dataset, for functional verification
This also includes changes from #1138