Skip to content

Commit f098268

Browse files
authored
TF: T5 can now handle a padded past (i.e. XLA generation) (#17969)
* get the right slicing index for position_bias
1 parent e3139ad commit f098268

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

src/transformers/models/t5/modeling_tf_t5.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import numpy as np
2525
import tensorflow as tf
26+
from tensorflow.compiler.tf2xla.python.xla import dynamic_slice
2627

2728
from ...activations_tf import get_tf_activation
2829
from ...modeling_tf_outputs import (
@@ -384,10 +385,19 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
384385
else:
385386
position_bias = self.compute_bias(real_seq_length, key_length)
386387

387-
# if key and values are already calculated
388-
# we want only the last query position bias
388+
# if key and values are already calculated we want only the last query position bias
389389
if past_key_value is not None:
390-
position_bias = position_bias[:, :, -seq_length:, :]
390+
if not self.has_relative_attention_bias:
391+
position_bias = position_bias[:, :, -seq_length:, :]
392+
else:
393+
# we might have a padded past structure, in which case we want to fetch the position bias slice
394+
# right after the most recently filled past index
395+
most_recently_filled_past_index = tf.reduce_max(tf.where(past_key_value[0][0, 0, :, 0] != 0.0))
396+
position_bias = dynamic_slice(
397+
position_bias,
398+
(0, 0, most_recently_filled_past_index + 1, 0),
399+
(1, self.n_heads, seq_length, real_seq_length),
400+
)
391401

392402
if mask is not None:
393403
position_bias = tf.cast(position_bias, dtype=mask.dtype)

tests/models/t5/test_modeling_tf_t5.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -590,21 +590,17 @@ def test_beam_search_xla_generate_simple(self):
590590
]
591591
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
592592

593-
# xla_generate = tf.function(model.generate, jit_compile=True)
594-
xla_generate = tf.function(model.generate)
593+
xla_generate = tf.function(model.generate, jit_compile=True)
595594

596-
# TODO (joao): there is something not quite right with XLA T5 -- as we increase `max_length` the two outputs
597-
# drift appart, where the XLA version clearly degrades its quality. XLA-related variables look fine (they are
598-
# being padded and filled in the right places). This also happens in other generation modes. Investigate.
599-
output_ids = model.generate(input_ids, num_beams=2, max_length=9)
600-
output_ids_xla = xla_generate(input_ids, num_beams=2, max_length=9)
595+
output_ids = model.generate(input_ids, num_beams=2)
596+
output_ids_xla = xla_generate(input_ids, num_beams=2)
601597

602598
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
603599
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
604600

605601
expected_output_string = [
606602
"Aujourd'hui est une belle journée.",
607-
"J'ai quatre chats,",
603+
"J'ai quatre chats, trois chiens, deux oiseaux et un cheval.",
608604
]
609605

610606
self.assertListEqual(expected_output_string, output_strings)

0 commit comments

Comments
 (0)