@@ -50,6 +50,7 @@ def dot_product_attention_weights(
50
50
deterministic : bool = False ,
51
51
dtype : Optional [Dtype ] = None ,
52
52
precision : PrecisionLike = None ,
53
+ module : Optional [Module ] = None ,
53
54
):
54
55
"""Computes dot-product attention weights given query and key.
55
56
@@ -76,6 +77,10 @@ def dot_product_attention_weights(
76
77
dtype: the dtype of the computation (default: infer from inputs and params)
77
78
precision: numerical precision of the computation see `jax.lax.Precision`
78
79
for details.
80
+ module: the Module that will sow the attention weights into the
81
+ 'intermediates' collection. Remember to mark 'intermediates' as mutable via
82
+ `mutable=['intermediates'] in order to have that collection returned.
83
+ If `module` is None, the attention weights will not be sowed.
79
84
80
85
Returns:
81
86
Output of shape `[batch..., num_heads, q_length, kv_length]`.
@@ -107,6 +112,9 @@ def dot_product_attention_weights(
107
112
# normalize the attention weights
108
113
attn_weights = jax .nn .softmax (attn_weights ).astype (dtype )
109
114
115
+ if module :
116
+ module .sow ('intermediates' , 'attention_weights' , attn_weights )
117
+
110
118
# apply attention dropout
111
119
if not deterministic and dropout_rate > 0.0 :
112
120
keep_prob = 1.0 - dropout_rate
@@ -134,6 +142,7 @@ def dot_product_attention(
134
142
deterministic : bool = False ,
135
143
dtype : Optional [Dtype ] = None ,
136
144
precision : PrecisionLike = None ,
145
+ module : Optional [Module ] = None ,
137
146
):
138
147
"""Computes dot-product attention given query, key, and value.
139
148
@@ -164,6 +173,10 @@ def dot_product_attention(
164
173
dtype: the dtype of the computation (default: infer from inputs)
165
174
precision: numerical precision of the computation see `jax.lax.Precision`
166
175
for details.
176
+ module: the Module that will sow the attention weights into the
177
+ 'intermediates' collection. Remember to mark 'intermediates' as mutable via
178
+ `mutable=['intermediates'] in order to have that collection returned.
179
+ If `module` is None, the attention weights will not be sowed.
167
180
168
181
Returns:
169
182
Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
@@ -191,6 +204,7 @@ def dot_product_attention(
191
204
deterministic ,
192
205
dtype ,
193
206
precision ,
207
+ module ,
194
208
)
195
209
196
210
# return weighted sum over values for each query position
@@ -306,6 +320,7 @@ def __call__(
306
320
mask : Optional [Array ] = None ,
307
321
deterministic : Optional [bool ] = None ,
308
322
dropout_rng : Optional [PRNGKey ] = None ,
323
+ return_weights : bool = False ,
309
324
):
310
325
...
311
326
@@ -318,6 +333,7 @@ def __call__(
318
333
mask : Optional [Array ] = None ,
319
334
deterministic : Optional [bool ] = None ,
320
335
dropout_rng : Optional [PRNGKey ] = None ,
336
+ return_weights : bool = False ,
321
337
):
322
338
...
323
339
@@ -332,6 +348,7 @@ def __call__(
332
348
mask : Optional [Array ] = None ,
333
349
deterministic : Optional [bool ] = None ,
334
350
dropout_rng : Optional [PRNGKey ] = None ,
351
+ return_weights : bool = False ,
335
352
):
336
353
"""Applies multi-head dot product attention on the input data.
337
354
@@ -358,6 +375,10 @@ def __call__(
358
375
dropout, whereas if true, the attention weights are deterministic.
359
376
dropout_rng: optional rng key to pass to the attention layer's dropout
360
377
mask. Otherwise, self.make_rng('dropout') is used instead.
378
+ return_weights: if `True`, the attention weights are sowed into the
379
+ 'intermediates' collection. Remember to mark 'intermediates' as
380
+ mutable via `mutable=['intermediates'] in order to have that
381
+ collection returned.
361
382
362
383
Returns:
363
384
output of shape `[batch_sizes..., length, features]`.
@@ -506,18 +527,33 @@ def __call__(
506
527
m_deterministic = True
507
528
508
529
# apply attention
509
- x = self .attention_fn (
510
- query ,
511
- key ,
512
- value ,
513
- mask = mask ,
514
- dropout_rng = dropout_rng ,
515
- dropout_rate = self .dropout_rate ,
516
- broadcast_dropout = self .broadcast_dropout ,
517
- deterministic = m_deterministic ,
518
- dtype = self .dtype ,
519
- precision = self .precision ,
520
- ) # pytype: disable=wrong-keyword-args
530
+ if return_weights :
531
+ x = self .attention_fn (
532
+ query ,
533
+ key ,
534
+ value ,
535
+ mask = mask ,
536
+ dropout_rng = dropout_rng ,
537
+ dropout_rate = self .dropout_rate ,
538
+ broadcast_dropout = self .broadcast_dropout ,
539
+ deterministic = m_deterministic ,
540
+ dtype = self .dtype ,
541
+ precision = self .precision ,
542
+ module = self if return_weights else None ,
543
+ ) # pytype: disable=wrong-keyword-args
544
+ else :
545
+ x = self .attention_fn (
546
+ query ,
547
+ key ,
548
+ value ,
549
+ mask = mask ,
550
+ dropout_rng = dropout_rng ,
551
+ dropout_rate = self .dropout_rate ,
552
+ broadcast_dropout = self .broadcast_dropout ,
553
+ deterministic = m_deterministic ,
554
+ dtype = self .dtype ,
555
+ precision = self .precision ,
556
+ )
521
557
# back to the original inputs dimensions
522
558
out = DenseGeneral (
523
559
features = features ,
0 commit comments