3
3
import torch .nn .functional as F
4
4
from typing import Tuple
5
5
6
+
6
7
def sequence_mask (length , max_length = None ):
7
8
if max_length is None :
8
9
max_length = length .max ()
@@ -40,7 +41,7 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
40
41
41
42
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
42
43
def top_k_top_p_filtering (
43
- logits , top_k = 0 , top_p = 1.0 , filter_value = - float ("Inf" ), min_tokens_to_keep = 1
44
+ logits , top_k = 0 , top_p = 1.0 , filter_value = - float ("Inf" ), min_tokens_to_keep = 1
44
45
):
45
46
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
46
47
Args:
@@ -100,66 +101,67 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
100
101
101
102
102
103
def multinomial_sample_one_no_sync (
103
- probs_sort ,
104
+ probs_sort ,
104
105
): # Does multinomial sampling without a cuda synchronization
105
106
q = torch .empty_like (probs_sort ).exponential_ (1 )
106
107
return torch .argmax (probs_sort / q , dim = - 1 , keepdim = True ).to (dtype = torch .int )
107
108
108
109
109
110
def logits_to_probs (
110
- logits ,
111
- previous_tokens : Optional [torch .Tensor ] = None ,
112
- temperature : float = 1.0 ,
113
- top_k : Optional [int ] = None ,
114
- top_p : Optional [float ] = None ,
115
- repetition_penalty : float = 1.0 ,
111
+ logits ,
112
+ previous_tokens : Optional [torch .Tensor ] = None ,
113
+ temperature : float = 1.0 ,
114
+ top_k : Optional [int ] = None ,
115
+ top_p : Optional [int ] = None ,
116
+ repetition_penalty : float = 1.0 ,
116
117
):
117
- if previous_tokens is not None :
118
- previous_tokens = previous_tokens .squeeze ()
118
+ # if previous_tokens is not None:
119
+ # previous_tokens = previous_tokens.squeeze()
119
120
# print(logits.shape,previous_tokens.shape)
120
121
# pdb.set_trace()
121
122
if previous_tokens is not None and repetition_penalty != 1.0 :
122
123
previous_tokens = previous_tokens .long ()
123
- score = torch .gather (logits , dim = 0 , index = previous_tokens )
124
+ score = torch .gather (logits , dim = 1 , index = previous_tokens )
124
125
score = torch .where (
125
126
score < 0 , score * repetition_penalty , score / repetition_penalty
126
127
)
127
- logits .scatter_ (dim = 0 , index = previous_tokens , src = score )
128
+ logits .scatter_ (dim = 1 , index = previous_tokens , src = score )
128
129
129
130
if top_p is not None and top_p < 1.0 :
130
131
sorted_logits , sorted_indices = torch .sort (logits , descending = True )
131
132
cum_probs = torch .cumsum (
132
133
torch .nn .functional .softmax (sorted_logits , dim = - 1 ), dim = - 1
133
134
)
134
135
sorted_indices_to_remove = cum_probs > top_p
135
- sorted_indices_to_remove [0 ] = False # keep at least one option
136
+ sorted_indices_to_remove [:, 0 ] = False # keep at least one option
136
137
indices_to_remove = sorted_indices_to_remove .scatter (
137
- dim = 0 , index = sorted_indices , src = sorted_indices_to_remove
138
+ dim = 1 , index = sorted_indices , src = sorted_indices_to_remove
138
139
)
139
140
logits = logits .masked_fill (indices_to_remove , - float ("Inf" ))
140
141
141
142
logits = logits / max (temperature , 1e-5 )
142
143
143
144
if top_k is not None :
144
145
v , _ = torch .topk (logits , min (top_k , logits .size (- 1 )))
145
- pivot = v . select ( - 1 , - 1 ) .unsqueeze (- 1 )
146
+ pivot = v [: , - 1 ] .unsqueeze (- 1 )
146
147
logits = torch .where (logits < pivot , - float ("Inf" ), logits )
147
148
148
149
probs = torch .nn .functional .softmax (logits , dim = - 1 )
149
150
return probs
150
151
151
152
152
153
def sample (
153
- logits ,
154
- previous_tokens : Optional [torch .Tensor ] = None ,
155
- ** sampling_kwargs ,
154
+ logits ,
155
+ previous_tokens : Optional [torch .Tensor ] = None ,
156
+ ** sampling_kwargs ,
156
157
) -> Tuple [torch .Tensor , torch .Tensor ]:
157
158
probs = logits_to_probs (
158
159
logits = logits , previous_tokens = previous_tokens , ** sampling_kwargs
159
160
)
160
161
idx_next = multinomial_sample_one_no_sync (probs )
161
162
return idx_next , probs
162
163
164
+
163
165
def dpo_loss (policy_chosen_logps : torch .FloatTensor ,
164
166
policy_rejected_logps : torch .FloatTensor ,
165
167
reference_chosen_logps : torch .FloatTensor ,
@@ -180,15 +182,20 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
180
182
181
183
return losses .mean (), chosen_rewards , rejected_rewards
182
184
183
- def get_batch_logps (logits_target : torch .FloatTensor , logits_reject : torch .FloatTensor , labels_target : torch .LongTensor , labels_reject : torch .LongTensor , average_log_prob : bool = False ) -> Tuple [torch .FloatTensor , torch .FloatTensor ]:
184
185
186
+ def get_batch_logps (logits_target : torch .FloatTensor , logits_reject : torch .FloatTensor , labels_target : torch .LongTensor ,
187
+ labels_reject : torch .LongTensor , average_log_prob : bool = False ) -> Tuple [
188
+ torch .FloatTensor , torch .FloatTensor ]:
185
189
# dummy token; we'll ignore the losses on these tokens later
186
190
187
- per_token_logps_target = torch .gather (logits_target .log_softmax (- 1 ), dim = 2 , index = labels_target .unsqueeze (2 )).squeeze (2 )
188
- per_token_logps_reject = torch .gather (logits_reject .log_softmax (- 1 ), dim = 2 , index = labels_reject .unsqueeze (2 )).squeeze (2 )
191
+ per_token_logps_target = torch .gather (logits_target .log_softmax (- 1 ), dim = 2 ,
192
+ index = labels_target .unsqueeze (2 )).squeeze (2 )
193
+ per_token_logps_reject = torch .gather (logits_reject .log_softmax (- 1 ), dim = 2 ,
194
+ index = labels_reject .unsqueeze (2 )).squeeze (2 )
189
195
190
196
return per_token_logps_target .sum (- 1 ), per_token_logps_reject .sum (- 1 )
191
197
198
+
192
199
def make_reject_y (y_o , y_lens ):
193
200
def repeat_P (y ):
194
201
range_idx , _ = torch .randint (0 , len (y ), size = (2 ,)).sort ()
@@ -197,23 +204,25 @@ def repeat_P(y):
197
204
range_text = y [range_idx [0 ]:range_idx [1 ]]
198
205
new_y = torch .cat ([pre , range_text , range_text , shf ])
199
206
return new_y
207
+
200
208
def lost_P (y ):
201
209
range_idx , _ = torch .randint (0 , len (y ), size = (2 ,)).sort ()
202
210
pre = y [:range_idx [0 ]]
203
211
shf = y [range_idx [1 ]:]
204
212
range_text = y [range_idx [0 ]:range_idx [1 ]]
205
213
new_y = torch .cat ([pre , shf ])
206
214
return new_y
215
+
207
216
bs = len (y_lens )
208
217
reject_y = []
209
218
reject_y_lens = []
210
219
for b in range (bs ):
211
- process_item_idx = torch .randint (0 , 1 , size = (1 , ))[0 ]
220
+ process_item_idx = torch .randint (0 , 1 , size = (1 ,))[0 ]
212
221
if process_item_idx == 0 :
213
222
new_y = repeat_P (y_o [b ])
214
223
reject_y .append (new_y )
215
224
reject_y_lens .append (len (new_y ))
216
- elif process_item_idx == 1 :
225
+ elif process_item_idx == 1 :
217
226
new_y = lost_P (y_o [b ])
218
227
reject_y .append (new_y )
219
228
reject_y_lens .append (len (new_y ))
@@ -222,7 +231,7 @@ def lost_P(y):
222
231
pad_length = max_length - reject_y_lens [b ]
223
232
reject_y [b ] = torch .cat ([reject_y [b ], torch .zeros (pad_length , dtype = y_o .dtype , device = y_o .device )], dim = 0 )
224
233
225
- reject_y = torch .stack (reject_y , dim = 0 )
234
+ reject_y = torch .stack (reject_y , dim = 0 )
226
235
reject_y_lens = torch .tensor (reject_y_lens , device = y_lens .device )
227
236
228
237
return reject_y , reject_y_lens
0 commit comments