Skip to content

Add qwen3 family #1948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3e80e54
Add qwen3
tianyuan211 Apr 24, 2025
c847b97
Merge branch 'add-qwen3' of https://github.com/tianyuan211/optimum-ha…
tianyuan211 May 7, 2025
479203d
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 May 8, 2025
87d6726
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 May 8, 2025
9ff70d6
fix qwen3 related files
tianyuan211 May 8, 2025
a15ea22
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 May 12, 2025
1b88968
Update modeling_qwen3.py
tianyuan211 May 22, 2025
0b58788
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 May 22, 2025
9fe502f
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 May 23, 2025
0f93054
add qwen3 moe
tianyuan211 May 28, 2025
92166a0
Update test_text_generation_example.json
tianyuan211 May 28, 2025
44fd14d
Update utils.py
tianyuan211 May 29, 2025
9056179
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 Jun 3, 2025
79ac0ec
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 Jun 10, 2025
73b835f
update qwen3 moe
tianyuan211 Jun 10, 2025
a1447cf
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 Jun 11, 2025
2a9d765
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 Jun 13, 2025
4a57d4b
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 Jun 18, 2025
9526267
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 Jun 24, 2025
1736a56
Merge branch 'huggingface:main' into add-qwen3
tianyuan211 Jun 30, 2025
c7e5c98
update qwen3moe related files
tianyuan211 Jun 30, 2025
c29f66f
update
tianyuan211 Jun 30, 2025
576c5e9
Update test_text_generation_example.json
tianyuan211 Jun 30, 2025
c6befe4
Update modeling_utils.py
tianyuan211 Jul 1, 2025
8c27ba1
Update modeling_qwen3_moe.py
tianyuan211 Jul 2, 2025
9ced9f3
Update modeling_qwen3_moe.py
tianyuan211 Jul 2, 2025
2f1b687
Update modeling_qwen3_moe.py
tianyuan211 Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ The following model architectures, tasks and device distributions have been vali
| Phi | :heavy_check_mark: | <li>Single card</li> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Mixtral | | <li>Single card</li> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Persimmon | | <li>Single card</li> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2 | <li>Single card</li> | <li>Single card</li> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2 / Qwen3 | <li>Single card</li> | <li>Single card</li> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2-MoE | | <li>Single card</li> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Gemma | :heavy_check_mark: | <li>Single card</li> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Gemma2 | | :heavy_check_mark: | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| Mixtral | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Gemma | ✅ | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Gemma2 | | ✅ | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2 | <div style="text-align:left"><li>Single card</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2 / Qwen3 | <div style="text-align:left"><li>Single card</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2-MoE | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Persimmon | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| XGLM | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
Expand Down
8 changes: 7 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@
"deepseek_v3",
"chatglm",
"qwen2_vl",
"qwen3",
"qwen3_moe",
]

# Initial generated token index is set to 1 to accomodate SOS (start of string) token.
Expand Down Expand Up @@ -1350,8 +1352,10 @@ def generate(
"chatglm",
"deepseek_v2",
"deepseek_v3",
"qwen3",
"qwen3_moe",
], (
"reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2, baichuan, chatglm and deepseek_v2 at the moment"
"reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, qwen3, qwen3_moe, gemma, gemma2, starcoder2, baichuan, chatglm and deepseek_v2 at the moment"
)
if not generation_config.bucket_internal:
assert generation_config.bucket_size <= 0, (
Expand Down Expand Up @@ -1565,6 +1569,8 @@ def generate(
"qwen2_moe",
"baichuan",
"deepseek_v2",
"qwen3",
"qwen3_moe",
]:
if (
hasattr(self.config, "max_position_embeddings")
Expand Down
30 changes: 30 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,17 @@
GaudiQwen2VLModel,
GaudiQwen2VLSdpaAttention,
GaudiQwen2VLVisionBlock,
GaudiQwen3Attention,
GaudiQwen3DecoderLayer,
GaudiQwen3ForCausalLM,
GaudiQwen3MLP,
GaudiQwen3Model,
GaudiQwen3MoeAttention,
GaudiQwen3MoeDecoderLayer,
GaudiQwen3MoeForCausalLM,
GaudiQwen3MoeMLP,
GaudiQwen3MoeModel,
GaudiQwen3MoeSparseMoeBlock,
GaudiSiglipAttention,
GaudiSiglipEncoder,
GaudiSiglipEncoderLayer,
Expand Down Expand Up @@ -263,6 +274,8 @@
gaudi_qwen2_rmsnorm_forward,
gaudi_qwen2moe_block_sparse_moe_forward,
gaudi_qwen2moe_rmsnorm_forward,
gaudi_qwen3_rmsnorm_forward,
gaudi_qwen3moe_rmsnorm_forward,
gaudi_rot_matmul,
gaudi_rot_vec_mul,
gaudi_SeamlessM4TAttention_forward,
Expand Down Expand Up @@ -707,6 +720,23 @@ def adapt_transformers_to_gaudi():
GaudiQwen2VLForConditionalGeneration
)

# Optimization for qwen3 on Gaudi
transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM = GaudiQwen3ForCausalLM
transformers.models.qwen3.modeling_qwen3.Qwen3Model = GaudiQwen3Model
transformers.models.qwen3.modeling_qwen3.Qwen3Attention = GaudiQwen3Attention
transformers.models.qwen3.modeling_qwen3.Qwen3MLP = GaudiQwen3MLP
transformers.models.qwen3.modeling_qwen3.Qwen3DecoderLayer = GaudiQwen3DecoderLayer
transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm.forward = gaudi_qwen3_rmsnorm_forward

# Optimization for qwen3Moe on Gaudi
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeForCausalLM = GaudiQwen3MoeForCausalLM
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeModel = GaudiQwen3MoeModel
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeAttention = GaudiQwen3MoeAttention
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP = GaudiQwen3MoeMLP
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeDecoderLayer = GaudiQwen3MoeDecoderLayer
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock = GaudiQwen3MoeSparseMoeBlock
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm.forward = gaudi_qwen3moe_rmsnorm_forward

# Optimization for stablelm on Gaudi
transformers.models.stablelm.modeling_stablelm.StableLmAttention = GaudiStableLmAttention
transformers.models.stablelm.modeling_stablelm.StableLmDecoderLayer = GaudiStableLmDecoderLayer
Expand Down
17 changes: 17 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,23 @@
GaudiQwen2VLVisionBlock,
GaudiVisionSdpaAttention,
)
from .qwen3 import (
GaudiQwen3Attention,
GaudiQwen3DecoderLayer,
GaudiQwen3ForCausalLM,
GaudiQwen3MLP,
GaudiQwen3Model,
gaudi_qwen3_rmsnorm_forward,
)
from .qwen3_moe import (
GaudiQwen3MoeAttention,
GaudiQwen3MoeDecoderLayer,
GaudiQwen3MoeForCausalLM,
GaudiQwen3MoeMLP,
GaudiQwen3MoeModel,
GaudiQwen3MoeSparseMoeBlock,
gaudi_qwen3moe_rmsnorm_forward,
)
from .seamless_m4t import (
gaudi_SeamlessM4TAttention_forward,
gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths,
Expand Down
8 changes: 8 additions & 0 deletions optimum/habana/transformers/models/qwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .modeling_qwen3 import (
GaudiQwen3Attention,
GaudiQwen3DecoderLayer,
GaudiQwen3ForCausalLM,
GaudiQwen3MLP,
GaudiQwen3Model,
gaudi_qwen3_rmsnorm_forward,
)
Loading
Loading