This repository was archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
Copy patht5.py
1145 lines (988 loc) · 44.9 KB
/
t5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
An implementation of [T5](https://api.semanticscholar.org/CorpusID:204838007), adapted from [HuggingFace]
(https://github.com/huggingface/transformers/blob/4c32f9f26e6a84f0d9843fec8757e6ce640bb44e/src/transformers/models/t5/modeling_t5.py).
""" # noqa: E401
import logging
from typing import Optional, Tuple, List, Union, Dict, TYPE_CHECKING, NamedTuple, Callable
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from allennlp.common import FromParams, Params, Lazy, Registrable
from allennlp.common.checks import ConfigurationError
from allennlp.modules.transformer.transformer_module import TransformerModule
from allennlp.modules.transformer.attention_module import (
T5Attention,
AttentionOutput,
)
from allennlp.modules.transformer.util import (
get_extended_attention_mask,
FloatT,
IntT,
BoolT,
)
from allennlp.nn.beam_search import BeamSearch
from allennlp.nn.parallel import DdpAccelerator
from allennlp.nn.checkpoint import CheckpointWrapper
if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
class T5LayerNorm(TransformerModule, FromParams):
"""T5-style layer norm does not have bias and does not subtract the mean."""
def __init__(self, hidden_size: int = 512, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states) -> FloatT:
# layer norm should always be calculated in float32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into float16 if necessary
if self.weight.dtype == torch.float16:
hidden_states = hidden_states.to(torch.float16)
return self.weight * hidden_states
class T5FeedForwardProjection(TransformerModule, Registrable):
def forward(self, hidden_states) -> FloatT:
raise NotImplementedError
@T5FeedForwardProjection.register("relu")
class T5DenseReluDense(TransformerModule, FromParams):
def __init__(self, hidden_size: int = 512, ff_size: int = 2048, dropout: float = 0.1):
super().__init__()
self.wi = nn.Linear(hidden_size, ff_size, bias=False)
self.wi.weight.data.normal_(mean=0.0, std=hidden_size**-0.5)
self.wo = nn.Linear(ff_size, hidden_size, bias=False)
self.wo.weight.data.normal_(mean=0.0, std=ff_size**-0.5)
self.dropout = nn.Dropout(dropout)
def forward(self, hidden_states) -> FloatT:
hidden_states = self.wi(hidden_states)
hidden_states = F.relu(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
@T5FeedForwardProjection.register("gated-gelu")
class T5DenseGatedGeluDense(TransformerModule, FromParams):
def __init__(self, hidden_size: int = 512, ff_size: int = 2048, dropout: float = 0.1):
super().__init__()
self.wi_0 = nn.Linear(hidden_size, ff_size, bias=False)
self.wi_0.weight.data.normal_(mean=0.0, std=hidden_size**-0.5)
self.wi_1 = nn.Linear(hidden_size, ff_size, bias=False)
self.wi_1.weight.data.normal_(mean=0.0, std=hidden_size**-0.5)
self.wo = nn.Linear(ff_size, hidden_size, bias=False)
self.wo.weight.data.normal_(mean=0.0, std=ff_size**-0.5)
self.dropout = nn.Dropout(dropout)
from allennlp.nn import Activation
self.gelu_act = Activation.by_name("gelu_new")()
def forward(self, hidden_states) -> FloatT:
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
class T5LayerFF(TransformerModule, FromParams):
_pretrained_mapping = {"DenseReluDense": "ff_proj"}
def __init__(
self,
ff_proj: Optional[T5FeedForwardProjection] = None,
layer_norm: Optional[T5LayerNorm] = None,
dropout: float = 0.1,
):
super().__init__()
self.ff_proj = ff_proj or T5DenseReluDense()
self.layer_norm = layer_norm or T5LayerNorm()
self.dropout = nn.Dropout(dropout)
def forward(self, hidden_states) -> FloatT:
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.ff_proj(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
class T5LayerSelfAttentionOutput(NamedTuple):
hidden_states: FloatT
attn_key_value_state: Optional[Tuple[FloatT, FloatT]]
attn_position_bias: FloatT
attn_weights: Optional[FloatT] = None
class T5LayerSelfAttention(TransformerModule, FromParams):
_pretrained_mapping = {"SelfAttention": "self_attention"}
def __init__(
self,
self_attention: Optional[T5Attention] = None,
layer_norm: Optional[T5LayerNorm] = None,
dropout: float = 0.1,
has_relative_attention_bias: bool = False,
):
super().__init__()
self.self_attention = self_attention or T5Attention(
has_relative_attention_bias=has_relative_attention_bias
)
self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.self_attention.hidden_size)
self.dropout = nn.Dropout(dropout)
@property
def hidden_size(self) -> int:
return self.self_attention.hidden_size
def forward(
self,
hidden_states: FloatT,
attention_mask: Optional[torch.BoolTensor] = None,
position_bias: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.BoolTensor] = None,
past_key_value: Optional[Tuple[FloatT]] = None,
use_cache: bool = False,
output_attentions: bool = False,
) -> T5LayerSelfAttentionOutput:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output: AttentionOutput = self.self_attention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = hidden_states + self.dropout(attention_output.hidden_states)
return T5LayerSelfAttentionOutput(
hidden_states,
attention_output.key_value_state,
attention_output.position_bias,
attention_output.attention_probs,
)
class T5LayerCrossAttentionOutput(NamedTuple):
hidden_states: FloatT
attn_key_value_state: Optional[Tuple[FloatT, FloatT]]
attn_position_bias: FloatT
attn_weights: Optional[FloatT] = None
class T5LayerCrossAttention(TransformerModule, FromParams):
_pretrained_mapping = {"EncDecAttention": "enc_dec_attention"}
def __init__(
self,
enc_dec_attention: Optional[T5Attention] = None,
layer_norm: Optional[T5LayerNorm] = None,
dropout: float = 0.1,
):
super().__init__()
self.enc_dec_attention = enc_dec_attention or T5Attention(
is_decoder=True,
has_relative_attention_bias=False,
is_cross_attention=True,
)
self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.enc_dec_attention.hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(
self,
hidden_states: FloatT,
key_value_states: Optional[FloatT],
attention_mask: Optional[torch.BoolTensor] = None,
position_bias: Optional[FloatT] = None,
layer_head_mask: Optional[torch.BoolTensor] = None,
past_key_value: Optional[Tuple[Tuple[FloatT]]] = None,
use_cache: bool = False,
query_length: int = None,
output_attentions: bool = False,
) -> T5LayerCrossAttentionOutput:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output: AttentionOutput = self.enc_dec_attention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
)
layer_output = hidden_states + self.dropout(attention_output.hidden_states)
return T5LayerCrossAttentionOutput(
layer_output,
attention_output.key_value_state,
attention_output.position_bias,
attention_output.attention_probs,
)
KeyValueStates = Union[
Tuple[FloatT, FloatT], # without cross attention
Tuple[FloatT, FloatT, FloatT, FloatT], # with cross attention
]
class T5BlockOutput(NamedTuple):
hidden_states: FloatT
present_key_value_states: Optional[KeyValueStates]
self_attn_weights: Optional[FloatT]
self_attn_position_bias: Optional[FloatT]
cross_attn_weights: Optional[FloatT] = None
cross_attn_position_bias: Optional[FloatT] = None
class T5Block(TransformerModule, FromParams):
def __init__(
self,
attention: Optional[T5LayerSelfAttention] = None,
cross_attention: Optional[T5LayerCrossAttention] = None,
ff: Optional[T5LayerFF] = None,
):
super().__init__()
self.layer = nn.ModuleList()
self.layer.append(attention or T5LayerSelfAttention())
if cross_attention is None:
self.is_decoder = False
else:
self.layer.append(cross_attention)
self.is_decoder = True
self.layer.append(ff or T5LayerFF())
@property
def hidden_size(self) -> int:
return self.layer[0].hidden_size
def forward(
self,
hidden_states: FloatT,
attention_mask: Optional[torch.BoolTensor] = None,
position_bias: Optional[FloatT] = None,
encoder_hidden_states: Optional[FloatT] = None,
encoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_decoder_position_bias: Optional[FloatT] = None,
layer_head_mask: Optional[torch.BoolTensor] = None,
encoder_layer_head_mask: Optional[torch.BoolTensor] = None,
past_key_value: Optional[KeyValueStates] = None,
use_cache: bool = False,
output_attentions: bool = False,
) -> T5BlockOutput:
if past_key_value is not None:
assert self.is_decoder, "Only decoder can use `past_key_values`"
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
error_message = f"There should be {expected_num_past_key_values} past states. "
error_message += "2 (past / key) for self attention. "
if expected_num_past_key_values == 4:
error_message += "2 (past / key) for cross attention. "
error_message += f"Got {len(past_key_value)} past key / value states"
assert len(past_key_value) == expected_num_past_key_values, error_message
self_attention_outputs: T5LayerSelfAttentionOutput = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=None if past_key_value is None else past_key_value[:2],
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = self_attention_outputs.hidden_states
present_key_value_state: Optional[
Tuple[FloatT, FloatT]
] = self_attention_outputs.attn_key_value_state
# clamp inf values to enable fp16 training
if torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs: T5LayerCrossAttentionOutput = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
layer_head_mask=encoder_layer_head_mask,
past_key_value=None if past_key_value is None else past_key_value[2:],
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = cross_attention_outputs.hidden_states
if torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if (
present_key_value_state is not None
and cross_attention_outputs.attn_key_value_state is not None
):
present_key_value_state: KeyValueStates = ( # type: ignore[no-redef]
present_key_value_state + cross_attention_outputs.attn_key_value_state
)
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
if torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
output = T5BlockOutput(
hidden_states,
present_key_value_state,
self_attention_outputs.attn_weights,
self_attention_outputs.attn_position_bias,
cross_attn_weights=(
None if not do_cross_attention else cross_attention_outputs.attn_weights
),
cross_attn_position_bias=(
None if not do_cross_attention else cross_attention_outputs.attn_position_bias
),
)
return output
class T5StackOutput(NamedTuple):
last_hidden_state: FloatT
past_key_values: Optional[List[KeyValueStates]] = None
all_hidden_states: Optional[List[FloatT]] = None
attentions: Optional[List[FloatT]] = None
cross_attentions: Optional[List[FloatT]] = None
class T5Stack(TransformerModule, FromParams):
_pretrained_mapping = {"embed_tokens": "token_embeddings", "block": "blocks"}
def __init__(
self,
token_embeddings: nn.Embedding,
blocks: List[T5Block],
final_layer_norm: Optional[T5LayerNorm] = None,
dropout: float = 0.1,
):
super().__init__()
self.is_decoder = blocks[0].is_decoder
if not all(b.is_decoder == self.is_decoder for b in blocks):
raise ConfigurationError("Found mismatched blocks in stack.")
self.blocks = nn.ModuleList(blocks)
self.token_embeddings = token_embeddings
self.final_layer_norm = final_layer_norm or T5LayerNorm(hidden_size=self.hidden_size)
self.dropout = nn.Dropout(dropout)
@property
def num_blocks(self) -> int:
return len(self.blocks)
@property
def hidden_size(self) -> int:
return self.blocks[0].hidden_size
@staticmethod
def get_head_mask(head_mask: Optional[torch.BoolTensor], num_hidden_layers: int) -> BoolT:
if head_mask is not None:
# -> [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
else:
head_mask = [None] * num_hidden_layers
return head_mask
def resize_token_embeddings(
self, new_size: int, *, init_fn: Callable = torch.nn.init.normal_
) -> None:
old_size, embedding_dim = tuple(self.token_embeddings.weight.shape)
if old_size == new_size:
return
if old_size > new_size:
logger.warning(
"Shrinking vocabulary from size %d to size %d. This is probably not what you want?",
old_size,
new_size,
)
result = torch.nn.Embedding(
new_size,
embedding_dim,
self.token_embeddings.padding_idx,
self.token_embeddings.max_norm,
self.token_embeddings.norm_type,
self.token_embeddings.scale_grad_by_freq,
self.token_embeddings.sparse,
device=self.token_embeddings.weight.device,
dtype=self.token_embeddings.weight.dtype,
)
copy_size = min(old_size, new_size)
result.weight.data[:copy_size, ...] = self.token_embeddings.weight.data[:copy_size, ...]
if new_size > old_size:
init_fn(result.weight.data[copy_size:, ...])
self.token_embeddings = result
def forward(
self,
input_ids: Optional[torch.IntTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
encoder_hidden_states: Optional[FloatT] = None,
encoder_attention_mask: Optional[torch.BoolTensor] = None,
inputs_embeds: Optional[FloatT] = None,
head_mask: Optional[torch.BoolTensor] = None,
encoder_head_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[KeyValueStates] = None,
use_cache: bool = False,
output_attentions: bool = False,
output_all_hidden_states: bool = False,
) -> T5StackOutput:
if input_ids is not None and inputs_embeds is not None:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You cannot specify both {err_msg_prefix}inputs "
f"and {err_msg_prefix}inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds"
)
if inputs_embeds is None:
assert (
self.token_embeddings is not None
), "You have to initialize the model with valid token embeddings"
inputs_embeds = self.token_embeddings(input_ids)
batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = (
seq_length if past_key_values is None else past_key_values[0][0].shape[2] + seq_length
)
if use_cache is True:
assert (
self.is_decoder
), ":obj:`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
if attention_mask is None:
attention_mask = torch.ones(
batch_size, mask_seq_length, dtype=torch.bool, device=inputs_embeds.device
)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.bool
)
extended_attention_mask = get_extended_attention_mask(
attention_mask, input_shape, inputs_embeds.dtype, is_decoder=self.is_decoder
)
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.num_blocks)
encoder_head_mask = self.get_head_mask(encoder_head_mask, self.num_blocks)
present_key_value_states: Optional[List[KeyValueStates]] = [] if use_cache else None
all_hidden_states: Optional[List[FloatT]] = [] if output_all_hidden_states else None
all_attentions: Optional[List[FloatT]] = [] if output_attentions else None
all_cross_attentions: Optional[List[FloatT]] = (
[] if (output_attentions and self.is_decoder) else None
)
position_bias: Optional[FloatT] = None
encoder_decoder_position_bias: Optional[FloatT] = None
hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(
zip(self.blocks, past_key_values or [None] * self.num_blocks)
):
layer_head_mask = head_mask[i]
encoder_layer_head_mask = encoder_head_mask[i]
if output_all_hidden_states:
all_hidden_states.append(hidden_states) # type: ignore[union-attr]
layer_outputs: T5BlockOutput = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
encoder_layer_head_mask=encoder_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
# If the blocks were wrapped with a `CheckpointWrapper`, the output
# may just be a raw tuple, not the NamedTuple that we want.
if not isinstance(layer_outputs, T5BlockOutput):
layer_outputs = T5BlockOutput(*layer_outputs)
hidden_states = layer_outputs.hidden_states
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights),
# (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs.self_attn_position_bias
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs.cross_attn_position_bias
if use_cache:
present_key_value_states.append(layer_outputs.present_key_value_states) # type: ignore
if output_attentions:
all_attentions.append(layer_outputs.self_attn_weights) # type: ignore[union-attr]
if self.is_decoder:
all_cross_attentions.append(layer_outputs.cross_attn_weights) # type: ignore[union-attr]
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
# Add last layer
if output_all_hidden_states:
all_hidden_states.append(hidden_states) # type: ignore[union-attr]
return T5StackOutput(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
all_hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
class T5EncoderStack(T5Stack, FromParams):
def __init__(
self,
token_embeddings: nn.Embedding,
blocks: List[T5Block],
final_layer_norm: Optional[T5LayerNorm] = None,
dropout: float = 0.1,
):
if any(b.is_decoder for b in blocks):
raise ConfigurationError("Found a decoder block in an encoder stack. This won't work.")
super().__init__(
token_embeddings,
blocks,
final_layer_norm=final_layer_norm,
dropout=dropout,
)
@classmethod
def basic_encoder(
cls,
token_embeddings: nn.Embedding,
num_blocks: int = 6,
block_self_attention: Lazy[T5Attention] = Lazy(T5Attention),
final_layer_norm: Optional[T5LayerNorm] = None,
block_ff: Lazy[T5LayerFF] = Lazy(T5LayerFF),
dropout: float = 0.1,
ddp_accelerator: Optional[DdpAccelerator] = None,
checkpoint_wrapper: Optional[CheckpointWrapper] = None,
) -> "T5EncoderStack":
if ddp_accelerator is not None:
logger.info("Initializing T5 encoder with DdpAccelerator %s", ddp_accelerator)
blocks: List[T5Block] = []
for i in range(num_blocks):
block = T5Block(
attention=T5LayerSelfAttention(
self_attention=block_self_attention.construct(
is_decoder=False, has_relative_attention_bias=(i == 0)
)
),
cross_attention=None,
ff=block_ff.construct(),
)
if checkpoint_wrapper is not None:
block = checkpoint_wrapper.wrap_module(block)
if ddp_accelerator is not None:
block = ddp_accelerator.wrap_module(block)
blocks.append(block)
return cls(token_embeddings, blocks, final_layer_norm=final_layer_norm, dropout=dropout)
class T5DecoderStack(T5Stack, FromParams):
def __init__(
self,
token_embeddings: nn.Embedding,
blocks: List[T5Block],
final_layer_norm: Optional[T5LayerNorm] = None,
dropout: float = 0.1,
):
if not all(b.is_decoder for b in blocks):
raise ConfigurationError("Found an encoder block in a decoder stack. This won't work.")
super().__init__(
token_embeddings,
blocks,
final_layer_norm=final_layer_norm,
dropout=dropout,
)
@classmethod
def basic_decoder(
cls,
token_embeddings: nn.Embedding,
num_blocks: int = 6,
block_self_attention: Lazy[T5Attention] = Lazy(T5Attention),
block_cross_attention: Lazy[T5Attention] = Lazy(T5Attention),
final_layer_norm: Optional[T5LayerNorm] = None,
block_ff: Lazy[T5LayerFF] = Lazy(T5LayerFF),
dropout: float = 0.1,
ddp_accelerator: Optional[DdpAccelerator] = None,
checkpoint_wrapper: Optional[CheckpointWrapper] = None,
) -> "T5DecoderStack":
if ddp_accelerator is not None:
logger.info("Initializing T5 decoder with DdpAccelerator %s", ddp_accelerator)
blocks: List[T5Block] = []
for i in range(num_blocks):
block = T5Block(
attention=T5LayerSelfAttention(
self_attention=block_self_attention.construct(
is_decoder=True, has_relative_attention_bias=(i == 0)
)
),
cross_attention=T5LayerCrossAttention(
enc_dec_attention=block_cross_attention.construct(
is_decoder=True,
has_relative_attention_bias=False,
)
),
ff=block_ff.construct(),
)
if checkpoint_wrapper is not None:
block = checkpoint_wrapper.wrap_module(block)
if ddp_accelerator is not None:
block = ddp_accelerator.wrap_module(block)
blocks.append(block)
return cls(token_embeddings, blocks, final_layer_norm=final_layer_norm, dropout=dropout)
class T5Output(NamedTuple):
"""
Defines the output from the `T5` model.
"""
encoder_last_hidden_state: FloatT
"""
Final hidden states from the encoder.
Shape: `(batch_size, target_length, hidden_dim)`
"""
encoder_all_hidden_states: Optional[List[FloatT]] = None
"""
All hidden states from the encoder.
Shape (each): `(batch_size, target_length, hidden_dim)`
"""
decoder_last_hidden_state: Optional[FloatT] = None
"""
Final hidden states from the decoder. Only present when `labels` is given.
Shape: `(batch_size, target_length, hidden_dim)`
"""
decoder_all_hidden_states: Optional[List[FloatT]] = None
"""
All hidden states from the decoder. Only present when `labels` is given
and `output_all_hidden_states` is `True`.
Shape (each): `(batch_size, target_length, hidden_dim)`
"""
encoder_attentions: Optional[List[FloatT]] = None
"""
Attention values from the encoder. Only present when `output_attentions` is `True`.
"""
decoder_attentions: Optional[List[FloatT]] = None
"""
Attention values from the decoder. Only present when `labels` is given
and `output_attentions` is `True`.
"""
cross_attentions: Optional[List[FloatT]] = None
"""
Cross-attention values from the decoder. Only present when `labels` is given
and `output_attentions` is `True`.
"""
loss: Optional[FloatT] = None
"""
The loss calculating with respect to `labels`.
"""
logits: Optional[FloatT] = None
"""
The logits that are used to calculate the loss with respect to `labels`.
"""
predictions: Optional[IntT] = None
"""
Predicted token IDs from beam search.
Shape: `(batch_size, beam_size, max_decoding_steps)`.
"""
predicted_log_probs: Optional[FloatT] = None
"""
Probabilities corresponding to `predictions`.
Shape: `(batch_size, beam_size,)`.
"""
class T5(TransformerModule, Registrable):
_pretrained_mapping = {"shared": "token_embeddings"}
# Don't know why HF has this param in their state_dict. It's not used in their model.
_pretrained_ignore = [
r"^decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight$"
]
default_implementation = "default"
def __init__(
self,
token_embeddings: Optional[nn.Embedding] = None,
encoder: Lazy[T5EncoderStack] = Lazy(T5EncoderStack.basic_encoder),
decoder: Lazy[T5DecoderStack] = Lazy(T5DecoderStack.basic_decoder),
decoder_start_token_id: int = 0,
pad_token_id: int = 0, # These are both 0 in t5-(small|base|large). Go figure.
eos_token_id: int = 1,
vocab_size: int = 32128,
model_dim: int = 512,
output_attentions: bool = False,
output_all_hidden_states: bool = False,
beam_search: Lazy[BeamSearch] = Lazy(BeamSearch, beam_size=3, max_steps=100),
ddp_accelerator: Optional[DdpAccelerator] = None,
checkpoint_wrapper: Optional[CheckpointWrapper] = None,
tie_word_embeddings: bool = True,
):
super().__init__()
self._tie_word_embeddings = tie_word_embeddings
self.model_dim = model_dim
self.token_embeddings = token_embeddings or nn.Embedding(vocab_size, model_dim)
if token_embeddings is None:
self.token_embeddings.weight.data.normal_(mean=0.0, std=1.0)
self.encoder: T5EncoderStack = encoder.construct(
token_embeddings=self.token_embeddings,
ddp_accelerator=ddp_accelerator,
checkpoint_wrapper=checkpoint_wrapper,
)
self.decoder: T5DecoderStack = decoder.construct(
token_embeddings=self.token_embeddings,
ddp_accelerator=ddp_accelerator,
checkpoint_wrapper=checkpoint_wrapper,
)
self.lm_head = nn.Linear(
self.decoder.hidden_size, self.token_embeddings.num_embeddings, bias=False
)
if self._tie_word_embeddings:
self.lm_head.weight = self.token_embeddings.weight
self.loss_fct = CrossEntropyLoss(ignore_index=-100)
self.decoder_start_token_id = decoder_start_token_id
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
self.output_attentions = output_attentions
self.output_all_hidden_states = output_all_hidden_states
self.beam_search = beam_search.construct(end_index=self.eos_token_id)
def resize_token_embeddings(
self, new_size: int, *, init_fn: Callable = torch.nn.init.normal_
) -> None:
"""
Resizes the token embeddings in the model.
This takes care of the token embeddings for the encoder, the decoder, and the LM head.
new_size : `int`
The new size of the token embeddings
init_fn : `Callable`
The function to use to initialize new embeddings. This function will be called with a
single argument, the tensor to initialize, and it is expected to initialize the tensor
in place. Many of the functions from `torch.nn.init` fit.
"""
self.encoder.resize_token_embeddings(new_size, init_fn=init_fn)
# If encoder and decoder share embeddings, this is a no-op the second time.
self.decoder.resize_token_embeddings(new_size, init_fn=init_fn)
# resize lm head
old_size = self.lm_head.out_features
if old_size == new_size:
return
new_lm_head = torch.nn.Linear(
self.lm_head.in_features,
new_size,
self.lm_head.bias,
self.lm_head.weight.device,
self.lm_head.weight.dtype,
)
copy_size = min(old_size, new_size)
new_lm_head.weight.data[:copy_size, ...] = self.lm_head.weight.data[:copy_size, ...]
if self.lm_head.bias and new_lm_head.bias:
new_lm_head.bias.data[:copy_size, ...] = self.lm_head.bias[:copy_size, ...]
if new_size > old_size:
init_fn(new_lm_head.weight.data[copy_size:, ...])
if new_lm_head.bias:
init_fn(new_lm_head.bias[copy_size:, ...])
self.lm_head = new_lm_head
def _post_load_state_dict(
self, missing_keys: List[str], unexpected_keys: List[str]
) -> Tuple[List[str], List[str]]:
missing_keys_to_ignore = [
"encoder.token_embeddings.weight",
"decoder.token_embeddings.weight",
]
if self._tie_word_embeddings:
missing_keys_to_ignore.append("lm_head.weight")
for key in missing_keys_to_ignore:
if key in missing_keys:
missing_keys.remove(key)
return missing_keys, unexpected_keys
@classmethod
def _from_config(cls, config: "PretrainedConfig", **kwargs):
attention_kwargs = {
"hidden_size": config.d_model,
"key_value_proj_dim": config.d_kv,
"num_heads": config.num_heads,
"relative_attention_num_buckets": config.relative_attention_num_buckets,
"dropout": config.dropout_rate,
}
layer_norm_kwargs = {
"hidden_size": config.d_model,
"eps": config.layer_norm_epsilon,
}
block_ff = Lazy(
T5LayerFF,
params=Params(
{
"ff_proj": {
"type": config.feed_forward_proj,
"hidden_size": config.d_model,
"ff_size": config.d_ff,
"dropout": config.dropout_rate,
},
"layer_norm": layer_norm_kwargs,
"dropout": config.dropout_rate,
}
),
)
return cls(
encoder=Lazy(
T5EncoderStack.basic_encoder,
constructor_extras={
"num_blocks": config.num_layers,
"block_self_attention": Lazy(T5Attention, constructor_extras=attention_kwargs),
"final_layer_norm": T5LayerNorm(**layer_norm_kwargs),
"block_ff": block_ff,
"dropout": config.dropout_rate,
},
),
decoder=Lazy(
T5DecoderStack.basic_decoder,
constructor_extras={
"num_blocks": config.num_decoder_layers,
"block_self_attention": Lazy(T5Attention, constructor_extras=attention_kwargs),
"block_cross_attention": Lazy(T5Attention, constructor_extras=attention_kwargs),
"final_layer_norm": T5LayerNorm(**layer_norm_kwargs),
"block_ff": block_ff,
"dropout": config.dropout_rate,
},
),
decoder_start_token_id=config.decoder_start_token_id,
pad_token_id=config.pad_token_id,
eos_token_id=config.eos_token_id,
vocab_size=config.vocab_size,
model_dim=config.d_model,
tie_word_embeddings=kwargs.pop("tie_word_embeddings", config.tie_word_embeddings),
**kwargs,
)
def _shift_right(self, input_ids, start_value: int):
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = start_value
return shifted_input_ids
def _get_lm_logits(self, decoder_last_hidden_state: FloatT) -> FloatT:
# Shape: (batch_size, target_length, model_dim)
sequence_output = decoder_last_hidden_state
# Rescale output before projecting on vocab
# TODO: HF only does this when does this when embeddings are tied.
# Currently tied embeddings is the only option we have, but if make
# that configurable then we should put this in an 'if' block.
sequence_output = sequence_output * (self.model_dim**-0.5)
# Shape: (batch_size, target_length, vocab_size)
logits = self.lm_head(sequence_output)
return logits
def forward(
self,
input_ids: IntT,
attention_mask: Optional[BoolT] = None,
labels: Optional[IntT] = None,
decoder_attention_mask: Optional[BoolT] = None,
) -> T5Output:
"""
Run forward pass of the model.
"""
if attention_mask is None:
attention_mask = ~(input_ids == self.pad_token_id)
# Encode inputs.
encoder_outputs: T5StackOutput = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=self.output_attentions,
output_all_hidden_states=self.output_all_hidden_states,
)
logits: Optional[FloatT] = None
loss: Optional[FloatT] = None
decoder_outputs: Optional[T5StackOutput] = None
predictions: Optional[IntT] = None
predicted_log_probs: Optional[FloatT] = None
if labels is not None:
# Calculate loss against targets.