Skip to content

Clean code in mistral.py #2535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
except:
MistralSdpaAttention = MistralAttention
MistralFlashAttention2 = MistralAttention
pass

from unsloth_zoo.utils import Version, _get_dtype


Expand All @@ -52,7 +52,7 @@ def MistralAttention_fast_forward(
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

# Clear inference
# Clear inference-related cached attributes.
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
Expand All @@ -61,7 +61,7 @@ def MistralAttention_fast_forward(
del self.temp_KV
del self.RH_Q
del self.attention
pass


bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -90,12 +90,12 @@ def MistralAttention_fast_forward(
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass


if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass

past_key_value = (K, V) if use_cache else None

# Attention module
Expand All @@ -122,7 +122,7 @@ def MistralAttention_fast_forward(
Q = Q.view(1, Q_M, n_heads, head_dim)
K = K.view(1, K_M, n_heads, head_dim)
V = V.view(1, V_M, n_heads, head_dim)
pass

else:
# Xformers does support the forward pass though
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
Expand All @@ -131,8 +131,8 @@ def MistralAttention_fast_forward(
Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim)
K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
pass
pass



A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
Expand All @@ -152,7 +152,7 @@ def MistralAttention_fast_forward(
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
# pass
#
# Must be contiguous or else results are False!
# https://github.com/pytorch/pytorch/issues/112577
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
Expand All @@ -161,13 +161,13 @@ def MistralAttention_fast_forward(
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass


attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
pass



def MistralForCausalLM_fast_forward(
Expand Down Expand Up @@ -199,7 +199,7 @@ def MistralForCausalLM_fast_forward(
causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\
.from_seqlens([q_len]*bsz)\
.make_local_attention(window_size = sliding_window)
pass


output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -231,7 +231,7 @@ def MistralForCausalLM_fast_forward(
output_hidden_states = output_hidden_states,
return_dict = return_dict,
)
pass


hidden_states = outputs[0]

Expand All @@ -255,7 +255,7 @@ def MistralForCausalLM_fast_forward(
hidden_states = outputs.hidden_states,
attentions = outputs.attentions,
)
pass


if bsz == 1 and q_len == 1:
logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
Expand Down Expand Up @@ -290,9 +290,9 @@ def MistralForCausalLM_fast_forward(
attentions = outputs.attentions,
)
return output
pass

logits = self.lm_head(hidden_states.to(lm_head.dtype))
pass

logits = logits.to(_get_dtype(self.config.torch_dtype))

loss = None
Expand All @@ -301,7 +301,7 @@ def MistralForCausalLM_fast_forward(
# if not hasattr(self, "extra_ignored_labels"):
# # Fixes https://github.com/unslothai/unsloth/issues/10
# self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0")
# pass
#
# shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
shift_labels = torch.empty_like(labels)
shift_labels[..., :-1] = labels[..., 1:]
Expand All @@ -311,7 +311,7 @@ def MistralForCausalLM_fast_forward(
labels = shift_labels,
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None),
)
pass


if not return_dict:
output = (logits,) + outputs[1:]
Expand All @@ -324,7 +324,7 @@ def MistralForCausalLM_fast_forward(
hidden_states = outputs.hidden_states,
attentions = outputs.attentions,
)
pass



# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
Expand All @@ -342,7 +342,7 @@ def patch_mistral_nemo_attention(function):
"self.o_proj = nn.Linear(self.config.num_attention_heads * self.head_dim, self.config.hidden_size, bias=False)",
)
return function
pass



class FastMistralModel(FastLlamaModel):
Expand All @@ -361,7 +361,7 @@ def pre_patch():
# if True:#init_name is not None:
exec(function, globals())
MistralAttention.__init__ = eval(init_name)
pass

MistralAttention .forward = MistralAttention_fast_forward
MistralSdpaAttention .forward = MistralAttention_fast_forward
MistralFlashAttention2.forward = MistralAttention_fast_forward
Expand All @@ -373,13 +373,13 @@ def pre_patch():

# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.mistral.modeling_mistral
transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding = LlamaRotaryEmbedding
return
pass



@staticmethod
Expand Down Expand Up @@ -411,5 +411,4 @@ def from_pretrained(
trust_remote_code = trust_remote_code,
**kwargs,
)
pass
pass