Skip to content

Commit 82e360b

Browse files
authored
Fixed the docstring and type hint for forced_decoder_ids option in Ge… (#19640)
1 parent f2ecb9e commit 82e360b

File tree

4 files changed

+24
-18
lines changed

4 files changed

+24
-18
lines changed

src/transformers/generation_logits_process.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -735,10 +735,11 @@ def __call__(self, input_ids, scores):
735735

736736

737737
class ForceTokensLogitsProcessor(LogitsProcessor):
738-
r"""This processor can be used to force a list of tokens. The processor will set their log probs to `inf` so that they
739-
are sampled at their corresponding index."""
738+
r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
739+
indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are
740+
sampled at their corresponding index."""
740741

741-
def __init__(self, force_token_map):
742+
def __init__(self, force_token_map: List[List[int]]):
742743
self.force_token_map = dict(force_token_map)
743744

744745
def __call__(self, input_ids, scores):

src/transformers/generation_tf_logits_process.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,10 +547,11 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.
547547

548548

549549
class TFForceTokensLogitsProcessor(TFLogitsProcessor):
550-
r"""This processor can be used to force a list of tokens. The processor will set their log probs to `0` and all
551-
other tokens to `-inf` so that they are sampled at their corresponding index."""
550+
r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
551+
indices that will be forced before sampling. The processor will set their log probs to `0` and all other tokens to
552+
`-inf` so that they are sampled at their corresponding index."""
552553

553-
def __init__(self, force_token_map):
554+
def __init__(self, force_token_map: List[List[int]]):
554555
force_token_map = dict(force_token_map)
555556
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
556557
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.

src/transformers/generation_tf_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def generate(
406406
forced_eos_token_id=None,
407407
suppress_tokens: Optional[List[int]] = None,
408408
begin_suppress_tokens: Optional[List[int]] = None,
409-
forced_decoder_ids: Optional[List[int]] = None,
409+
forced_decoder_ids: Optional[List[List[int]]] = None,
410410
**model_kwargs,
411411
) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
412412
r"""
@@ -506,8 +506,10 @@ def generate(
506506
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
507507
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
508508
logit processor will set their log probs to `-inf` so that they are not sampled.
509-
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
510-
A list of tokens that will be forced as beginning tokens, before sampling.
509+
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
510+
A list of pairs of integers which indicates a mapping from generation indices to token indices that
511+
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
512+
be a token of index 123.
511513
model_specific_kwargs:
512514
Additional model specific kwargs will be forwarded to the `forward` function of the model.
513515
@@ -1493,9 +1495,10 @@ def _generate(
14931495
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
14941496
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
14951497
logit processor will set their log probs to `-inf` so that they are not sampled.
1496-
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
1497-
A list of tokens that will be forced as beginning tokens.
1498-
1498+
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
1499+
A list of pairs of integers which indicates a mapping from generation indices to token indices that
1500+
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
1501+
be a token of index 123.
14991502
model_kwargs:
15001503
Additional model specific kwargs will be forwarded to the `call` function of the model.
15011504
@@ -2147,7 +2150,7 @@ def _get_logits_processor(
21472150
forced_eos_token_id: int,
21482151
suppress_tokens: Optional[List[int]] = None,
21492152
begin_suppress_tokens: Optional[List[int]] = None,
2150-
forced_decoder_ids: Optional[List[int]] = None,
2153+
forced_decoder_ids: Optional[List[List[int]]] = None,
21512154
) -> TFLogitsProcessorList:
21522155
"""
21532156
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]

src/transformers/generation_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ def _get_logits_processor(
696696
renormalize_logits: Optional[bool],
697697
suppress_tokens: Optional[List[int]] = None,
698698
begin_suppress_tokens: Optional[List[int]] = None,
699-
forced_decoder_ids: Optional[List[int]] = None,
699+
forced_decoder_ids: Optional[List[List[int]]] = None,
700700
) -> LogitsProcessorList:
701701
"""
702702
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
@@ -956,7 +956,7 @@ def generate(
956956
exponential_decay_length_penalty: Optional[Tuple[int, float]] = None,
957957
suppress_tokens: Optional[List[int]] = None,
958958
begin_suppress_tokens: Optional[List[int]] = None,
959-
forced_decoder_ids: Optional[List[int]] = None,
959+
forced_decoder_ids: Optional[List[List[int]]] = None,
960960
**model_kwargs,
961961
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
962962
r"""
@@ -1121,9 +1121,10 @@ def generate(
11211121
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
11221122
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
11231123
logit processor will set their log probs to `-inf` so that they are not sampled.
1124-
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
1125-
A list of tokens that will be forced as beginning tokens, before sampling.
1126-
1124+
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
1125+
A list of pairs of integers which indicates a mapping from generation indices to token indices that
1126+
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
1127+
be a token of index 123.
11271128
model_kwargs:
11281129
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
11291130
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs

0 commit comments

Comments
 (0)