Skip to content

Commit 6dc0a84

Browse files
Fix weight tying in TF-ESM (huggingface#22839)
Fix weight tying in ESM
1 parent 3b61d28 commit 6dc0a84

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

src/transformers/models/esm/modeling_tf_esm.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
""" PyTorch ESM model."""
1616

17+
import os
1718
from typing import Optional, Tuple, Union
1819

1920
import numpy as np
@@ -1102,6 +1103,11 @@ def __init__(self, config):
11021103

11031104
self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
11041105
self.lm_head = TFEsmLMHead(config, name="lm_head")
1106+
if config.tie_word_embeddings:
1107+
# Ensure word embeddings are built so that we actually have something to tie
1108+
with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")):
1109+
self.esm.embeddings.word_embeddings.build((None, None))
1110+
self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0]
11051111

11061112
def get_output_embeddings(self):
11071113
return self.lm_head.decoder
@@ -1211,18 +1217,22 @@ def __init__(self, config, name=None):
12111217

12121218
self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
12131219

1214-
self.decoder = Dense(
1215-
config.vocab_size,
1216-
use_bias=False,
1217-
kernel_initializer=get_initializer(config.initializer_range),
1218-
name="decoder",
1219-
)
1220+
self.decoder = None
12201221
self.config = config
12211222

12221223
def build(self, input_shape):
12231224
super().build(input_shape)
12241225
# Separate bias to match the PT model and allow weight cross-loading to work
12251226
# Put it in the build so it gets the right name when adding it as a weight
1227+
if not self.config.tie_word_embeddings:
1228+
if self.decoder is not None:
1229+
raise ValueError("Expected decoder not to be initialized before build when not tying weights!")
1230+
self.decoder = self.add_weight(
1231+
"decoder.weight",
1232+
shape=(self.config.hidden_size, self.config.vocab_size),
1233+
initializer=get_initializer(self.config.initializer_range),
1234+
trainable=True,
1235+
)
12261236
self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
12271237

12281238
def get_bias(self):
@@ -1234,8 +1244,7 @@ def call(self, features):
12341244
x = self.layer_norm(x)
12351245

12361246
# project back to size of vocabulary with bias
1237-
x = self.decoder(x)
1238-
x = x + self.bias
1247+
x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
12391248
return x
12401249

12411250

tests/models/esm/test_modeling_tf_esm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,24 @@ def test_resize_token_embeddings(self):
262262
def test_save_load_after_resize_token_embeddings(self):
263263
pass
264264

265+
def test_model_common_attributes(self):
266+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
267+
268+
for model_class in self.all_model_classes:
269+
model = model_class(config)
270+
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
271+
if model_class is TFEsmForMaskedLM:
272+
# Output embedding test differs from the main test because they're a matrix, not a layer
273+
name = model.get_bias()
274+
assert isinstance(name, dict)
275+
for k, v in name.items():
276+
assert isinstance(v, tf.Variable)
277+
else:
278+
x = model.get_output_embeddings()
279+
assert x is None
280+
name = model.get_bias()
281+
assert name is None
282+
265283

266284
@require_tf
267285
class TFEsmModelIntegrationTest(unittest.TestCase):

0 commit comments

Comments
 (0)