14
14
# limitations under the License.
15
15
""" PyTorch ESM model."""
16
16
17
+ import os
17
18
from typing import Optional , Tuple , Union
18
19
19
20
import numpy as np
@@ -1102,6 +1103,11 @@ def __init__(self, config):
1102
1103
1103
1104
self .esm = TFEsmMainLayer (config , add_pooling_layer = False , name = "esm" )
1104
1105
self .lm_head = TFEsmLMHead (config , name = "lm_head" )
1106
+ if config .tie_word_embeddings :
1107
+ # Ensure word embeddings are built so that we actually have something to tie
1108
+ with tf .name_scope (os .path .join (self ._name_scope (), "esm" , "embeddings" , "word_embeddings" )):
1109
+ self .esm .embeddings .word_embeddings .build ((None , None ))
1110
+ self .lm_head .decoder = self .esm .embeddings .word_embeddings .weights [0 ]
1105
1111
1106
1112
def get_output_embeddings (self ):
1107
1113
return self .lm_head .decoder
@@ -1211,18 +1217,22 @@ def __init__(self, config, name=None):
1211
1217
1212
1218
self .layer_norm = LayerNormalization (epsilon = config .layer_norm_eps , name = "layer_norm" )
1213
1219
1214
- self .decoder = Dense (
1215
- config .vocab_size ,
1216
- use_bias = False ,
1217
- kernel_initializer = get_initializer (config .initializer_range ),
1218
- name = "decoder" ,
1219
- )
1220
+ self .decoder = None
1220
1221
self .config = config
1221
1222
1222
1223
def build (self , input_shape ):
1223
1224
super ().build (input_shape )
1224
1225
# Separate bias to match the PT model and allow weight cross-loading to work
1225
1226
# Put it in the build so it gets the right name when adding it as a weight
1227
+ if not self .config .tie_word_embeddings :
1228
+ if self .decoder is not None :
1229
+ raise ValueError ("Expected decoder not to be initialized before build when not tying weights!" )
1230
+ self .decoder = self .add_weight (
1231
+ "decoder.weight" ,
1232
+ shape = (self .config .hidden_size , self .config .vocab_size ),
1233
+ initializer = get_initializer (self .config .initializer_range ),
1234
+ trainable = True ,
1235
+ )
1226
1236
self .bias = self .add_weight ("bias" , shape = (self .config .vocab_size ,), initializer = "zeros" , trainable = True )
1227
1237
1228
1238
def get_bias (self ):
@@ -1234,8 +1244,7 @@ def call(self, features):
1234
1244
x = self .layer_norm (x )
1235
1245
1236
1246
# project back to size of vocabulary with bias
1237
- x = self .decoder (x )
1238
- x = x + self .bias
1247
+ x = tf .matmul (x , self .decoder , transpose_b = True ) + self .bias
1239
1248
return x
1240
1249
1241
1250
0 commit comments