Skip to content

Commit a272ed0

Browse files
donebydanDaniel Jones
authored andcommitted
Change BartLearnedPositionalEmbedding's forward method signature to support Opacus training (huggingface#18486)
* changing BartLearnedPositionalEmbedding forward signature and references to it * removing debugging dead code (thanks style checker) * blackened modeling_bart file * removing copy inconsistencies via make fix-copies * changing references to copied signatures in Bart variants * make fix-copies once more * using expand over repeat (thanks @michaelbenayoun) * expand instead of repeat for all model copies Co-authored-by: Daniel Jones <[email protected]>
1 parent ad4215f commit a272ed0

File tree

5 files changed

+70
-46
lines changed

5 files changed

+70
-46
lines changed

src/transformers/models/bart/modeling_bart.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
128128
self.offset = 2
129129
super().__init__(num_embeddings + self.offset, embedding_dim)
130130

131-
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
132-
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
133-
bsz, seq_len = input_ids_shape[:2]
131+
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
132+
"""`input_ids' shape is expected to be [bsz x seqlen]."""
133+
134+
bsz, seq_len = input_ids.shape[:2]
134135
positions = torch.arange(
135136
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
136-
)
137+
).expand(bsz, -1)
138+
137139
return super().forward(positions + self.offset)
138140

139141

@@ -788,17 +790,17 @@ def forward(
788790
if input_ids is not None and inputs_embeds is not None:
789791
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
790792
elif input_ids is not None:
791-
input_shape = input_ids.size()
792-
input_ids = input_ids.view(-1, input_shape[-1])
793+
input = input_ids
794+
input_ids = input_ids.view(-1, input_ids.shape[-1])
793795
elif inputs_embeds is not None:
794-
input_shape = inputs_embeds.size()[:-1]
796+
input = inputs_embeds[:, :, -1]
795797
else:
796798
raise ValueError("You have to specify either input_ids or inputs_embeds")
797799

798800
if inputs_embeds is None:
799801
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
800802

801-
embed_pos = self.embed_positions(input_shape)
803+
embed_pos = self.embed_positions(input)
802804

803805
hidden_states = inputs_embeds + embed_pos
804806
hidden_states = self.layernorm_embedding(hidden_states)
@@ -1015,18 +1017,20 @@ def forward(
10151017
if input_ids is not None and inputs_embeds is not None:
10161018
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
10171019
elif input_ids is not None:
1018-
input_shape = input_ids.size()
1020+
input = input_ids
1021+
input_shape = input.shape
10191022
input_ids = input_ids.view(-1, input_shape[-1])
10201023
elif inputs_embeds is not None:
10211024
input_shape = inputs_embeds.size()[:-1]
1025+
input = inputs_embeds[:, :, -1]
10221026
else:
10231027
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
10241028

10251029
# past_key_values_length
10261030
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
10271031

10281032
if inputs_embeds is None:
1029-
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1033+
inputs_embeds = self.embed_tokens(input) * self.embed_scale
10301034

10311035
attention_mask = self._prepare_decoder_attention_mask(
10321036
attention_mask, input_shape, inputs_embeds, past_key_values_length
@@ -1038,7 +1042,7 @@ def forward(
10381042
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
10391043

10401044
# embed positions
1041-
positions = self.embed_positions(input_shape, past_key_values_length)
1045+
positions = self.embed_positions(input, past_key_values_length)
10421046

10431047
hidden_states = inputs_embeds + positions
10441048
hidden_states = self.layernorm_embedding(hidden_states)

src/transformers/models/mbart/modeling_mbart.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
134134
self.offset = 2
135135
super().__init__(num_embeddings + self.offset, embedding_dim)
136136

137-
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
138-
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
139-
bsz, seq_len = input_ids_shape[:2]
137+
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
138+
"""`input_ids' shape is expected to be [bsz x seqlen]."""
139+
140+
bsz, seq_len = input_ids.shape[:2]
140141
positions = torch.arange(
141142
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
142-
)
143+
).expand(bsz, -1)
144+
143145
return super().forward(positions + self.offset)
144146

145147

@@ -783,17 +785,18 @@ def forward(
783785
if input_ids is not None and inputs_embeds is not None:
784786
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
785787
elif input_ids is not None:
786-
input_shape = input_ids.size()
788+
input = input_ids
789+
input_shape = input.shape
787790
input_ids = input_ids.view(-1, input_shape[-1])
788791
elif inputs_embeds is not None:
789-
input_shape = inputs_embeds.size()[:-1]
792+
input = inputs_embeds[:, :, -1]
790793
else:
791794
raise ValueError("You have to specify either input_ids or inputs_embeds")
792795

793796
if inputs_embeds is None:
794797
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
795798

796-
embed_pos = self.embed_positions(input_shape)
799+
embed_pos = self.embed_positions(input)
797800

798801
hidden_states = inputs_embeds + embed_pos
799802
hidden_states = self.layernorm_embedding(hidden_states)
@@ -1013,10 +1016,12 @@ def forward(
10131016
if input_ids is not None and inputs_embeds is not None:
10141017
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
10151018
elif input_ids is not None:
1016-
input_shape = input_ids.size()
1019+
input = input_ids
1020+
input_shape = input.size()
10171021
input_ids = input_ids.view(-1, input_shape[-1])
10181022
elif inputs_embeds is not None:
10191023
input_shape = inputs_embeds.size()[:-1]
1024+
input = inputs_embeds[:, :, -1]
10201025
else:
10211026
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
10221027

@@ -1036,7 +1041,7 @@ def forward(
10361041
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
10371042

10381043
# embed positions
1039-
positions = self.embed_positions(input_shape, past_key_values_length)
1044+
positions = self.embed_positions(input, past_key_values_length)
10401045

10411046
hidden_states = inputs_embeds + positions
10421047
hidden_states = self.layernorm_embedding(hidden_states)

src/transformers/models/mvp/modeling_mvp.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
134134
self.offset = 2
135135
super().__init__(num_embeddings + self.offset, embedding_dim)
136136

137-
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
138-
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
139-
bsz, seq_len = input_ids_shape[:2]
137+
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
138+
"""`input_ids' shape is expected to be [bsz x seqlen]."""
139+
140+
bsz, seq_len = input_ids.shape[:2]
140141
positions = torch.arange(
141142
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
142-
)
143+
).expand(bsz, -1)
144+
143145
return super().forward(positions + self.offset)
144146

145147

@@ -895,17 +897,19 @@ def forward(
895897
if input_ids is not None and inputs_embeds is not None:
896898
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
897899
elif input_ids is not None:
898-
input_shape = input_ids.size()
900+
input = input_ids
901+
input_shape = input.shape
899902
input_ids = input_ids.view(-1, input_shape[-1])
900903
elif inputs_embeds is not None:
901904
input_shape = inputs_embeds.size()[:-1]
905+
input = inputs_embeds[:, :, -1]
902906
else:
903907
raise ValueError("You have to specify either input_ids or inputs_embeds")
904908

905909
if inputs_embeds is None:
906910
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
907911

908-
embed_pos = self.embed_positions(input_shape)
912+
embed_pos = self.embed_positions(input)
909913

910914
hidden_states = inputs_embeds + embed_pos
911915
hidden_states = self.layernorm_embedding(hidden_states)
@@ -1144,10 +1148,12 @@ def forward(
11441148
if input_ids is not None and inputs_embeds is not None:
11451149
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
11461150
elif input_ids is not None:
1147-
input_shape = input_ids.size()
1151+
input = input_ids
1152+
input_shape = input_ids.shape
11481153
input_ids = input_ids.view(-1, input_shape[-1])
11491154
elif inputs_embeds is not None:
11501155
input_shape = inputs_embeds.size()[:-1]
1156+
input = inputs_embeds[:, :, -1]
11511157
else:
11521158
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
11531159

@@ -1167,7 +1173,7 @@ def forward(
11671173
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
11681174

11691175
# embed positions
1170-
positions = self.embed_positions(input_shape, past_key_values_length)
1176+
positions = self.embed_positions(input, past_key_values_length)
11711177

11721178
hidden_states = inputs_embeds + positions
11731179
hidden_states = self.layernorm_embedding(hidden_states)

src/transformers/models/plbart/modeling_plbart.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
131131
self.offset = 2
132132
super().__init__(num_embeddings + self.offset, embedding_dim)
133133

134-
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
135-
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
136-
bsz, seq_len = input_ids_shape[:2]
134+
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
135+
"""`input_ids' shape is expected to be [bsz x seqlen]."""
136+
137+
bsz, seq_len = input_ids.shape[:2]
137138
positions = torch.arange(
138139
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
139-
)
140+
).expand(bsz, -1)
141+
140142
return super().forward(positions + self.offset)
141143

142144

@@ -759,17 +761,17 @@ def forward(
759761
if input_ids is not None and inputs_embeds is not None:
760762
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
761763
elif input_ids is not None:
762-
input_shape = input_ids.size()
763-
input_ids = input_ids.view(-1, input_shape[-1])
764+
input = input_ids
765+
input_ids = input_ids.view(-1, input_ids.shape[-1])
764766
elif inputs_embeds is not None:
765-
input_shape = inputs_embeds.size()[:-1]
767+
input = inputs_embeds[:, :, -1]
766768
else:
767769
raise ValueError("You have to specify either input_ids or inputs_embeds")
768770

769771
if inputs_embeds is None:
770772
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
771773

772-
embed_pos = self.embed_positions(input_shape)
774+
embed_pos = self.embed_positions(input)
773775

774776
hidden_states = inputs_embeds + embed_pos
775777
hidden_states = self.layernorm_embedding(hidden_states)
@@ -987,18 +989,20 @@ def forward(
987989
if input_ids is not None and inputs_embeds is not None:
988990
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
989991
elif input_ids is not None:
990-
input_shape = input_ids.size()
992+
input = input_ids
993+
input_shape = input.shape
991994
input_ids = input_ids.view(-1, input_shape[-1])
992995
elif inputs_embeds is not None:
993996
input_shape = inputs_embeds.size()[:-1]
997+
input = inputs_embeds[:, :, -1]
994998
else:
995999
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
9961000

9971001
# past_key_values_length
9981002
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
9991003

10001004
if inputs_embeds is None:
1001-
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1005+
inputs_embeds = self.embed_tokens(input) * self.embed_scale
10021006

10031007
attention_mask = self._prepare_decoder_attention_mask(
10041008
attention_mask, input_shape, inputs_embeds, past_key_values_length
@@ -1010,7 +1014,7 @@ def forward(
10101014
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
10111015

10121016
# embed positions
1013-
positions = self.embed_positions(input_shape, past_key_values_length)
1017+
positions = self.embed_positions(input, past_key_values_length)
10141018

10151019
hidden_states = inputs_embeds + positions
10161020
hidden_states = self.layernorm_embedding(hidden_states)

src/transformers/models/trocr/modeling_trocr.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
8787
self.offset = 2
8888
super().__init__(num_embeddings + self.offset, embedding_dim)
8989

90-
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
91-
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
92-
bsz, seq_len = input_ids_shape[:2]
90+
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
91+
"""`input_ids' shape is expected to be [bsz x seqlen]."""
92+
93+
bsz, seq_len = input_ids.shape[:2]
9394
positions = torch.arange(
9495
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
95-
)
96+
).expand(bsz, -1)
97+
9698
return super().forward(positions + self.offset)
9799

98100

@@ -626,10 +628,11 @@ def forward(
626628
if input_ids is not None and inputs_embeds is not None:
627629
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
628630
elif input_ids is not None:
629-
input_shape = input_ids.size()
630-
input_ids = input_ids.view(-1, input_shape[-1])
631+
input = input_ids
632+
input_ids = input_ids.view(-1, input.shape[-1])
631633
elif inputs_embeds is not None:
632634
input_shape = inputs_embeds.size()[:-1]
635+
input = inputs_embeds[:, :, -1]
633636
else:
634637
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
635638

@@ -640,7 +643,7 @@ def forward(
640643
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
641644

642645
if self.config.use_learned_position_embeddings:
643-
embed_pos = self.embed_positions(input_shape, past_key_values_length=past_key_values_length)
646+
embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length)
644647
else:
645648
embed_pos = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
646649

@@ -651,6 +654,8 @@ def forward(
651654

652655
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
653656

657+
input_shape = input.shape
658+
654659
attention_mask = self._prepare_decoder_attention_mask(
655660
attention_mask, input_shape, inputs_embeds, past_key_values_length
656661
)

0 commit comments

Comments
 (0)