Skip to content

Commit ca3833e

Browse files
ganteamyeroberts
authored andcommitted
TF: XLA-trainable DeBERTa v2 (huggingface#18546)
* fix deberta issues * add different code paths for gpu and tpu * shorter gpu take along axis * Stable Dropout without tf cond * variable must be float
1 parent 09f36ba commit ca3833e

File tree

2 files changed

+62
-54
lines changed

2 files changed

+62
-54
lines changed

src/transformers/models/deberta/modeling_tf_deberta.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -101,27 +101,6 @@ def call(self, inputs: tf.Tensor, mask: tf.Tensor):
101101
return output
102102

103103

104-
def get_mask(input, dropout):
105-
mask = tf.cast(
106-
1 - tf.compat.v1.distributions.Bernoulli(probs=1 - dropout).sample(sample_shape=shape_list(input)), tf.bool
107-
)
108-
return mask, dropout
109-
110-
111-
@tf.custom_gradient
112-
def TFDebertaXDropout(input, local_ctx):
113-
mask, dropout = get_mask(input, local_ctx)
114-
scale = tf.convert_to_tensor(1.0 / (1 - dropout), dtype=tf.float32)
115-
input = tf.cond(dropout > 0, lambda: tf.where(mask, 0.0, input) * scale, lambda: input)
116-
117-
def custom_grad(upstream_grad):
118-
return tf.cond(
119-
scale > 1, lambda: (tf.where(mask, 0.0, upstream_grad) * scale, None), lambda: (upstream_grad, None)
120-
)
121-
122-
return input, custom_grad
123-
124-
125104
class TFDebertaStableDropout(tf.keras.layers.Layer):
126105
"""
127106
Optimized dropout module for stabilizing the training
@@ -132,11 +111,33 @@ class TFDebertaStableDropout(tf.keras.layers.Layer):
132111

133112
def __init__(self, drop_prob, **kwargs):
134113
super().__init__(**kwargs)
135-
self.drop_prob = tf.convert_to_tensor(drop_prob, dtype=tf.float32)
114+
self.drop_prob = drop_prob
115+
116+
@tf.custom_gradient
117+
def xdropout(self, inputs):
118+
"""
119+
Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
120+
"""
121+
mask = tf.cast(
122+
1
123+
- tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)),
124+
tf.bool,
125+
)
126+
scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=tf.float32)
127+
if self.drop_prob > 0:
128+
inputs = tf.where(mask, 0.0, inputs) * scale
129+
130+
def grad(upstream):
131+
if self.drop_prob > 0:
132+
return tf.where(mask, 0.0, upstream) * scale
133+
else:
134+
return upstream
135+
136+
return inputs, grad
136137

137138
def call(self, inputs: tf.Tensor, training: tf.Tensor = False):
138-
if training and self.drop_prob > 0:
139-
return TFDebertaXDropout(inputs, self.drop_prob)
139+
if training:
140+
return self.xdropout(inputs)
140141
return inputs
141142

142143

src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -102,29 +102,6 @@ def call(self, inputs: tf.Tensor, mask: tf.Tensor):
102102
return output
103103

104104

105-
# Copied from transformers.models.deberta.modeling_tf_deberta.get_mask
106-
def get_mask(input, dropout):
107-
mask = tf.cast(
108-
1 - tf.compat.v1.distributions.Bernoulli(probs=1 - dropout).sample(sample_shape=shape_list(input)), tf.bool
109-
)
110-
return mask, dropout
111-
112-
113-
@tf.custom_gradient
114-
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaXDropout
115-
def TFDebertaV2XDropout(input, local_ctx):
116-
mask, dropout = get_mask(input, local_ctx)
117-
scale = tf.convert_to_tensor(1.0 / (1 - dropout), dtype=tf.float32)
118-
input = tf.cond(dropout > 0, lambda: tf.where(mask, 0.0, input) * scale, lambda: input)
119-
120-
def custom_grad(upstream_grad):
121-
return tf.cond(
122-
scale > 1, lambda: (tf.where(mask, 0.0, upstream_grad) * scale, None), lambda: (upstream_grad, None)
123-
)
124-
125-
return input, custom_grad
126-
127-
128105
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaStableDropout with Deberta->DebertaV2
129106
class TFDebertaV2StableDropout(tf.keras.layers.Layer):
130107
"""
@@ -136,11 +113,33 @@ class TFDebertaV2StableDropout(tf.keras.layers.Layer):
136113

137114
def __init__(self, drop_prob, **kwargs):
138115
super().__init__(**kwargs)
139-
self.drop_prob = tf.convert_to_tensor(drop_prob, dtype=tf.float32)
116+
self.drop_prob = drop_prob
117+
118+
@tf.custom_gradient
119+
def xdropout(self, inputs):
120+
"""
121+
Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
122+
"""
123+
mask = tf.cast(
124+
1
125+
- tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)),
126+
tf.bool,
127+
)
128+
scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=tf.float32)
129+
if self.drop_prob > 0:
130+
inputs = tf.where(mask, 0.0, inputs) * scale
131+
132+
def grad(upstream):
133+
if self.drop_prob > 0:
134+
return tf.where(mask, 0.0, upstream) * scale
135+
else:
136+
return upstream
137+
138+
return inputs, grad
140139

141140
def call(self, inputs: tf.Tensor, training: tf.Tensor = False):
142-
if training and self.drop_prob > 0:
143-
return TFDebertaV2XDropout(inputs, self.drop_prob)
141+
if training:
142+
return self.xdropout(inputs)
144143
return inputs
145144

146145

@@ -525,10 +524,18 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
525524
def take_along_axis(x, indices):
526525
# Only a valid port of np.take_along_axis when the gather axis is -1
527526

528-
flat_x = tf.reshape(x, (-1, x.shape[-1]))
529-
flat_indices = tf.reshape(indices, (-1, indices.shape[-1]))
530-
gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
531-
gathered = tf.reshape(gathered, indices.shape)
527+
# TPU + gathers and reshapes don't go along well -- see https://github.com/huggingface/transformers/issues/18239
528+
if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy):
529+
# [B, S, P] -> [B, S, P, D]
530+
one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype)
531+
532+
# if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
533+
# grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
534+
gathered = tf.einsum("ijkl,ijl->ijk", one_hot_indices, x)
535+
536+
# GPUs, on the other hand, prefer gathers instead of large one-hot+matmuls
537+
else:
538+
gathered = tf.gather(x, indices, batch_dims=2)
532539

533540
return gathered
534541

0 commit comments

Comments
 (0)