Skip to content

Add validation and batched inference to flux #1205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: flux-train
Choose a base branch
from

Conversation

CarlosGomes98
Copy link
Contributor

@CarlosGomes98 CarlosGomes98 commented May 19, 2025

  • Add val loss
  • Add batched inference

Ideally we would also add COCO2014 as dataset. However, I havent been able to find a hf dataset containing both the images and the captions. So, for now, Ive added a dataset which is just the first 30k samples of the training dataset, for functional verification

This also includes changes from #1138

@facebook-github-bot
Copy link

Hi @CarlosGomes98!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we take first 30_000 samples as validation dataset, will it overlap with the training dataset?

A alternative ways to specify the data_files (eg, dataset = load_dataset("json", data_files={"train": base_url + "train-v1.1.json", "validation": base_url + "dev-v1.1.json"}, field="data")). https://huggingface.co/docs/datasets/en/loading, if we are loading dataset from hugging face directly.

If we are loading data locally, we could keep a _info.json, (https://huggingface.co/datasets/pixparse/cc12m-wds/blob/main/_info.json) to specify train / validation split

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will. This was just a temporary solution to functionally verify the validation loop. I wanted to ask if you had some insights on how we should include the coco2014 dataset, given that its not easily available on hf hub.

Would we add download instructions to the readme and load it locally?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to using coco dataset because the stable diffusion paper? I think we should keep it simplify and just cut some part from the cc12m dataset to work as validation group.


return result

def _coco2014_data_processor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not using coco dataset as validation set right now, we should remove this function

continue
except (UnicodeDecodeError, SyntaxError, OSError) as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my training, I added this line before to capture some data loading error, eg, corrupted image header when PIL.image is reading, or corrupted .tar file header etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I must have removed it by accident

@@ -176,9 +237,9 @@ def denoise(
# create positional encodings
POSITION_DIM = 3
latent_pos_enc = create_position_encoding_for_latents(
bsz, latent_height, latent_width, POSITION_DIM
1, latent_height, latent_width, POSITION_DIM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain a little bit more on this line, why we change batch_size to 1 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this method from bsz to 1 allow the denoise method to deal with batches of images + the possible doubling of the batch size due to classifier free guidance.

For this case, since they will all have the same position encoding, we can set the batch dimension to 1 and allow PyTorch broadcasting to match it to whatever the batch dimension will be

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wwwjn would it be possible for you to test this batched inference code, with / without classifier free guidance, on your trained model? the code functionally runs but I have not tested if it correctly produces images as I dont have a properly trained checkpoint

output_name = os.path.join(output_dir, name)
# bring into PIL format and save
x = x.clamp(-1, 1)
x = rearrange(x[0], "c h w -> h w c")
if len(x.shape) == 4:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we add len(x.shape) == 4 here? Under which cases will this happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change allows the save_image method to correctly handle being passed a single image with or without the batch dimension. In the current code, the image to be saved must always be passed with 4 dimensions, from which we take x[0]

time.perf_counter() - data_load_start
)

def batch_generator(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to separate this diff as what we did before: #1138 , it's easier to track changes seprately

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, will revert and we can merge this in later

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 19, 2025
if (
parallel_dims.dp_replicate_enabled
or parallel_dims.dp_shard_enabled
or parallel_dims.cp_enabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: currently we are not enabling cp for Flux model. We could remove this line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copied this from the train step. For consistency I would either keep it or remove it in both

self.step, force=(self.step == job_config.training.steps)
)

if self.step % job_config.eval.eval_freq == 0 and job_config.eval.dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could wrap all these parts into eval(), and make the main train() loop easier to read.

return global_loss_per_timestep, global_timestep_counts

@record
def inference(self, prompts: list[str], bs: int = 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a line in README, or add a simliar file as run_train.sh to run the inference

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or a better way is to move the inference code outside of the train.py, and create another subclass of FluxTrainer() to do so

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With simplicity in mind, I think I agree with your second suggestion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm somewhat conflicted about using Trainer for inference. On one hand, we definitely would like to re-use all the logic for model loading and parallelization.

On the other hand, it forces us to do things like loading the training dataset, which doesnt really make much sense.

For now I have left it like that, but in the future creating a more light-weight Trainer-like class for inference only may be better.

return global_loss_per_timestep, global_timestep_counts

@record
def inference(self, prompts: list[str], bs: int = 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or a better way is to move the inference code outside of the train.py, and create another subclass of FluxTrainer() to do so

results = torch.cat(results, dim=0)
return results

def generate_and_save_images(self, inputs) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this function to sampling.py, and reuse some of the functions there? We could calculate empty_batch in train.py, and pass it to the function call.

In general, we want to make the train.py similar and clean to read

)
return images

def generate_val_timesteps(self, cur_val_timestep, samples):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this can be moved to sampling.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one I would argue belongs here. It is really only relevant during the training process for validation, not for sampling in general

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

It seems a lot of good stuff is being added. While I can clearly sense the values of most changes, to be honest it's a bit difficult for reviewers to keep track of all the changes and their motivations.

Do you think it's doable to split the changes into several PRs, each with its own theme and documentation as PR summary / doc string / comments?

@@ -21,6 +21,36 @@
from torchtitan.tools.utils import device_module, device_type


def dist_collect(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the only difference is .item() call, we should just move it out of _dist_reduce and reuse the function where you'd use this dist_collect.

name="flux",
cls=FluxModel,
config=flux_configs,
parallelize_fn=parallelize_flux,
pipelining_fn=None,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_flux_dataloader,
build_dataloader_fn=build_flux_train_dataloader,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this aligns with my proposal to do validation in torchtitan (not just for flux but also for other models). #1210
I would hope we can take a more principled approach and make general improvements, instead of doing an ad hoc change here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely. I wanted to enable this functionality for flux asap, so this is hacky.

Since it will involve changes to some central components in torchtitan, I didnt want to attempt a full implementation just yet, and Im not sure I'd have the bandwidth for this, especially if its work that someone is already doing / plans on doing.

I'm happy to remove the validation dataset bit and wait on a proper implementation being added to main. until then the validation metrics I added in this pr could instead target a subset of the training set, for example

) -> Optional[torch.Tensor]:
"""Process CC12M image to the desired size."""

width, height = img.size
# Skip low resolution images
if width < output_size or height < output_size:
if skip_low_resolution and (width < output_size or height < output_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code seems still right with this flag being False -- the smaller dimension will be enlarged to output_size and the other dimension will be enlarged proportionally and then cropped.

But is this used anywhere?

@@ -106,14 +109,14 @@ def _cc12m_wds_data_processor(
result = {
"image": img,
"clip_tokens": clip_tokens, # type: List[int]
"t5_tokens": t5_tokens, # type: List[int]
"t5_tokens": t5_tokens, # type: List[int],
"txt": sample["txt"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm why adding this?

@@ -285,43 +292,50 @@ def __init__(

# Variables for checkpointing
self._sample_idx = 0
self._all_samples: list[dict[str, Any]] = []
self._epoch = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of adding this variable and in general what's the purpose of making these changes around data loader?

dp_world_size: int,
dp_rank: int,
job_config: JobConfig,
# This parameter is not used, keep it for compatibility
tokenizer: FluxTokenizer | None,
infinite: bool = True,
include_sample_id: bool = False,
batch_size: int = 4,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this magic number?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Common MSE loss function for Transformer models training."""
return torch.nn.functional.mse_loss(pred.float(), labels.float().detach())
return torch.nn.functional.mse_loss(pred.float(), labels.float().detach(), reduction=reduction)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not 100% sure, but I think FSDP / PP doesn't work with reduction other than "mean".
Also, why do we need to alter this?

@CarlosGomes98
Copy link
Contributor Author

Thanks for working on this!

It seems a lot of good stuff is being added. While I can clearly sense the values of most changes, to be honest it's a bit difficult for reviewers to keep track of all the changes and their motivations.

Do you think it's doable to split the changes into several PRs, each with its own theme and documentation as PR summary / doc string / comments?

Yes it did grow a bit out of hand. I can definitely split it at least into inference and validation. Will see if I can make it more granular than that

@wwwjn
Copy link
Contributor

wwwjn commented May 21, 2025

@CarlosGomes98 one quick note is flux-train is a little bit behind the main branch, let's just solve the comments and create a PR to main branch instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants