Skip to content

Pytorch Implementation (unofficial) of the paper "Mean Flows for One-step Generative Modeling" by Geng et al.

License

Notifications You must be signed in to change notification settings

haidog-yaqub/MeanFlow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

63 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MeanFlow

😈 This repository offers an unofficial PyTorch implementation of the paper Mean Flows for One-step Generative Modeling, building upon Just-a-DiT and EzAudio.

💬 Contributions and feedback are very welcome — feel free to open an issue or pull request if you spot something or have ideas!

🛠️ This codebase is kept as clean and minimal as possible for easier integration into your own projects — thus, frameworks like Wandb are intentionally excluded.

📢 Shameless Plug (Ad Spot)

🚀 Check out our recent project on stylized text-to-speech: 🧢 CapSpeech.

(Yes, this is an ad 😂)

Examples

MNIST -- 10k training steps, 1-step sample result:

MNIST

MNIST -- 6k training steps, 1-step CFG (w=2.0) sample result:

MNIST-cfg

CIFAR-10 -- 200k training steps, 1-step CFG (w=2.0) sample result:

CIFAR-10-cfg

TODO

  • Implement basic training and inference
  • Enable multi-GPU training via 🤗 Accelerate
  • Add support for Classifier-Free Guidance (CFG)
  • Integrate latent image representation support
  • Add tricks like improved CFG mentioned in Appendix
  • Improve code clarity and structure, following 🤗 Diffusers style
  • Extend to additional modalities (e.g., audio, speech)

Known Issues (PyTorch)

  • jvp is incompatible with Flash Attention and likely also with Triton, Mamba, and similar libraries.
  • jvp significantly increases GPU memory usage, even when using torch.utils.checkpoint.
  • CFG is implemented implicitly, leading to some limitations:
    • The CFG scale is fixed at training time and cannot be adjusted during inference.
    • Negative prompts are not supported, such as "noise" or "low quality" commonly used in text-to-image diffusion models.

🌟 Like This Project?

If you find this repo helpful or interesting, consider dropping a ⭐ — it really helps and means a lot!

About

Pytorch Implementation (unofficial) of the paper "Mean Flows for One-step Generative Modeling" by Geng et al.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages