Skip to content

Commit ae82d4d

Browse files
authored
Merge branch 'oobabooga:main' into DualModel
2 parents c0a92a7 + ae8cd44 commit ae82d4d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

modules/exllamav2_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ def __call__(self, *args, **kwargs):
112112
if len(seq_tensor) > 1:
113113
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
114114

115-
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device)
115+
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()
116116
else:
117117
ex_cache.current_seq_len = 0
118-
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras)
118+
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()
119119

120120
if is_negative:
121121
self.past_seq_negative = seq_tensor

0 commit comments

Comments
 (0)