Skip to content

Commit 4aba114

Browse files
authored
Merge branch 'rep-pen' into unified
2 parents 8082df1 + 45168f5 commit 4aba114

File tree

7 files changed

+30
-8
lines changed

7 files changed

+30
-8
lines changed

extensions/api/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def build_parameters(body, chat=False):
3131
'tfs': float(body.get('tfs', 1)),
3232
'top_a': float(body.get('top_a', 0)),
3333
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
34+
'additive_repetition_penalty': float(body.get('additive_repetition_penalty', body.get('additive_rep_pen', 0))),
3435
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
3536
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
3637
'top_k': int(body.get('top_k', 0)),

modules/loaders.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150
'tfs',
151151
'top_a',
152152
'repetition_penalty',
153+
'additive_repetition_penalty',
153154
'repetition_penalty_range',
154155
'encoder_repetition_penalty',
155156
'no_repeat_ngram_size',
@@ -180,6 +181,7 @@
180181
'tfs',
181182
'top_a',
182183
'repetition_penalty',
184+
'additive_repetition_penalty',
183185
'repetition_penalty_range',
184186
'encoder_repetition_penalty',
185187
'no_repeat_ngram_size',
@@ -219,6 +221,7 @@
219221
'tfs',
220222
'top_a',
221223
'repetition_penalty',
224+
'additive_repetition_penalty',
222225
'repetition_penalty_range',
223226
'encoder_repetition_penalty',
224227
'no_repeat_ngram_size',
@@ -249,6 +252,7 @@
249252
'tfs',
250253
'top_a',
251254
'repetition_penalty',
255+
'additive_repetition_penalty',
252256
'repetition_penalty_range',
253257
'encoder_repetition_penalty',
254258
'no_repeat_ngram_size',
@@ -290,6 +294,7 @@
290294
'tfs',
291295
'top_a',
292296
'repetition_penalty',
297+
'additive_repetition_penalty',
293298
'repetition_penalty_range',
294299
'encoder_repetition_penalty',
295300
'no_repeat_ngram_size',

modules/presets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def default_preset():
1616
'tfs': 1,
1717
'top_a': 0,
1818
'repetition_penalty': 1,
19+
'additive_repetition_penalty': 0,
1920
'repetition_penalty_range': 0,
2021
'encoder_repetition_penalty': 1,
2122
'no_repeat_ngram_size': 0,

modules/sampler_hijack.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,12 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
139139
Copied from the transformers library
140140
'''
141141

142-
def __init__(self, penalty: float, _range: int):
143-
if not isinstance(penalty, float) or not (penalty > 0):
144-
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
142+
def __init__(self, penalty: float, additive_penalty: float, _range: int):
143+
if not (penalty > 0):
144+
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")
145145

146146
self.penalty = penalty
147+
self.additive_penalty = additive_penalty
147148
self._range = _range
148149

149150
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@@ -153,6 +154,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
153154

154155
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
155156
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
157+
score -= self.additive_penalty
156158

157159
scores.scatter_(1, input_ids, score)
158160
return scores
@@ -185,14 +187,22 @@ def get_logits_warper_patch(self, generation_config):
185187

186188

187189
def get_logits_processor_patch(self, **kwargs):
188-
result = self._get_logits_processor_old(**kwargs)
189190
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
190191
repetition_penalty = kwargs['generation_config'].repetition_penalty
192+
additive_repetition_penalty = kwargs['generation_config'].additive_repetition_penalty
193+
need_rep_pen_hijack = (repetition_penalty_range > 0) or (additive_repetition_penalty > 0)
194+
if need_rep_pen_hijack:
195+
# Make sure it always creates a RepetitionPenaltyLogitsProcessor
196+
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1
197+
result = self._get_logits_processor_old(**kwargs)
198+
if need_rep_pen_hijack:
199+
# Now set the rep_pen back to the actual value (just in case)
200+
kwargs['generation_config'].repetition_penalty = repetition_penalty
191201

192-
if repetition_penalty_range > 0:
202+
if need_rep_pen_hijack:
193203
for i in range(len(result)):
194204
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
195-
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range)
205+
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, additive_repetition_penalty, repetition_penalty_range)
196206

197207
return result
198208

@@ -205,6 +215,7 @@ def generation_config_init_patch(self, **kwargs):
205215
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
206216
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
207217
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
218+
self.additive_repetition_penalty = kwargs.pop("additive_repetition_penalty", 0)
208219

209220

210221
def hijack_samplers():

modules/text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def apply_stopping_strings(reply, all_stop_strings):
240240

241241
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
242242
generate_params = {}
243-
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
243+
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
244244
generate_params[k] = state[k]
245245

246246
if state['negative_prompt'] != '':

modules/ui.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def list_interface_input_elements():
112112
'epsilon_cutoff',
113113
'eta_cutoff',
114114
'repetition_penalty',
115+
'additive_repetition_penalty',
115116
'repetition_penalty_range',
116117
'encoder_repetition_penalty',
117118
'no_repeat_ngram_size',

modules/ui_parameters.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def create_ui(default_preset):
3636

3737
with gr.Column():
3838
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
39+
shared.gradio['additive_repetition_penalty'] = gr.Slider(0, 4, value=generate_params['additive_repetition_penalty'], step=0.05, label='additive_repetition_penalty')
3940
shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range')
4041
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
4142
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
@@ -79,7 +80,9 @@ def create_ui(default_preset):
7980
### eta_cutoff
8081
In units of 1e-4; a reasonable value is 3. Should be used with top_p, top_k, and epsilon_cutoff set to 0.
8182
### repetition_penalty
82-
Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.
83+
Exponential penalty factor for repeating prior tokens. This is a multiplicative factor on the raw token scores. 1 means no penalty, higher value = less repetition, lower value = more repetition.
84+
### additive_repetition_penalty
85+
Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. 0 means no penalty, higher value = less repetition, lower value = more repetition.
8386
### repetition_penalty_range
8487
The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
8588
### encoder_repetition_penalty

0 commit comments

Comments
 (0)