@@ -48,6 +48,30 @@ def __call__(self, inputs):
48
48
return self .activation_final (x )
49
49
50
50
51
+ class AttentionTuple (nn .Module ):
52
+ num_heads : int = 2
53
+ qkv_features : int = 16
54
+
55
+ @nn .compact
56
+ def __call__ (self , query , key_value ):
57
+ output = nn .MultiHeadDotProductAttention (
58
+ num_heads = self .num_heads , qkv_features = self .qkv_features )(query ,
59
+ key_value )
60
+ return output , key_value
61
+
62
+
63
+ class AttentionDict (nn .Module ):
64
+ num_heads : int = 2
65
+ qkv_features : int = 16
66
+
67
+ @nn .compact
68
+ def __call__ (self , query , key_value ):
69
+ output = nn .MultiHeadDotProductAttention (
70
+ num_heads = self .num_heads , qkv_features = self .qkv_features )(query ,
71
+ key_value )
72
+ return dict (query = output , key_value = key_value )
73
+
74
+
51
75
class SequentialTest (absltest .TestCase ):
52
76
53
77
def test_construction (self ):
@@ -103,5 +127,38 @@ def test_same_output_as_mlp_with_activation(self):
103
127
np .testing .assert_array_equal (output_1 , output_2 )
104
128
105
129
130
+ def test_tuple_output (self ):
131
+ sequential = nn .Sequential ([
132
+ AttentionTuple (),
133
+ AttentionTuple (),
134
+ ])
135
+
136
+ key1 , key2 = random .split (random .PRNGKey (0 ), 2 )
137
+ query = random .uniform (key1 , (3 , 5 ))
138
+ key_value = random .uniform (key1 , (9 , 5 ))
139
+ params_1 = sequential .init (key2 , query , key_value )
140
+ outputs = sequential .apply (params_1 , query , key_value )
141
+ np .testing .assert_equal (len (outputs ), 2 )
142
+ out_query , out_key_value = outputs
143
+ np .testing .assert_equal (out_query .shape , (3 , 5 ))
144
+ np .testing .assert_equal (out_key_value .shape , (9 , 5 ))
145
+
146
+ def test_dict_output (self ):
147
+ sequential = nn .Sequential ([
148
+ AttentionDict (),
149
+ AttentionDict (),
150
+ ])
151
+
152
+ key1 , key2 = random .split (random .PRNGKey (0 ), 2 )
153
+ query = random .uniform (key1 , (3 , 5 ))
154
+ key_value = random .uniform (key1 , (9 , 5 ))
155
+ params_1 = sequential .init (key2 , query , key_value )
156
+ outputs = sequential .apply (params_1 , query , key_value )
157
+ np .testing .assert_equal (len (outputs ), 2 )
158
+ out_query , out_key_value = outputs ['query' ], outputs ['key_value' ]
159
+ np .testing .assert_equal (out_query .shape , (3 , 5 ))
160
+ np .testing .assert_equal (out_key_value .shape , (9 , 5 ))
161
+
162
+
106
163
if __name__ == '__main__' :
107
164
absltest .main ()
0 commit comments