Skip to content

Commit 50cd169

Browse files
author
Flax Authors
committed
Merge pull request #3529 from chiamp:attention
PiperOrigin-RevId: 588392463
2 parents 3687af0 + 9d989b0 commit 50cd169

File tree

2 files changed

+99
-12
lines changed

2 files changed

+99
-12
lines changed

flax/linen/attention.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def dot_product_attention_weights(
5050
deterministic: bool = False,
5151
dtype: Optional[Dtype] = None,
5252
precision: PrecisionLike = None,
53+
module: Optional[Module] = None,
5354
):
5455
"""Computes dot-product attention weights given query and key.
5556
@@ -76,6 +77,10 @@ def dot_product_attention_weights(
7677
dtype: the dtype of the computation (default: infer from inputs and params)
7778
precision: numerical precision of the computation see `jax.lax.Precision`
7879
for details.
80+
module: the Module that will sow the attention weights into the
81+
'intermediates' collection. Remember to mark 'intermediates' as mutable via
82+
`mutable=['intermediates'] in order to have that collection returned.
83+
If `module` is None, the attention weights will not be sowed.
7984
8085
Returns:
8186
Output of shape `[batch..., num_heads, q_length, kv_length]`.
@@ -107,6 +112,9 @@ def dot_product_attention_weights(
107112
# normalize the attention weights
108113
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
109114

115+
if module:
116+
module.sow('intermediates', 'attention_weights', attn_weights)
117+
110118
# apply attention dropout
111119
if not deterministic and dropout_rate > 0.0:
112120
keep_prob = 1.0 - dropout_rate
@@ -134,6 +142,7 @@ def dot_product_attention(
134142
deterministic: bool = False,
135143
dtype: Optional[Dtype] = None,
136144
precision: PrecisionLike = None,
145+
module: Optional[Module] = None,
137146
):
138147
"""Computes dot-product attention given query, key, and value.
139148
@@ -164,6 +173,10 @@ def dot_product_attention(
164173
dtype: the dtype of the computation (default: infer from inputs)
165174
precision: numerical precision of the computation see `jax.lax.Precision`
166175
for details.
176+
module: the Module that will sow the attention weights into the
177+
'intermediates' collection. Remember to mark 'intermediates' as mutable via
178+
`mutable=['intermediates'] in order to have that collection returned.
179+
If `module` is None, the attention weights will not be sowed.
167180
168181
Returns:
169182
Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
@@ -191,6 +204,7 @@ def dot_product_attention(
191204
deterministic,
192205
dtype,
193206
precision,
207+
module,
194208
)
195209

196210
# return weighted sum over values for each query position
@@ -306,6 +320,7 @@ def __call__(
306320
mask: Optional[Array] = None,
307321
deterministic: Optional[bool] = None,
308322
dropout_rng: Optional[PRNGKey] = None,
323+
return_weights: bool = False,
309324
):
310325
...
311326

@@ -318,6 +333,7 @@ def __call__(
318333
mask: Optional[Array] = None,
319334
deterministic: Optional[bool] = None,
320335
dropout_rng: Optional[PRNGKey] = None,
336+
return_weights: bool = False,
321337
):
322338
...
323339

@@ -332,6 +348,7 @@ def __call__(
332348
mask: Optional[Array] = None,
333349
deterministic: Optional[bool] = None,
334350
dropout_rng: Optional[PRNGKey] = None,
351+
return_weights: bool = False,
335352
):
336353
"""Applies multi-head dot product attention on the input data.
337354
@@ -358,6 +375,10 @@ def __call__(
358375
dropout, whereas if true, the attention weights are deterministic.
359376
dropout_rng: optional rng key to pass to the attention layer's dropout
360377
mask. Otherwise, self.make_rng('dropout') is used instead.
378+
return_weights: if `True`, the attention weights are sowed into the
379+
'intermediates' collection. Remember to mark 'intermediates' as
380+
mutable via `mutable=['intermediates'] in order to have that
381+
collection returned.
361382
362383
Returns:
363384
output of shape `[batch_sizes..., length, features]`.
@@ -506,18 +527,33 @@ def __call__(
506527
m_deterministic = True
507528

508529
# apply attention
509-
x = self.attention_fn(
510-
query,
511-
key,
512-
value,
513-
mask=mask,
514-
dropout_rng=dropout_rng,
515-
dropout_rate=self.dropout_rate,
516-
broadcast_dropout=self.broadcast_dropout,
517-
deterministic=m_deterministic,
518-
dtype=self.dtype,
519-
precision=self.precision,
520-
) # pytype: disable=wrong-keyword-args
530+
if return_weights:
531+
x = self.attention_fn(
532+
query,
533+
key,
534+
value,
535+
mask=mask,
536+
dropout_rng=dropout_rng,
537+
dropout_rate=self.dropout_rate,
538+
broadcast_dropout=self.broadcast_dropout,
539+
deterministic=m_deterministic,
540+
dtype=self.dtype,
541+
precision=self.precision,
542+
module=self if return_weights else None,
543+
) # pytype: disable=wrong-keyword-args
544+
else:
545+
x = self.attention_fn(
546+
query,
547+
key,
548+
value,
549+
mask=mask,
550+
dropout_rng=dropout_rng,
551+
dropout_rate=self.dropout_rate,
552+
broadcast_dropout=self.broadcast_dropout,
553+
deterministic=m_deterministic,
554+
dtype=self.dtype,
555+
precision=self.precision,
556+
)
521557
# back to the original inputs dimensions
522558
out = DenseGeneral(
523559
features=features,

tests/linen/linen_attention_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,57 @@ def test_multihead_mask_warning(self):
343343
with self.assertRaises(errors.ScopeParamShapeError):
344344
module.apply(initial_vars, query, key, causal_mask)
345345

346+
def test_multihead_sow_attention_weights(self):
347+
rng = random.key(0)
348+
x = jnp.ones((4, 6, 5))
349+
350+
class Model(nn.Module):
351+
attention_kwargs: dict
352+
353+
@nn.compact
354+
def __call__(self, x, return_weights=False):
355+
x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(
356+
x, return_weights=return_weights
357+
)
358+
x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x)
359+
x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(
360+
x, return_weights=return_weights
361+
)
362+
return x
363+
364+
module = Model(
365+
dict(
366+
num_heads=8,
367+
qkv_features=16,
368+
kernel_init=initializers.ones,
369+
bias_init=initializers.zeros,
370+
deterministic=False,
371+
)
372+
)
373+
v = module.init(rng, x)
374+
_, intermediates = module.apply(
375+
v, x, mutable=['intermediates'], return_weights=True
376+
)
377+
self.assertEqual(
378+
intermediates['intermediates']['MultiHeadDotProductAttention_0'][
379+
'attention_weights'
380+
][0].shape,
381+
(4, 8, 6, 6),
382+
)
383+
self.assertNotIn(
384+
'MultiHeadDotProductAttention_1', intermediates['intermediates']
385+
)
386+
self.assertEqual(
387+
intermediates['intermediates']['MultiHeadDotProductAttention_2'][
388+
'attention_weights'
389+
][0].shape,
390+
(4, 8, 6, 6),
391+
)
392+
_, intermediates = module.apply(
393+
v, x, mutable=['intermediates'], return_weights=False
394+
)
395+
self.assertNotIn('intermediates', intermediates)
396+
346397

347398
if __name__ == '__main__':
348399
absltest.main()

0 commit comments

Comments
 (0)