From 3e80e5490d82851e5301481a70a554ee947941d8 Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Thu, 24 Apr 2025 17:30:58 +0800 Subject: [PATCH 01/14] Add qwen3 --- .../habana/transformers/generation/utils.py | 4 +- optimum/habana/transformers/modeling_utils.py | 14 + .../habana/transformers/models/__init__.py | 8 + .../transformers/models/qwen3/__init__.py | 8 + .../models/qwen3/modeling_qwen3.py | 1166 +++++++++++++++++ 5 files changed, 1199 insertions(+), 1 deletion(-) create mode 100644 optimum/habana/transformers/models/qwen3/__init__.py create mode 100644 optimum/habana/transformers/models/qwen3/modeling_qwen3.py diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index ad9db2191b..78ebd05ac0 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -130,6 +130,7 @@ "deepseek_v3", "chatglm", "qwen2_vl", + "qwen3", ] # Initial generated token index is set to 1 to accomodate SOS (start of string) token. @@ -1302,8 +1303,9 @@ def generate( "chatglm", "deepseek_v2", "deepseek_v3", + "qwen3", ], ( - "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2, baichuan, chatglm and deepseek_v2 at the moment" + "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, qwen3, gemma, gemma2, starcoder2, baichuan, chatglm and deepseek_v2 at the moment" ) if not generation_config.bucket_internal: assert generation_config.bucket_size <= 0, ( diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 708dacab9a..ae3a540538 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -163,6 +163,11 @@ GaudiQwen2VLModel, GaudiQwen2VLSdpaAttention, GaudiQwen2VLVisionBlock, + GaudiQwen3Attention, + GaudiQwen3DecoderLayer, + GaudiQwen3ForCausalLM, + GaudiQwen3MLP, + GaudiQwen3Model, GaudiStableLmAttention, GaudiStableLmDecoderLayer, GaudiStableLmForCausalLM, @@ -254,6 +259,7 @@ gaudi_qwen2_rmsnorm_forward, gaudi_qwen2moe_block_sparse_moe_forward, gaudi_qwen2moe_rmsnorm_forward, + gaudi_qwen3_rmsnorm_forward, gaudi_rot_matmul, gaudi_rot_vec_mul, gaudi_SeamlessM4TAttention_forward, @@ -686,6 +692,14 @@ def adapt_transformers_to_gaudi(): transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration = ( GaudiQwen2VLForConditionalGeneration ) + + # Optimization for qwen3 on Gaudi + transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM = GaudiQwen3ForCausalLM + transformers.models.qwen3.modeling_qwen3.Qwen3Model = GaudiQwen3Model + transformers.models.qwen3.modeling_qwen3.Qwen3Attention = GaudiQwen3Attention + transformers.models.qwen3.modeling_qwen3.Qwen3MLP = GaudiQwen3MLP + transformers.models.qwen3.modeling_qwen3.Qwen3DecoderLayer = GaudiQwen3DecoderLayer + transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm.forward = gaudi_qwen3_rmsnorm_forward # Optimization for stablelm on Gaudi transformers.models.stablelm.modeling_stablelm.StableLmAttention = GaudiStableLmAttention diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 5c611492bd..9523c76eb4 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -270,6 +270,14 @@ GaudiQwen2VLVisionBlock, GaudiVisionSdpaAttention, ) +from .qwen3 import ( + GaudiQwen3Attention, + GaudiQwen3DecoderLayer, + GaudiQwen3ForCausalLM, + GaudiQwen3MLP, + GaudiQwen3Model, + gaudi_qwen3_rmsnorm_forward, +) from .seamless_m4t import ( gaudi_SeamlessM4TAttention_forward, gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths, diff --git a/optimum/habana/transformers/models/qwen3/__init__.py b/optimum/habana/transformers/models/qwen3/__init__.py new file mode 100644 index 0000000000..156f298a5a --- /dev/null +++ b/optimum/habana/transformers/models/qwen3/__init__.py @@ -0,0 +1,8 @@ +from .modeling_qwen3 import ( + GaudiQwen3Attention, + GaudiQwen3DecoderLayer, + GaudiQwen3ForCausalLM, + GaudiQwen3MLP, + GaudiQwen3Model, + gaudi_qwen3_rmsnorm_forward, +) diff --git a/optimum/habana/transformers/models/qwen3/modeling_qwen3.py b/optimum/habana/transformers/models/qwen3/modeling_qwen3.py new file mode 100644 index 0000000000..80c9e8eb8f --- /dev/null +++ b/optimum/habana/transformers/models/qwen3/modeling_qwen3.py @@ -0,0 +1,1166 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### +# Copyright (C) 2022-2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +from typing import List, Optional, Tuple, Union + +import torch +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers.models.qwen3.modeling_qwen3 import ( + KwargsForCausalLM, + Qwen3Attention, + Qwen3DecoderLayer, + Qwen3ForCausalLM, + Qwen3MLP, + Qwen3Model, + Qwen3RMSNorm, + apply_rotary_pos_emb, + logger, +) +from transformers.processing_utils import Unpack + +from ....distributed import parallel_state +from ...modeling_attn_mask_utils import ( + _gaudi_prepare_4d_causal_attention_mask, +) +from ...modeling_rope_utils import GaudiRotaryEmbedding +from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module + + +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa + + has_fused_rope = True +except ImportError: + has_fused_rope = False + print("Not using HPU fused kernel for apply_rotary_pos_emb") + +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm + + has_fused_rms_norm = True +except ImportError: + has_fused_rms_norm = False + print("Not using HPU fused kernel for RMSNorm") + +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + +import habana_frameworks.torch.core as htcore + + +def gaudi_qwen3_rmsnorm_forward(self, hidden_states): + if hidden_states.device.type == "hpu" and has_fused_rms_norm: + # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype + if hidden_states.dtype != self.weight.dtype: + orig_dtype = hidden_states.dtype + hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) + return hidden_states.to(orig_dtype) + else: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class GaudiQwen3MLP(Qwen3MLP): + def pre_mlp_forward(self, x): + inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + output = self.down_proj(inputs) + return output + + def mlp_all_reduce(self, x): + if hasattr(self.down_proj, "all_reduce"): + self.down_proj.all_reduce(x) + + def post_mlp_forward(self, x): + if hasattr(self.down_proj, "post_all_reduce"): + return self.down_proj.post_all_reduce(x) + return x + + +def gaudi_qwen3_repeat_kv( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask + + +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + self.scale = scale + self.attention_dropout = attention_dropout + self.enable_recompute = enable_recompute + self.flash_attention_fp8 = flash_attention_fp8 + + def forward( + self, + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side="left", + ): + return self._hpu_kernel_fsdpa.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) + + +def gaudi_eager_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + attn_softmax_bf16: bool = False, + **kwargs, +): + bsz, q_len = kwargs["input_shape"] + query_states, key_states, value_states, attention_mask = gaudi_qwen3_repeat_kv( + query, key, value, attention_mask, module.num_key_value_groups + ) + + query_states = query_states * scaling + attn_weights = module.matmul_qk(query_states, key_states.transpose(-2, -1)).float() + htcore.mark_step() + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = module.matmul_av(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, module.head_dim) + + return attn_output, attn_weights + + +class GaudiDistributedAttention(torch.nn.Module): + def __init__( + self, hpu_module_fsdpa: ModuleFusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8 + ): + super().__init__() + self._hpu_module_fsdpa = hpu_module_fsdpa + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + from deepspeed.sequence.layer import DistributedAttention + + self._hpu_module_fsdpa_distributed = DistributedAttention( + self._hpu_module_fsdpa, parallel_state.get_sequence_parallel_group(), 1, 2 + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor, + dropout_p: float, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side="left", + ): + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + return self._hpu_module_fsdpa_distributed( + query, + key, + value, + 0, # As the shape for inputs is [B, N, S, H] + None, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) + else: + return self._hpu_module_fsdpa( + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) + + +def get_gaudi_distributed_attention( + fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed +): + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + return fused_scaled_dot_product_attention_distributed + else: + return fused_scaled_dot_product_attention + + +class GaudiQwen3Attention(Qwen3Attention): + def __init__(self, config: Qwen3Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + + self.matmul_qk = Matmul() + self.matmul_av = Matmul() + self.k_cache = KVCache() + self.v_cache = KVCache() + + self.inp_seq_len = -1 + + self.rotary_emb = GaudiRotaryEmbedding(config=self.config) + + self.fused_scaled_dot_product_attention = ( + ModuleFusedSDPA( + FusedSDPA, + scale=self.scaling, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=getattr(config, "flash_attention_fp8", False), + ) + if FusedSDPA + else None + ) + # for all2all comm, Distributed Attention cares about sequence (s) and number of heads (h) dimensions. In HPU, they are at 1 and 2 indices + self.fused_scaled_dot_product_attention_distributed = None + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + self.fused_scaled_dot_product_attention_distributed = ( + GaudiDistributedAttention( + self.fused_scaled_dot_product_attention, + scale=self.scaling, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=getattr(config, "flash_attention_fp8", False), + ) + if FusedSDPA + else None + ) + + self.num_key_value_heads = config.num_key_value_heads + + def get_k_proj_weight(self): + """4bit quantization in GPTQ replaces the k_proj.weight with qweight.""" + if hasattr(self.k_proj, "qweight"): + return self.k_proj.qweight + return self.k_proj.weight + + def get_k_proj_weight_dtype(self): + """4bit quantization in GPTQ replaces the k_proj.weight with qweight. + Scales tensor gets the weight dtype.""" + if hasattr(self.k_proj, "qweight"): + return self.k_proj.scales.dtype + return self.k_proj.weight.dtype + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.get_k_proj_weight().device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + + def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.get_k_proj_weight(), seq_len=seq_len) + + def reorder(self, tensor, beam_idx, dim_a, dim_b): + updated = tensor.index_select(0, beam_idx) + tensor.copy_(updated) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + if self.k_cache.cache is None: + return (None, None) + + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) + + def pre_attn_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, + cache_idx: int = None, + num_virtual_tokens: int = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + The only differences are: + - add new args token_idx + - optimize KV cache + - add new args attn_softmax_bf16 + - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask + - add new arg flash_attention_fast_softmax + - add new arg num_virtual_tokens + """ + input_shape = hidden_states.shape[:-1] + q_len = input_shape[1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] + else: + if reuse_cache and not isinstance(past_key_value[0], torch.Tensor): + kv_seq_len = past_key_value[0][-2] + else: + if num_virtual_tokens is not None and num_virtual_tokens == past_key_value[0].shape[-2]: + kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len + else: + kv_seq_len = past_key_value[0].shape[-2] + + seq_len = kv_seq_len + if parallel_state.sequence_parallel_is_initialized(): + seq_len = kv_seq_len * parallel_state.get_sequence_parallel_world_size() + + cos, sin = self.rotary_emb(value_states, seq_len=seq_len) + # If sequence parallel in enabled, position_ids should be based on which part of the sequence is present in the rank + # As we divide the inputs based on ranks, position_ids are generated to suit that part of the sequence + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_rank() > 0: + position_ids = torch.arange( + kv_seq_len * parallel_state.get_sequence_parallel_rank(), + kv_seq_len * (parallel_state.get_sequence_parallel_rank() + 1), + dtype=torch.long, + device=query_states.device, + ) + position_ids = position_ids.unsqueeze(0) + + query_states, key_states = apply_customized_rope( + query_states, key_states, cos, sin, kwargs["position_ids"], self.training + ) + + if use_cache: + # reuse k, v, self_attention + if reuse_cache: + if past_key_value is not None and isinstance(past_key_value[0], torch.Tensor): + # prefix tuning case. attach past_key_value to generate first token. + key_states = torch.cat((past_key_value[0], key_states), -2) + value_states = torch.cat((past_key_value[1], value_states), -2) + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if past_key_value is None: + past_key = torch.zeros( + key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device + ) + past_value = torch.zeros( + key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device + ) + # Return list instead of tuple + past_key_value = [past_key, past_value] + if ( + token_idx is not None + and num_virtual_tokens is not None + and num_virtual_tokens == past_key_value[0].shape[-2] + ): + # prefix tuning case. attach past_key_value to generate first token. + key_states = torch.cat((past_key_value[0], key_states), -2) + value_states = torch.cat((past_key_value[1], value_states), -2) + past_key_value = (key_states, value_states) + else: + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + else: + past_key_value = None + fused_scaled_dot_product_attention = get_gaudi_distributed_attention( + self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed + ) + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + if use_flash_attention and FusedSDPA is not None: + attn_weights = None + if self.training: + attn_output = fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + "None", + False, + None, + "None", + ) + elif q_len == 1: + # next token + attn_output = fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + "fp32", + False, + None, + "None", + ) + else: + # first token + softmax_mode = "fp32" + if flash_attention_causal_mask: + attn_output = fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + None, + 0.0, + True, + None, + softmax_mode, + flash_attention_recompute, + valid_sequence_lengths, + "left", + ) + else: + attn_output = fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + softmax_mode, + flash_attention_recompute, + None, + "None", + ) + + else: + attn_output, attn_weights = gaudi_eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + attn_softmax_bf16=attn_softmax_bf16, + input_shape=input_shape, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1: + # Return only past key value shapes and not the tensors during decode phase (q len is 1) + # to avoid making past key values as persistent output tensors of HPU graphs. + past_key_value = (past_key_value[0].shape, past_key_value[1].shape) + + return attn_output, attn_weights, past_key_value + + def attention_all_reduce(self, attn_output): + if hasattr(self.o_proj, "all_reduce"): + self.o_proj.all_reduce(attn_output) + + def post_attn_forward(self, attn_output): + if hasattr(self.o_proj, "post_all_reduce"): + return self.o_proj.post_all_reduce(attn_output) + return attn_output + + +class GaudiQwen3DecoderLayer(Qwen3DecoderLayer): + def __init__(self, config: Qwen3Config, layer_idx: int): + super(Qwen3DecoderLayer, self).__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GaudiQwen3Attention(config, layer_idx) + + self.mlp = GaudiQwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.self_attn.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.self_attn.update_sincos_cache(seq_len) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, + cache_idx: int = None, + num_virtual_tokens: int = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states, self_attn_weights, present_key_value = self.pre_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + **kwargs, + ) + + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) + + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.post_mlp(hidden_states, residual) + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + + return outputs + + def pre_attn( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, + cache_idx: int = None, + num_virtual_tokens: int = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + hidden_states = self.input_layernorm(hidden_states) + hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + **kwargs, + ) + return hidden_states, attn_weights, present_key_value + + def post_attn_pre_mlp(self, hidden_states, residual): + hidden_states = self.self_attn.post_attn_forward(hidden_states) + + if self.training: + hidden_states = hidden_states + residual + residual = hidden_states + else: + residual.add_(hidden_states) + hidden_states = residual + + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp.pre_mlp_forward(hidden_states) + return hidden_states, residual + + def post_mlp(self, hidden_states, residual): + hidden_states = self.mlp.post_mlp_forward(hidden_states) + + if self.training: + hidden_states = hidden_states + residual + else: + residual.add_(hidden_states) + hidden_states = residual + + return hidden_states + + +class GaudiQwen3Model(Qwen3Model): + def __init__(self, config: Qwen3Config): + """ + Copied from https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/qwen3/modeling_qwen3.py#L920 + 1. set fill_value to 1 instead of True + 2. add device=self.device + """ + super(Qwen3Model, self).__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = torch.nn.ModuleList( + [GaudiQwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) + + def update_sincos_cache(self, seq_len): + for layer in self.layers: + layer.update_sincos_cache(seq_len) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, + num_virtual_tokens: int = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + ignore_cache_position = True # Ignoring cache position for HPU + use_new_cache = False # Ignoring new Cache path for HPU + + past_seen_tokens = 0 + + if past_key_values is not None and use_cache: # kept for BC (cache positions) + if reuse_cache: + if isinstance(past_key_values[0][0], torch.Tensor): + past_seen_tokens = past_key_values[0][0].shape[2] + else: + past_seen_tokens = past_key_values[0][0][2] + else: + if use_new_cache: + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + else: + past_seen_tokens = past_key_values[0][0].shape[2] + + if ignore_cache_position is False: + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None and cache_position: + position_ids = cache_position.unsqueeze(0) + else: + if position_ids is None: + position_ids = torch.arange( + past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device + ) + position_ids = position_ids.unsqueeze(0) + cache_position = None + + # HPU specific mask generation + if ignore_cache_position: + causal_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, + ) + else: + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if not use_new_cache else None + + if lazy_mode: + htcore.mark_step() + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + if ( + lazy_mode + and not self.training + and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) + ): + htcore.mark_step() + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + None, + None, + attn_softmax_bf16, + False, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + flash_attention_fast_softmax, + valid_sequence_lengths, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class GaudiQwen3ForCausalLM(Qwen3ForCausalLM): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.model.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.model.update_sincos_cache(seq_len) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + token_idx: Optional[torch.Tensor] = None, + trim_logits: Optional[bool] = False, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, + num_virtual_tokens: int = None, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.generation_config.use_fused_rope is False: + global has_fused_rope + has_fused_rope = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + lazy_mode=lazy_mode, + num_virtual_tokens=num_virtual_tokens, + ) + + hidden_states = outputs[0] + _, seq_len, _ = hidden_states.shape + if seq_len > 1 and trim_logits and not self.training: + if token_idx is not None: + hidden_states = hidden_states.index_select(1, token_idx - 1) + else: + hidden_states = hidden_states[:, -1, :] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]).float() + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + token_idx=None, + **kwargs, + ): + reuse_cache = kwargs.get("reuse_cache") + bucket_internal = kwargs.get("bucket_internal") + if past_key_values is not None: + if token_idx is not None: + idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 + input_ids = torch.index_select(input_ids, 1, idx) + else: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + elif (reuse_cache or bucket_internal) and token_idx is not None: + # KV cache is pre allocated with reuse cache or will be padded with bucket internal + # hence for the 1st token we can slice the inputs till token idx for the fwd pass. + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + cache_position = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = { + "input_ids": input_ids.clone(memory_format=torch.contiguous_format) + } # `contiguous()` needed for compilation use cases + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids.contiguous(), + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "token_idx": token_idx, + "trim_logits": kwargs.get("trim_logits"), + "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), + "reuse_cache": reuse_cache, + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), + "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"), + "valid_sequence_lengths": kwargs.get("valid_sequence_lengths"), + "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), + "num_virtual_tokens": kwargs.get("num_virtual_tokens"), + } + ) + return model_inputs + + +def apply_customized_rope(q, k, cos, sin, position_ids, training=True): + if q.device.type == "hpu" and has_fused_rope: + return apply_customized_rope_module(q, k, cos, sin, position_ids, training) + else: + # keep the same implementation as Transformers v4.37.2 + return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids]) + \ No newline at end of file From 9ff70d6ace965672ab9a491449e90615629ade86 Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Thu, 8 May 2025 17:27:43 +0800 Subject: [PATCH 02/14] fix qwen3 related files --- README.md | 2 +- docs/source/index.mdx | 2 +- tests/test_text_generation_example.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c5e52fd56d..2f4e8836de 100644 --- a/README.md +++ b/README.md @@ -250,7 +250,7 @@ The following model architectures, tasks and device distributions have been vali | Phi | :heavy_check_mark: |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Mixtral | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Persimmon | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | -| Qwen2 |
  • Single card
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| Qwen2 / Qwen3 |
  • Single card
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Qwen2-MoE | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Gemma | :heavy_check_mark: |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Gemma2 | | :heavy_check_mark: |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 81c73b0c6a..0836ec2004 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -79,7 +79,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | Mixtral | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Gemma | ✅ |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Gemma2 | | ✅ |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | -| Qwen2 |
  • Single card
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| Qwen2 / Qwen3 |
  • Single card
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Qwen2-MoE | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Persimmon | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | XGLM | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 701a5c49d5..52aeb94066 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -62,6 +62,7 @@ ("THUDM/chatglm3-6b", 1, True, False), ("Qwen/Qwen2.5-7B", 4, False, False), ("moonshotai/Moonlight-16B-A3B", 1, False, False), + ("Qwen/Qwen3-8B", 1, False, False), ], "fp8": [ pytest.param("tiiuae/falcon-180B", 4, 950, True, 128, 128, marks=pytest.mark.x4), From 1b8896882586ca806bc2904916cb1543f7f47a1e Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Thu, 22 May 2025 10:54:53 +0800 Subject: [PATCH 03/14] Update modeling_qwen3.py --- .../models/qwen3/modeling_qwen3.py | 49 +++++-------------- 1 file changed, 12 insertions(+), 37 deletions(-) diff --git a/optimum/habana/transformers/models/qwen3/modeling_qwen3.py b/optimum/habana/transformers/models/qwen3/modeling_qwen3.py index 80c9e8eb8f..e1437db7d3 100644 --- a/optimum/habana/transformers/models/qwen3/modeling_qwen3.py +++ b/optimum/habana/transformers/models/qwen3/modeling_qwen3.py @@ -16,6 +16,7 @@ # Copyright (C) 2022-2024 Habana Labs, Ltd. an Intel Company ############################################################################### +from functools import partial from typing import List, Optional, Tuple, Union import torch @@ -486,21 +487,7 @@ def pre_attn_forward( if use_flash_attention and FusedSDPA is not None: attn_weights = None - if self.training: - attn_output = fused_scaled_dot_product_attention( - query_states, - key_states, - value_states, - attention_mask, - 0.0, - False, - None, - "None", - False, - None, - "None", - ) - elif q_len == 1: + if q_len == 1: # next token attn_output = fused_scaled_dot_product_attention( query_states, @@ -510,14 +497,14 @@ def pre_attn_forward( 0.0, False, None, - "fp32", + "None", False, None, "None", ) else: # first token - softmax_mode = "fp32" + softmax_mode = "fast" if flash_attention_fast_softmax else "None" if flash_attention_causal_mask: attn_output = fused_scaled_dot_product_attention( query_states, @@ -769,7 +756,7 @@ def update_sincos_cache(self, seq_len): def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -777,7 +764,6 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, @@ -790,15 +776,14 @@ def forward( cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + **kwargs, + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -888,7 +873,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **kwargs), hidden_states, causal_mask, position_ids, @@ -947,8 +932,6 @@ def forward( next_cache = ( next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache ) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -969,7 +952,7 @@ def update_sincos_cache(self, seq_len): def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -978,7 +961,6 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, token_idx: Optional[torch.Tensor] = None, @@ -994,18 +976,17 @@ def forward( lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.generation_config.use_fused_rope is False: global has_fused_rope has_fused_rope = False # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1014,7 +995,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, @@ -1029,7 +1009,7 @@ def forward( num_virtual_tokens=num_virtual_tokens, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state _, seq_len, _ = hidden_states.shape if seq_len > 1 and trim_logits and not self.training: if token_idx is not None: @@ -1045,10 +1025,6 @@ def forward( if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -1163,4 +1139,3 @@ def apply_customized_rope(q, k, cos, sin, position_ids, training=True): else: # keep the same implementation as Transformers v4.37.2 return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids]) - \ No newline at end of file From 0f93054927ea3520c54151cfbc7413e0363daea8 Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Wed, 28 May 2025 16:27:47 +0800 Subject: [PATCH 04/14] add qwen3 moe --- .../transformers/models/qwen3_moe/__init__.py | 9 + .../models/qwen3_moe/modeling_qwen3_moe.py | 1243 +++++++++++++++++ 2 files changed, 1252 insertions(+) create mode 100644 optimum/habana/transformers/models/qwen3_moe/__init__.py create mode 100755 optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py diff --git a/optimum/habana/transformers/models/qwen3_moe/__init__.py b/optimum/habana/transformers/models/qwen3_moe/__init__.py new file mode 100644 index 0000000000..208c4a5b94 --- /dev/null +++ b/optimum/habana/transformers/models/qwen3_moe/__init__.py @@ -0,0 +1,9 @@ +from .modeling_qwen3_moe import ( + GaudiQwen3MoeAttention, + GaudiQwen3MoeDecoderLayer, + GaudiQwen3MoeForCausalLM, + GaudiQwen3MoeMLP, + GaudiQwen3MoeModel, + gaudi_qwen3moe_block_sparse_moe_forward, + gaudi_qwen3moe_rmsnorm_forward, +) diff --git a/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py new file mode 100755 index 0000000000..f2325e5a80 --- /dev/null +++ b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -0,0 +1,1243 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3MoE model.""" + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import habana_frameworks.torch.core as htcore +import torch +import torch.nn.functional as F +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.integrations.deepspeed import is_deepspeed_available +from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeAttention, + Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeMLP, + Qwen3MoeModel, + Qwen3MoeRMSNorm, + Qwen3MoeSparseMoeBlock, + apply_rotary_pos_emb, + load_balancing_loss_func, +) +from transformers.utils import logging + +from ...modeling_attn_mask_utils import ( + _gaudi_prepare_4d_causal_attention_mask, +) +from ...modeling_rope_utils import GaudiRotaryEmbedding + + +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + + has_fused_rope = True +except ImportError: + has_fused_rope = False + print("Not using HPU fused kernel for apply_rotary_pos_emb") + +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm + + has_fused_rms_norm = True +except ImportError: + has_fused_rms_norm = False + print("Not using HPU fused kernel for RMSNorm") + +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + +logger = logging.get_logger(__name__) + + +def apply_customized_rope(q, k, cos, sin, position_ids): + if q.device.type == "hpu" and has_fused_rope: + # TODO: remove `.clone()` when it is fixed in SynapseAI + if k.dtype == torch.bfloat16: + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, + cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), + position_ids, + ) + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ) + else: + # keep the same implementation as Transformers v4.37.2 + return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids]) + + +def gaudi_qwen3moe_rmsnorm_forward(self, hidden_states): + """ + Copied from MixtralRMSNorm.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - override RMSNorm with Habana fused RMSNorm + """ + if hidden_states.device.type == "hpu" and has_fused_rms_norm: + # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype + if hidden_states.dtype != self.weight.dtype: + orig_dtype = hidden_states.dtype + hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) + return hidden_states.to(orig_dtype) + else: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class GaudiQwen3MoeMLP(Qwen3MoeMLP): + def pre_mlp_forward(self, x): + input = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + output = self.down_proj(input) + return output + + def mlp_all_reduce(self, x): + if hasattr(self.down_proj, "all_reduce"): + self.down_proj.all_reduce(x) + + def post_mlp_forward(self, x): + if hasattr(self.down_proj, "post_all_reduce"): + return self.down_proj.post_all_reduce(x) + return x + + +def gaudi_qwen3moe_repeat_kv( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): + """ + Copied from repeat_kv: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. + - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. + The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) + The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) + """ + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, q_heads, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask + + +class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert self.inp_seq_len == inp_seq_len, ( + f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + ) + self.cache.fill_(0) + + @staticmethod + def update(prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + + +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + self.scale = scale + self.attention_dropout = attention_dropout + self.enable_recompute = enable_recompute + self.flash_attention_fp8 = flash_attention_fp8 + + def forward( + self, + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side="left", + ): + return self._hpu_kernel_fsdpa.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) + + +class GaudiQwen3MoeAttention(Qwen3MoeAttention): + def __init__(self, config: Qwen3MoeConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + + self.matmul_qk = Matmul() + self.matmul_av = Matmul() + self.k_cache = KVCache() + self.v_cache = KVCache() + + self.inp_seq_len = -1 + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + + self.rotary_emb = GaudiRotaryEmbedding(config=self.config) + + self.fused_scaled_dot_product_attention = ( + ModuleFusedSDPA( + FusedSDPA, + scale=self.norm_factor, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=getattr(config, "flash_attention_fp8", False), + ) + if FusedSDPA + else None + ) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + + def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + + def reorder(self, tensor, beam_idx, dim_a, dim_b): + updated = tensor.index_select(0, beam_idx) + tensor.copy_(updated) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + if self.k_cache.cache is None: + return (None, None) + + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) + + def pre_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, + cache_idx: int = None, + num_virtual_tokens: int = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - add new args token_idx + - optimize KV cache + - add new args attn_softmax_bf16 + - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask + - add new arg flash_attention_fast_softmax + - add new arg num_virtual_tokens + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] + else: + if reuse_cache and not isinstance(past_key_value[0], torch.Tensor): + kv_seq_len = past_key_value[0][-2] + else: + if num_virtual_tokens is not None and num_virtual_tokens == past_key_value[0].shape[-2]: + kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len + else: + kv_seq_len = past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + + if use_cache: + # reuse k, v, self_attention + if reuse_cache: + if past_key_value is not None and isinstance(past_key_value[0], torch.Tensor): + # prefix tuning case. attach past_key_value to generate first token. + key_states = torch.cat((past_key_value[0], key_states), -2) + value_states = torch.cat((past_key_value[1], value_states), -2) + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + # Return list instead of tuple + past_key_value = [past_key, past_value] + if ( + token_idx is not None + and num_virtual_tokens is not None + and num_virtual_tokens == past_key_value[0].shape[-2] + ): + # prefix tuning case. attach past_key_value to generate first token. + key_states = torch.cat((past_key_value[0], key_states), -2) + value_states = torch.cat((past_key_value[1], value_states), -2) + past_key_value = (key_states, value_states) + else: + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + else: + past_key_value = None + + if use_flash_attention and FusedSDPA is not None: + if q_len == 1: + # next token + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + "None", + False, + None, + "None", + ) + else: + # first token + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + if flash_attention_causal_mask: + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + None, + 0.0, + True, + None, + softmax_mode, + flash_attention_recompute, + valid_sequence_lengths, + "left", + ) + else: + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + softmax_mode, + flash_attention_recompute, + None, + "None", + ) + + else: + query_states, key_states, value_states, attention_mask = gaudi_qwen3moe_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) + + query_states = query_states * self.norm_factor + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)).float() + htcore.mark_step() + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask.float() + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = self.matmul_av(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1: + # Return only past key value shapes and not the tensors during decode phase (q len is 1) + # to avoid making past key values as persistent output tensors of HPU graphs. + past_key_value = (past_key_value[0].shape, past_key_value[1].shape) + + return attn_output, attn_weights, past_key_value + + def attention_all_reduce(self, attn_output): + if hasattr(self.o_proj, "all_reduce"): + self.o_proj.all_reduce(attn_output) + + def post_attn_forward(self, attn_output): + if hasattr(self.o_proj, "post_all_reduce"): + return self.o_proj.post_all_reduce(attn_output) + return attn_output + + +def gaudi_qwen3moe_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + - optimize expert forward, remove dynamic control and dynamic shape + """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + if self.training: + final_hidden_states = torch.zeros( + (batch_size, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + padded_weights = torch.zeros( + (batch_size * sequence_length, self.num_experts), dtype=hidden_states.dtype, device=hidden_states.device + ) + padded_weights.scatter_(-1, selected_experts, routing_weights) + padded_weights = padded_weights.reshape(-1, sequence_length, self.num_experts) + padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + padded_weight = padded_weights[expert_idx] + current_state_static = hidden_states.reshape(-1, hidden_dim) + current_hidden_states_static = ( + expert_layer.pre_mlp_forward(current_state_static).reshape(-1, sequence_length, hidden_dim) + * padded_weight + ) + final_hidden_states = final_hidden_states + current_hidden_states_static + else: + experts_range = range(self.num_experts) + w1_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range] + w2_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range] + w3_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range] + + final_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=selected_experts, + router_weights=routing_weights, + w1=w1_list, + w2=w3_list, # Note that there is a different naming convention of w1, w2, and w3 between optimum habana's mixtral model and dynamic MoE kernel. + w3=w2_list, + permuted_weights=True, + activation="silu", + experts_min=0, + experts_max=(self.num_experts - 1), + ) + final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim) + + if is_deepspeed_available(): + from deepspeed import comm as dist + + if dist.is_initialized(): + dist.all_reduce(final_hidden_states, op=dist.ReduceOp.SUM) + + shared_expert_output = self.shared_expert(hidden_states) + + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + + shared_expert_output = shared_expert_output.reshape(-1, sequence_length, hidden_dim) + + final_hidden_states = final_hidden_states + shared_expert_output + + return final_hidden_states, router_logits + + +class GaudiQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super(Qwen3MoeDecoderLayer, self).__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GaudiQwen3MoeAttention(config=config, layer_idx=layer_idx) + + if config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0: + self.mlp = Qwen3MoeSparseMoeBlock(config) + else: + self.mlp = GaudiQwen3MoeMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.self_attn.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.self_attn.update_sincos_cache(seq_len) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, + cache_idx: int = None, + num_virtual_tokens: int = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - add new args token_idx + - add new args attn_softmax_bf16 + - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask + - add new arg flash_attention_fast_softmax + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + hidden_states, self_attn_weights, present_key_value = self.pre_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + **kwargs, + ) + + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual, router_logits = self.post_attn_pre_mlp(hidden_states, residual) + + hidden_states = self.post_mlp(hidden_states, residual) + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + if output_router_logits: + outputs += (router_logits,) + + return outputs + + def pre_attn( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, + cache_idx: int = None, + num_virtual_tokens: int = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + hidden_states = self.input_layernorm(hidden_states) + hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + **kwargs, + ) + return hidden_states, attn_weights, present_key_value + + def post_attn_pre_mlp(self, hidden_states, residual): + hidden_states = self.self_attn.post_attn_forward(hidden_states) + + if self.training: + hidden_states = hidden_states + residual + residual = hidden_states + else: + residual.add_(hidden_states) + hidden_states = residual + + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = None + + return hidden_states, residual, router_logits + + def post_mlp(self, hidden_states, residual): + if self.training: + hidden_states = hidden_states + residual + else: + residual.add_(hidden_states) + hidden_states = residual + + return hidden_states + + +class GaudiQwen3MoeModel(Qwen3MoeModel): + def __init__(self, config: Qwen3MoeConfig): + super(Qwen3MoeModel, self).__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = torch.nn.ModuleList( + [GaudiQwen3MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) + + def update_sincos_cache(self, seq_len): + for layer in self.layers: + layer.update_sincos_cache(seq_len) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, + num_virtual_tokens: int = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + """ + Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - add new args token_idx + - add new args attn_softmax_bf16 + - add new args reuse_cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask + - add new arg flash_attention_fast_softmax + - add new arg lazy_mode + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + ignore_cache_position = True # Ignoring cache position for HPU + use_new_cache = False # Ignoring new Cache path for HPU + + past_seen_tokens = 0 + + if past_key_values is not None and use_cache: # kept for BC (cache positions) + if reuse_cache: + if isinstance(past_key_values[0][0], torch.Tensor): + past_seen_tokens = past_key_values[0][0].shape[2] + else: + past_seen_tokens = past_key_values[0][0][2] + else: + if use_new_cache: + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + else: + if past_key_values[0] is not None: ##added for (None, None) + past_seen_tokens = past_key_values[0][0].shape[2] + + if ignore_cache_position is False: + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None and cache_position: + position_ids = cache_position.unsqueeze(0) + + else: + if position_ids is None: + position_ids = torch.arange( + past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device + ) + position_ids = position_ids.unsqueeze(0) + cache_position = None + + # HPU specific mask generation + if ignore_cache_position: + causal_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, + ) + else: + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = () if not use_new_cache else None + + if lazy_mode: + htcore.mark_step() + + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + None, + attn_softmax_bf16, + False, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + flash_attention_fast_softmax, + valid_sequence_lengths, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class GaudiQwen3MoeForCausalLM(Qwen3MoeForCausalLM): + """ + Inherits from Qwen3MoeForCausalLM: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1231 + The only differences are: + - add new args token_idx + - add token_idx into model_inputs + - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx + - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx + """ + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + self.kv_cache_len = max_seq_len + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.model.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.model.update_sincos_cache(seq_len) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + token_idx: Optional[torch.Tensor] = None, + trim_logits: Optional[bool] = False, + reuse_cache: Optional[bool] = None, + attn_softmax_bf16: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, + num_virtual_tokens: int = None, + **loss_kwargs, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.generation_config.use_fused_rope is False: + global has_fused_rope + has_fused_rope = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + lazy_mode=lazy_mode, + num_virtual_tokens=num_virtual_tokens, + ) + + hidden_states = outputs[0] + _, seq_len, _ = hidden_states.shape + if seq_len > 1 and trim_logits and not self.training: + if token_idx is not None: + hidden_states = hidden_states.index_select(1, token_idx - 1) + else: + hidden_states = hidden_states[:, -1, :] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + token_idx=None, + **kwargs, + ): + reuse_cache = kwargs.get("reuse_cache") + bucket_internal = kwargs.get("bucket_internal") + + if past_key_values is not None: + if token_idx is not None: + idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 + input_ids = torch.index_select(input_ids, 1, idx) + else: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + elif (reuse_cache or bucket_internal) and token_idx is not None: + # KV cache is pre allocated with reuse cache or will be padded with bucket internal + # hence for the 1st token we can slice the inputs till token idx for the fwd pass. + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # keep cache_position implementation as None for HPU + cache_position = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "token_idx": token_idx, + "trim_logits": kwargs.get("trim_logits"), + "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), + "reuse_cache": reuse_cache, + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), + "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"), + "valid_sequence_lengths": kwargs.get("valid_sequence_lengths"), + "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), + "num_virtual_tokens": kwargs.get("num_virtual_tokens"), + } + ) + return model_inputs From 92166a0c8cfe159563ce3684cc1ed2ac58a4762e Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Wed, 28 May 2025 17:24:50 +0800 Subject: [PATCH 05/14] Update test_text_generation_example.json --- .../fixture/tests/test_text_generation_example.json | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/baselines/fixture/tests/test_text_generation_example.json b/tests/baselines/fixture/tests/test_text_generation_example.json index 299bf50686..ab58ce5405 100644 --- a/tests/baselines/fixture/tests/test_text_generation_example.json +++ b/tests/baselines/fixture/tests/test_text_generation_example.json @@ -99,6 +99,11 @@ "throughput": 633.0694674407139 } }, + "tests/test_text_generation_example.py::test_text_generation_bf16_1x[Qwen/Qwen3-8B-1-False-False]": { + "gaudi2": { + "throughput": 123.06282996640333 + }, + }, "tests/test_text_generation_example.py::test_text_generation_bf16_1x[Salesforce/codegen2-1B-1-False-False]": { "gaudi1": { "throughput": 155.32071248826423 From 44fd14d15580f3d1d86936e59be633d8085aa725 Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Thu, 29 May 2025 17:36:11 +0800 Subject: [PATCH 06/14] Update utils.py --- optimum/habana/transformers/generation/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 1a472cd9a9..f7baebde47 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -132,6 +132,7 @@ "chatglm", "qwen2_vl", "qwen3", + "qwen3_moe", ] # Initial generated token index is set to 1 to accomodate SOS (start of string) token. @@ -1305,8 +1306,9 @@ def generate( "deepseek_v2", "deepseek_v3", "qwen3", + "qwen3_moe", ], ( - "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, qwen3, gemma, gemma2, starcoder2, baichuan, chatglm and deepseek_v2 at the moment" + "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, qwen3, qwen3_moe, gemma, gemma2, starcoder2, baichuan, chatglm and deepseek_v2 at the moment" ) if not generation_config.bucket_internal: assert generation_config.bucket_size <= 0, ( @@ -1520,6 +1522,8 @@ def generate( "qwen2_moe", "baichuan", "deepseek_v2", + "qwen3", + "qwen3_moe", ]: if ( hasattr(self.config, "max_position_embeddings") From 73b835f7fdda83c0c3f112cb140082cb5a119a14 Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Tue, 10 Jun 2025 13:29:38 +0800 Subject: [PATCH 07/14] update qwen3 moe --- .../transformers/models/qwen3_moe/__init__.py | 2 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 466 +++++++++--------- 2 files changed, 245 insertions(+), 223 deletions(-) mode change 100755 => 100644 optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py diff --git a/optimum/habana/transformers/models/qwen3_moe/__init__.py b/optimum/habana/transformers/models/qwen3_moe/__init__.py index 208c4a5b94..c45313c231 100644 --- a/optimum/habana/transformers/models/qwen3_moe/__init__.py +++ b/optimum/habana/transformers/models/qwen3_moe/__init__.py @@ -4,6 +4,6 @@ GaudiQwen3MoeForCausalLM, GaudiQwen3MoeMLP, GaudiQwen3MoeModel, - gaudi_qwen3moe_block_sparse_moe_forward, + GaudiQwen3MoeSparseMoeBlock, gaudi_qwen3moe_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py old mode 100755 new mode 100644 index f2325e5a80..979b73d7be --- a/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -19,16 +19,24 @@ # limitations under the License. """PyTorch Qwen3MoE model.""" -import math -import warnings +from functools import partial from typing import List, Optional, Tuple, Union -import habana_frameworks.torch.core as htcore import torch import torch.nn.functional as F +from torch import nn + from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.integrations.deepspeed import is_deepspeed_available -from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig from transformers.models.qwen3_moe.modeling_qwen3_moe import ( Qwen3MoeAttention, @@ -40,13 +48,16 @@ Qwen3MoeSparseMoeBlock, apply_rotary_pos_emb, load_balancing_loss_func, + logger, ) -from transformers.utils import logging +from transformers.processing_utils import Unpack +from ....distributed import parallel_state from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) from ...modeling_rope_utils import GaudiRotaryEmbedding +from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module try: @@ -71,37 +82,18 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None -logger = logging.get_logger(__name__) +import habana_frameworks.torch.core as htcore -def apply_customized_rope(q, k, cos, sin, position_ids): +def apply_customized_rope(q, k, cos, sin, position_ids, training=True): if q.device.type == "hpu" and has_fused_rope: - # TODO: remove `.clone()` when it is fixed in SynapseAI - if k.dtype == torch.bfloat16: - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, - cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - position_ids, - ) - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ) + return apply_customized_rope_module(q, k, cos, sin, position_ids, training) else: # keep the same implementation as Transformers v4.37.2 return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids]) def gaudi_qwen3moe_rmsnorm_forward(self, hidden_states): - """ - Copied from MixtralRMSNorm.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py - The only differences are: - - override RMSNorm with Habana fused RMSNorm - """ if hidden_states.device.type == "hpu" and has_fused_rms_norm: # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype if hidden_states.dtype != self.weight.dtype: @@ -142,14 +134,6 @@ def gaudi_qwen3moe_repeat_kv( attention_mask: torch.Tensor, n_rep: int, ): - """ - Copied from repeat_kv: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py - The only differences are: - - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. - - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. - The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) - The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) - """ batch, num_key_value_heads, kv_len, head_dim = key_states.shape if n_rep == 1 or num_key_value_heads == 1: return query_states, key_states, value_states, attention_mask @@ -169,55 +153,6 @@ def gaudi_qwen3moe_repeat_kv( return query_states, key_states, value_states, attention_mask -class Matmul(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.matmul(x, y) - - -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert self.inp_seq_len == inp_seq_len, ( - f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - ) - self.cache.fill_(0) - - @staticmethod - def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) - - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len) - - # FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): @@ -257,6 +192,109 @@ def forward( ) +def gaudi_eager_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + attn_softmax_bf16: bool = False, + **kwargs, +): + bsz, q_len = kwargs["input_shape"] + query_states, key_states, value_states, attention_mask = gaudi_qwen3moe_repeat_kv( + query, key, value, attention_mask, module.num_key_value_groups + ) + + query_states = query_states * scaling + attn_weights = module.matmul_qk(query_states, key_states.transpose(-2, -1)).float() + htcore.mark_step() + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = module.matmul_av(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, module.head_dim) + + return attn_output, attn_weights + + +class GaudiDistributedAttention(torch.nn.Module): + def __init__( + self, hpu_module_fsdpa: ModuleFusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8 + ): + super().__init__() + self._hpu_module_fsdpa = hpu_module_fsdpa + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + from deepspeed.sequence.layer import DistributedAttention + + self._hpu_module_fsdpa_distributed = DistributedAttention( + self._hpu_module_fsdpa, parallel_state.get_sequence_parallel_group(), 1, 2 + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor, + dropout_p: float, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side="left", + ): + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + return self._hpu_module_fsdpa_distributed( + query, + key, + value, + 0, # As the shape for inputs is [B, N, S, H] + None, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) + else: + return self._hpu_module_fsdpa( + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) + + +def get_gaudi_distributed_attention( + fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed +): + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + return fused_scaled_dot_product_attention_distributed + else: + return fused_scaled_dot_product_attention + + class GaudiQwen3MoeAttention(Qwen3MoeAttention): def __init__(self, config: Qwen3MoeConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) @@ -267,14 +305,13 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: Optional[int] = None): self.v_cache = KVCache() self.inp_seq_len = -1 - self.norm_factor = 1.0 / math.sqrt(self.head_dim) self.rotary_emb = GaudiRotaryEmbedding(config=self.config) self.fused_scaled_dot_product_attention = ( ModuleFusedSDPA( FusedSDPA, - scale=self.norm_factor, + scale=self.scaling, attention_dropout=self.attention_dropout, enable_recompute=False, flash_attention_fp8=getattr(config, "flash_attention_fp8", False), @@ -282,10 +319,38 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: Optional[int] = None): if FusedSDPA else None ) + self.fused_scaled_dot_product_attention_distributed = None + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + self.fused_scaled_dot_product_attention_distributed = ( + GaudiDistributedAttention( + self.fused_scaled_dot_product_attention, + scale=self.scaling, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=getattr(config, "flash_attention_fp8", False), + ) + if FusedSDPA + else None + ) + + self.num_key_value_heads = config.num_key_value_heads + + def get_k_proj_weight(self): + """4bit quantization in GPTQ replaces the k_proj.weight with qweight.""" + if hasattr(self.k_proj, "qweight"): + return self.k_proj.qweight + return self.k_proj.weight + + def get_k_proj_weight_dtype(self): + """4bit quantization in GPTQ replaces the k_proj.weight with qweight. + Scales tensor gets the weight dtype.""" + if hasattr(self.k_proj, "qweight"): + return self.k_proj.scales.dtype + return self.k_proj.weight.dtype def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) - device = self.k_proj.weight.device + device = self.get_k_proj_weight().device dtype = self.config.torch_dtype self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) @@ -296,7 +361,7 @@ def update_sincos_cache(self, seq_len): # reduce memory consumption and improve performance. if seq_len > self.max_position_embeddings: self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + _, _ = self.rotary_emb(self.get_k_proj_weight(), seq_len=seq_len) def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) @@ -315,13 +380,11 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor): def pre_attn_forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, @@ -335,7 +398,6 @@ def pre_attn_forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ - Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: - add new args token_idx - optimize KV cache @@ -347,15 +409,13 @@ def pre_attn_forward( - add new arg flash_attention_fast_softmax - add new arg num_virtual_tokens """ - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(self.q_proj(hidden_states)) - key_states = self.k_norm(self.k_proj(hidden_states)) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + q_len = input_shape[1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -373,8 +433,25 @@ def pre_attn_forward( else: kv_seq_len = past_key_value[0].shape[-2] + seq_len = kv_seq_len + if parallel_state.sequence_parallel_is_initialized(): + seq_len = kv_seq_len * parallel_state.get_sequence_parallel_world_size() + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + # If sequence parallel in enabled, position_ids should be based on which part of the sequence is present in the rank + # As we divide the inputs based on ranks, position_ids are generated to suit that part of the sequence + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_rank() > 0: + position_ids = torch.arange( + kv_seq_len * parallel_state.get_sequence_parallel_rank(), + kv_seq_len * (parallel_state.get_sequence_parallel_rank() + 1), + dtype=torch.long, + device=query_states.device, + ) + position_ids = position_ids.unsqueeze(0) + + query_states, key_states = apply_customized_rope( + query_states, key_states, cos, sin, kwargs["position_ids"], self.training + ) if use_cache: # reuse k, v, self_attention @@ -418,8 +495,19 @@ def pre_attn_forward( kv_seq_len = key_states.shape[-2] else: past_key_value = None + fused_scaled_dot_product_attention = get_gaudi_distributed_attention( + self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed + ) + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window if use_flash_attention and FusedSDPA is not None: + attn_weights = None if q_len == 1: # next token attn_output = self.fused_scaled_dot_product_attention( @@ -468,46 +556,23 @@ def pre_attn_forward( ) else: - query_states, key_states, value_states, attention_mask = gaudi_qwen3moe_repeat_kv( - query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) - - query_states = query_states * self.norm_factor - attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)).float() - htcore.mark_step() - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask - if cache_position is not None: - causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask.float() - - if attn_softmax_bf16: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) - else: - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = self.matmul_av(attn_weights, value_states) - attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" + attn_output, attn_weights = gaudi_eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + attn_softmax_bf16=attn_softmax_bf16, + input_shape=input_shape, ) attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1: # Return only past key value shapes and not the tensors during decode phase (q len is 1) # to avoid making past key values as persistent output tensors of HPU graphs. @@ -525,47 +590,38 @@ def post_attn_forward(self, attn_output): return attn_output -def gaudi_qwen3moe_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - - optimize expert forward, remove dynamic control and dynamic shape - """ - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) +class GaudiQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): + def __init__(self, config: Qwen3MoeConfig): + super().__init__(config) + self.moe_intermediate_size = config.moe_intermediate_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] + ) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + - optimize expert forward, remove dynamic control and dynamic shape + """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - if self.training: - final_hidden_states = torch.zeros( - (batch_size, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) - padded_weights = torch.zeros( - (batch_size * sequence_length, self.num_experts), dtype=hidden_states.dtype, device=hidden_states.device - ) - padded_weights.scatter_(-1, selected_experts, routing_weights) - padded_weights = padded_weights.reshape(-1, sequence_length, self.num_experts) - padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - padded_weight = padded_weights[expert_idx] - current_state_static = hidden_states.reshape(-1, hidden_dim) - current_hidden_states_static = ( - expert_layer.pre_mlp_forward(current_state_static).reshape(-1, sequence_length, hidden_dim) - * padded_weight - ) - final_hidden_states = final_hidden_states + current_hidden_states_static - else: experts_range = range(self.num_experts) w1_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range] w2_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range] @@ -576,30 +632,23 @@ def gaudi_qwen3moe_block_sparse_moe_forward(self, hidden_states: torch.Tensor) - expert_routing_table=selected_experts, router_weights=routing_weights, w1=w1_list, - w2=w3_list, # Note that there is a different naming convention of w1, w2, and w3 between optimum habana's mixtral model and dynamic MoE kernel. + w2=w3_list, w3=w2_list, permuted_weights=True, activation="silu", experts_min=0, experts_max=(self.num_experts - 1), ) - final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim) + htcore.mark_step() - if is_deepspeed_available(): + if not self.training and is_deepspeed_available() and self.moe_intermediate_size != w1_list[0].size(0): from deepspeed import comm as dist if dist.is_initialized(): dist.all_reduce(final_hidden_states, op=dist.ReduceOp.SUM) - shared_expert_output = self.shared_expert(hidden_states) - - shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output - - shared_expert_output = shared_expert_output.reshape(-1, sequence_length, hidden_dim) - - final_hidden_states = final_hidden_states + shared_expert_output - - return final_hidden_states, router_logits + final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim) + return final_hidden_states, router_logits class GaudiQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): @@ -610,7 +659,7 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): self.self_attn = GaudiQwen3MoeAttention(config=config, layer_idx=layer_idx) if config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0: - self.mlp = Qwen3MoeSparseMoeBlock(config) + self.mlp = GaudiQwen3MoeSparseMoeBlock(config) else: self.mlp = GaudiQwen3MoeMLP(config, intermediate_size=config.intermediate_size) @@ -649,17 +698,6 @@ def forward( num_virtual_tokens: int = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py - The only differences are: - - add new args token_idx - - add new args attn_softmax_bf16 - - add new args reuse_cache - - add new args use_flash_attention - - add new arg flash_attention_recompute - - add new arg flash_attention_causal_mask - - add new arg flash_attention_fast_softmax - """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" @@ -673,6 +711,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, @@ -709,7 +748,7 @@ def pre_attn( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, @@ -806,7 +845,7 @@ def update_sincos_cache(self, seq_len): def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -815,7 +854,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, @@ -828,7 +866,7 @@ def forward( cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: + ) -> MoeModelOutputWithPast: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: @@ -849,7 +887,6 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -998,12 +1035,6 @@ def forward( next_cache = ( next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache ) - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, @@ -1036,7 +1067,7 @@ def update_sincos_cache(self, seq_len): def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -1046,7 +1077,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, token_idx: Optional[torch.Tensor] = None, @@ -1062,7 +1092,7 @@ def forward( lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, **loss_kwargs, - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + ) -> MoeCausalLMOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits @@ -1071,13 +1101,12 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.generation_config.use_fused_rope is False: global has_fused_rope has_fused_rope = False # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1087,7 +1116,6 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, - return_dict=return_dict, cache_position=cache_position, token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, @@ -1102,7 +1130,7 @@ def forward( num_virtual_tokens=num_virtual_tokens, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state _, seq_len, _ = hidden_states.shape if seq_len > 1 and trim_logits and not self.training: if token_idx is not None: @@ -1121,7 +1149,7 @@ def forward( aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], + outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, @@ -1129,12 +1157,6 @@ def forward( if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, From c7e5c9854de70166e7115afd41eb35cd7f0e803d Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Mon, 30 Jun 2025 11:07:00 +0800 Subject: [PATCH 08/14] update qwen3moe related files --- optimum/habana/transformers/modeling_utils.py | 16 ++++++++++++++++ optimum/habana/transformers/models/__init__.py | 9 +++++++++ 2 files changed, 25 insertions(+) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index de638ee4b2..6e5ea672de 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -170,6 +170,12 @@ GaudiQwen3ForCausalLM, GaudiQwen3MLP, GaudiQwen3Model, + GaudiQwen3MoeAttention, + GaudiQwen3MoeDecoderLayer, + GaudiQwen3MoeForCausalLM, + GaudiQwen3MoeMLP, + GaudiQwen3MoeModel, + GaudiQwen3MoeSparseMoeBlock, GaudiSiglipAttention, GaudiSiglipEncoder, GaudiSiglipEncoderLayer, @@ -269,6 +275,7 @@ gaudi_qwen2moe_block_sparse_moe_forward, gaudi_qwen2moe_rmsnorm_forward, gaudi_qwen3_rmsnorm_forward, + gaudi_qwen3moe_rmsnorm_forward, gaudi_rot_matmul, gaudi_rot_vec_mul, gaudi_SeamlessM4TAttention_forward, @@ -721,6 +728,15 @@ def adapt_transformers_to_gaudi(): transformers.models.qwen3.modeling_qwen3.Qwen3DecoderLayer = GaudiQwen3DecoderLayer transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm.forward = gaudi_qwen3_rmsnorm_forward + # Optimization for qwen3Moe on Gaudi + transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeForCausalLM = GaudiQwen3MoeForCausalLM + transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeModel = GaudiQwen3MoeModel + transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeAttention = GaudiQwen3MoeAttention + transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP = GaudiQwen3MoeMLP + transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeDecoderLayer = GaudiQwen3MoeDecoderLayer + transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock = GaudiQwen3MoeSparseMoeBlock + transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm.forward = gaudi_qwen3moe_rmsnorm_forward + # Optimization for stablelm on Gaudi transformers.models.stablelm.modeling_stablelm.StableLmAttention = GaudiStableLmAttention transformers.models.stablelm.modeling_stablelm.StableLmDecoderLayer = GaudiStableLmDecoderLayer diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 3222667990..23b029d29a 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -279,6 +279,15 @@ GaudiQwen3Model, gaudi_qwen3_rmsnorm_forward, ) +from .qwen3_moe import ( + GaudiQwen3MoeAttention, + GaudiQwen3MoeDecoderLayer, + GaudiQwen3MoeForCausalLM, + GaudiQwen3MoeMLP, + GaudiQwen3MoeModel, + GaudiQwen3MoeSparseMoeBlock, + gaudi_qwen3moe_rmsnorm_forward, +) from .seamless_m4t import ( gaudi_SeamlessM4TAttention_forward, gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths, From c29f66fa5047ea4f6d343d78f841dd5be1b4981b Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Mon, 30 Jun 2025 12:22:50 +0800 Subject: [PATCH 09/14] update --- .../baselines/fixture/tests/test_text_generation_example.json | 4 ++-- tests/test_text_generation_example.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/baselines/fixture/tests/test_text_generation_example.json b/tests/baselines/fixture/tests/test_text_generation_example.json index 9485ef4cde..0ca9b03b3f 100644 --- a/tests/baselines/fixture/tests/test_text_generation_example.json +++ b/tests/baselines/fixture/tests/test_text_generation_example.json @@ -101,8 +101,8 @@ }, "tests/test_text_generation_example.py::test_text_generation_bf16_1x[Qwen/Qwen3-8B-1-False-False]": { "gaudi2": { - "throughput": 123.06282996640333 - }, + "throughput": 101.78595453711921 + } }, "tests/test_text_generation_example.py::test_text_generation_bf16_1x[Salesforce/codegen2-1B-1-False-False]": { "gaudi1": { diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 52aeb94066..6537ebbb23 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -63,6 +63,7 @@ ("Qwen/Qwen2.5-7B", 4, False, False), ("moonshotai/Moonlight-16B-A3B", 1, False, False), ("Qwen/Qwen3-8B", 1, False, False), + ("Qwen/Qwen3-30B-A3B", 1, False, False), ], "fp8": [ pytest.param("tiiuae/falcon-180B", 4, 950, True, 128, 128, marks=pytest.mark.x4), From 576c5e9a344125ed72601660f9d8153daf8bd0e7 Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Mon, 30 Jun 2025 12:34:23 +0800 Subject: [PATCH 10/14] Update test_text_generation_example.json --- .../fixture/tests/test_text_generation_example.json | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/baselines/fixture/tests/test_text_generation_example.json b/tests/baselines/fixture/tests/test_text_generation_example.json index 0ca9b03b3f..fefcb33b8f 100644 --- a/tests/baselines/fixture/tests/test_text_generation_example.json +++ b/tests/baselines/fixture/tests/test_text_generation_example.json @@ -104,6 +104,11 @@ "throughput": 101.78595453711921 } }, + "tests/test_text_generation_example.py::test_text_generation_bf16_1x[Qwen/Qwen3-30B-A3B-1-False-False]": { + "gaudi2": { + "throughput": 23.27712445319976 + } + }, "tests/test_text_generation_example.py::test_text_generation_bf16_1x[Salesforce/codegen2-1B-1-False-False]": { "gaudi1": { "throughput": 155.32071248826423 From c6befe451233f1c36ab29e43d1330cbfe4e5c7ab Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Tue, 1 Jul 2025 20:18:34 +0800 Subject: [PATCH 11/14] Update modeling_utils.py --- optimum/habana/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 6e5ea672de..529d07947d 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -719,7 +719,7 @@ def adapt_transformers_to_gaudi(): transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration = ( GaudiQwen2VLForConditionalGeneration ) - + # Optimization for qwen3 on Gaudi transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM = GaudiQwen3ForCausalLM transformers.models.qwen3.modeling_qwen3.Qwen3Model = GaudiQwen3Model @@ -736,7 +736,7 @@ def adapt_transformers_to_gaudi(): transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeDecoderLayer = GaudiQwen3MoeDecoderLayer transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock = GaudiQwen3MoeSparseMoeBlock transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm.forward = gaudi_qwen3moe_rmsnorm_forward - + # Optimization for stablelm on Gaudi transformers.models.stablelm.modeling_stablelm.StableLmAttention = GaudiStableLmAttention transformers.models.stablelm.modeling_stablelm.StableLmDecoderLayer = GaudiStableLmDecoderLayer From 8c27ba14c47b79779d6bc6dfd5524a5148c01e3d Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Wed, 2 Jul 2025 10:26:57 +0800 Subject: [PATCH 12/14] Update modeling_qwen3_moe.py --- .../transformers/models/qwen3_moe/modeling_qwen3_moe.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 979b73d7be..3d71391466 100644 --- a/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -19,23 +19,16 @@ # limitations under the License. """PyTorch Qwen3MoE model.""" -from functools import partial from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn - from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.integrations.deepspeed import is_deepspeed_available from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, ) from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig from transformers.models.qwen3_moe.modeling_qwen3_moe import ( @@ -50,7 +43,6 @@ load_balancing_loss_func, logger, ) -from transformers.processing_utils import Unpack from ....distributed import parallel_state from ...modeling_attn_mask_utils import ( From 9ced9f3683d35631344bff1120a45c7ba04ffcc2 Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Wed, 2 Jul 2025 14:14:50 +0800 Subject: [PATCH 13/14] Update modeling_qwen3_moe.py make style related changes --- .../models/qwen3_moe/modeling_qwen3_moe.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 3d71391466..dc95b0c4d9 100644 --- a/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -18,7 +18,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Qwen3MoE model.""" - + +import warnings from typing import List, Optional, Tuple, Union import torch @@ -53,7 +54,7 @@ try: - from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa has_fused_rope = True except ImportError: @@ -134,7 +135,7 @@ def gaudi_qwen3moe_repeat_kv( key_states = key_states.reshape(new_kv_shape) value_states = value_states.reshape(new_kv_shape) - batch, q_heads, q_len, head_dim = query_states.shape + batch, _, q_len, head_dim = query_states.shape new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) query_states = query_states.reshape(new_q_shape) @@ -145,7 +146,7 @@ def gaudi_qwen3moe_repeat_kv( return query_states, key_states, value_states, attention_mask -# FusedScaledDotProductAttention +# FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): super().__init__() @@ -429,7 +430,7 @@ def pre_attn_forward( if parallel_state.sequence_parallel_is_initialized(): seq_len = kv_seq_len * parallel_state.get_sequence_parallel_world_size() - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, seq_len=seq_len) # If sequence parallel in enabled, position_ids should be based on which part of the sequence is present in the rank # As we divide the inputs based on ranks, position_ids are generated to suit that part of the sequence if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_rank() > 0: @@ -457,9 +458,11 @@ def pre_attn_forward( past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: - past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_key = torch.zeros( + key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device + ) past_value = torch.zeros( - key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device ) # Return list instead of tuple past_key_value = [past_key, past_value] @@ -502,7 +505,7 @@ def pre_attn_forward( attn_weights = None if q_len == 1: # next token - attn_output = self.fused_scaled_dot_product_attention( + attn_output = fused_scaled_dot_product_attention( query_states, key_states, value_states, @@ -519,7 +522,7 @@ def pre_attn_forward( # first token softmax_mode = "fast" if flash_attention_fast_softmax else "None" if flash_attention_causal_mask: - attn_output = self.fused_scaled_dot_product_attention( + attn_output = fused_scaled_dot_product_attention( query_states, key_states, value_states, @@ -533,7 +536,7 @@ def pre_attn_forward( "left", ) else: - attn_output = self.fused_scaled_dot_product_attention( + attn_output = fused_scaled_dot_product_attention( query_states, key_states, value_states, From 2f1b68716579cff2cde19089d70f0a25ff9a2747 Mon Sep 17 00:00:00 2001 From: tianyuan211 Date: Wed, 2 Jul 2025 14:19:24 +0800 Subject: [PATCH 14/14] Update modeling_qwen3_moe.py --- .../transformers/models/qwen3_moe/modeling_qwen3_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py index dc95b0c4d9..86c703c738 100644 --- a/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Qwen3MoE model.""" - + import warnings from typing import List, Optional, Tuple, Union @@ -458,8 +458,8 @@ def pre_attn_forward( past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: - past_key = torch.zeros( - key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device + past_key = torch.zeros( + key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device ) past_value = torch.zeros( key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device