Skip to content

Commit ae8cd44

Browse files
authored
ExLlamav2_HF: Convert logits to FP32 (oobabooga#4310)
1 parent c0ffb77 commit ae8cd44

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

modules/exllamav2_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ def __call__(self, *args, **kwargs):
108108
if len(seq_tensor) > 1:
109109
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
110110

111-
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device)
111+
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()
112112
else:
113113
ex_cache.current_seq_len = 0
114-
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras)
114+
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()
115115

116116
if is_negative:
117117
self.past_seq_negative = seq_tensor

0 commit comments

Comments
 (0)