Skip to content

Commit c9d8c70

Browse files
stas00amyeroberts
authored andcommitted
[fsmt] deal with -100 indices in decoder ids (huggingface#18592)
* [fsmt] deal with -100 indices in decoder ids Fixes: huggingface#17945 decoder ids get the default index -100, which breaks the model - like t5 and many other models add a fix to replace -100 with the correct pad index. For some reason this use case hasn't been used with this model until recently - so this issue was there since the beginning it seems. Any suggestions to how to add a simple test here? or perhaps we have something similar already? user's script is quite massive. * style
1 parent b2fe78b commit c9d8c70

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/transformers/models/fsmt/modeling_fsmt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,10 @@ def _check_shapes(shape_1, shape2):
372372

373373
def shift_tokens_right(input_ids, pad_token_id):
374374
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
375+
376+
# replace possible -100 values in labels by `pad_token_id`
377+
input_ids.masked_fill_(input_ids == -100, pad_token_id)
378+
375379
prev_output_tokens = input_ids.clone()
376380
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
377381
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()

0 commit comments

Comments
 (0)