@@ -406,7 +406,7 @@ def generate(
406
406
forced_eos_token_id = None ,
407
407
suppress_tokens : Optional [List [int ]] = None ,
408
408
begin_suppress_tokens : Optional [List [int ]] = None ,
409
- forced_decoder_ids : Optional [List [int ]] = None ,
409
+ forced_decoder_ids : Optional [List [List [ int ] ]] = None ,
410
410
** model_kwargs ,
411
411
) -> Union [TFGreedySearchOutput , TFSampleOutput , TFBeamSearchOutput , TFBeamSampleOutput , tf .Tensor ]:
412
412
r"""
@@ -506,8 +506,10 @@ def generate(
506
506
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
507
507
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
508
508
logit processor will set their log probs to `-inf` so that they are not sampled.
509
- forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
510
- A list of tokens that will be forced as beginning tokens, before sampling.
509
+ forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
510
+ A list of pairs of integers which indicates a mapping from generation indices to token indices that
511
+ will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
512
+ be a token of index 123.
511
513
model_specific_kwargs:
512
514
Additional model specific kwargs will be forwarded to the `forward` function of the model.
513
515
@@ -1493,9 +1495,10 @@ def _generate(
1493
1495
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
1494
1496
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
1495
1497
logit processor will set their log probs to `-inf` so that they are not sampled.
1496
- forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
1497
- A list of tokens that will be forced as beginning tokens.
1498
-
1498
+ forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
1499
+ A list of pairs of integers which indicates a mapping from generation indices to token indices that
1500
+ will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
1501
+ be a token of index 123.
1499
1502
model_kwargs:
1500
1503
Additional model specific kwargs will be forwarded to the `call` function of the model.
1501
1504
@@ -2147,7 +2150,7 @@ def _get_logits_processor(
2147
2150
forced_eos_token_id : int ,
2148
2151
suppress_tokens : Optional [List [int ]] = None ,
2149
2152
begin_suppress_tokens : Optional [List [int ]] = None ,
2150
- forced_decoder_ids : Optional [List [int ]] = None ,
2153
+ forced_decoder_ids : Optional [List [List [ int ] ]] = None ,
2151
2154
) -> TFLogitsProcessorList :
2152
2155
"""
2153
2156
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
0 commit comments