😈 This repository offers an unofficial PyTorch implementation of the paper Mean Flows for One-step Generative Modeling, building upon Just-a-DiT and EzAudio.
💬 Contributions and feedback are very welcome — feel free to open an issue or pull request if you spot something or have ideas!
🛠️ This codebase is kept as clean and minimal as possible for easier integration into your own projects — thus, frameworks like Wandb are intentionally excluded.
🚀 Check out our recent project on stylized text-to-speech: 🧢 CapSpeech.
(Yes, this is an ad 😂)
MNIST -- 10k training steps, 1-step sample result:
MNIST -- 6k training steps, 1-step CFG (w=2.0) sample result:
CIFAR-10 -- 200k training steps, 1-step CFG (w=2.0) sample result:
- Implement basic training and inference
- Enable multi-GPU training via 🤗 Accelerate
- Add support for Classifier-Free Guidance (CFG)
- Integrate latent image representation support
- Add tricks like improved CFG mentioned in Appendix
- Improve code clarity and structure, following 🤗 Diffusers style
- Extend to additional modalities (e.g., audio, speech)
jvp
is incompatible with Flash Attention and likely also with Triton, Mamba, and similar libraries.jvp
significantly increases GPU memory usage, even when usingtorch.utils.checkpoint
.- CFG is implemented implicitly, leading to some limitations:
- The CFG scale is fixed at training time and cannot be adjusted during inference.
- Negative prompts are not supported, such as "noise" or "low quality" commonly used in text-to-image diffusion models.
If you find this repo helpful or interesting, consider dropping a ⭐ — it really helps and means a lot!