@@ -102,29 +102,6 @@ def call(self, inputs: tf.Tensor, mask: tf.Tensor):
102
102
return output
103
103
104
104
105
- # Copied from transformers.models.deberta.modeling_tf_deberta.get_mask
106
- def get_mask (input , dropout ):
107
- mask = tf .cast (
108
- 1 - tf .compat .v1 .distributions .Bernoulli (probs = 1 - dropout ).sample (sample_shape = shape_list (input )), tf .bool
109
- )
110
- return mask , dropout
111
-
112
-
113
- @tf .custom_gradient
114
- # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaXDropout
115
- def TFDebertaV2XDropout (input , local_ctx ):
116
- mask , dropout = get_mask (input , local_ctx )
117
- scale = tf .convert_to_tensor (1.0 / (1 - dropout ), dtype = tf .float32 )
118
- input = tf .cond (dropout > 0 , lambda : tf .where (mask , 0.0 , input ) * scale , lambda : input )
119
-
120
- def custom_grad (upstream_grad ):
121
- return tf .cond (
122
- scale > 1 , lambda : (tf .where (mask , 0.0 , upstream_grad ) * scale , None ), lambda : (upstream_grad , None )
123
- )
124
-
125
- return input , custom_grad
126
-
127
-
128
105
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaStableDropout with Deberta->DebertaV2
129
106
class TFDebertaV2StableDropout (tf .keras .layers .Layer ):
130
107
"""
@@ -136,11 +113,33 @@ class TFDebertaV2StableDropout(tf.keras.layers.Layer):
136
113
137
114
def __init__ (self , drop_prob , ** kwargs ):
138
115
super ().__init__ (** kwargs )
139
- self .drop_prob = tf .convert_to_tensor (drop_prob , dtype = tf .float32 )
116
+ self .drop_prob = drop_prob
117
+
118
+ @tf .custom_gradient
119
+ def xdropout (self , inputs ):
120
+ """
121
+ Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
122
+ """
123
+ mask = tf .cast (
124
+ 1
125
+ - tf .compat .v1 .distributions .Bernoulli (probs = 1.0 - self .drop_prob ).sample (sample_shape = shape_list (inputs )),
126
+ tf .bool ,
127
+ )
128
+ scale = tf .convert_to_tensor (1.0 / (1 - self .drop_prob ), dtype = tf .float32 )
129
+ if self .drop_prob > 0 :
130
+ inputs = tf .where (mask , 0.0 , inputs ) * scale
131
+
132
+ def grad (upstream ):
133
+ if self .drop_prob > 0 :
134
+ return tf .where (mask , 0.0 , upstream ) * scale
135
+ else :
136
+ return upstream
137
+
138
+ return inputs , grad
140
139
141
140
def call (self , inputs : tf .Tensor , training : tf .Tensor = False ):
142
- if training and self . drop_prob > 0 :
143
- return TFDebertaV2XDropout ( inputs , self .drop_prob )
141
+ if training :
142
+ return self .xdropout ( inputs )
144
143
return inputs
145
144
146
145
@@ -525,10 +524,18 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
525
524
def take_along_axis (x , indices ):
526
525
# Only a valid port of np.take_along_axis when the gather axis is -1
527
526
528
- flat_x = tf .reshape (x , (- 1 , x .shape [- 1 ]))
529
- flat_indices = tf .reshape (indices , (- 1 , indices .shape [- 1 ]))
530
- gathered = tf .gather (flat_x , flat_indices , batch_dims = 1 )
531
- gathered = tf .reshape (gathered , indices .shape )
527
+ # TPU + gathers and reshapes don't go along well -- see https://github.com/huggingface/transformers/issues/18239
528
+ if isinstance (tf .distribute .get_strategy (), tf .distribute .TPUStrategy ):
529
+ # [B, S, P] -> [B, S, P, D]
530
+ one_hot_indices = tf .one_hot (indices , depth = x .shape [- 1 ], dtype = x .dtype )
531
+
532
+ # if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
533
+ # grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
534
+ gathered = tf .einsum ("ijkl,ijl->ijk" , one_hot_indices , x )
535
+
536
+ # GPUs, on the other hand, prefer gathers instead of large one-hot+matmuls
537
+ else :
538
+ gathered = tf .gather (x , indices , batch_dims = 2 )
532
539
533
540
return gathered
534
541
0 commit comments