Skip to content

Commit 6d5bc2a

Browse files
aboSamoorFlax Authors
authored andcommitted
Support multi output layers in the Sequential combinator
PiperOrigin-RevId: 495905245
1 parent 4f24933 commit 6d5bc2a

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

flax/linen/combinators.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Combinators of modules, such as a Sequential."""
1616

17-
from typing import Any, Callable, Sequence
17+
from typing import Any, Callable, Dict, Sequence
1818

1919
from flax.linen.module import Module
2020

@@ -42,6 +42,31 @@ def __call__(self, x):
4242
nn.relu,
4343
nn.Dense(2),
4444
nn.log_softmax])(x)
45+
46+
This combinator supports also layers that return multiple outputs if returned
47+
as a tuple or a dictionary.
48+
49+
Example usage::
50+
51+
class CrossAttentionBlock(nn.Module):
52+
num_heads: int = 2
53+
qkv_features: int = 16
54+
55+
@nn.compact
56+
def __call__(self, query, key_value):
57+
output = nn.MultiHeadDotProductAttention(
58+
num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
59+
key_value)
60+
output = nn.Dense(self.qkv_features)(output)
61+
return dict(query=output, key_value=key_value) # also works for tuples
62+
63+
class CrossAttentionNetwork(nn.Module):
64+
num_layers: Sequence[int]
65+
66+
@nn.compact
67+
def __call__(self, x):
68+
return nn.Sequential([CrossAttentionBlock() for _ in
69+
range(self.num_layers)])(query, key_value)
4570
"""
4671
layers: Sequence[Callable[..., Any]]
4772

@@ -51,5 +76,10 @@ def __call__(self, *args, **kwargs):
5176

5277
outputs = self.layers[0](*args, **kwargs)
5378
for layer in self.layers[1:]:
54-
outputs = layer(outputs)
79+
if isinstance(outputs, tuple):
80+
outputs = layer(*outputs)
81+
elif isinstance(outputs, Dict):
82+
outputs = layer(**outputs)
83+
else:
84+
outputs = layer(outputs)
5585
return outputs

tests/linen/linen_combinators_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,30 @@ def __call__(self, inputs):
4848
return self.activation_final(x)
4949

5050

51+
class AttentionTuple(nn.Module):
52+
num_heads: int = 2
53+
qkv_features: int = 16
54+
55+
@nn.compact
56+
def __call__(self, query, key_value):
57+
output = nn.MultiHeadDotProductAttention(
58+
num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
59+
key_value)
60+
return output, key_value
61+
62+
63+
class AttentionDict(nn.Module):
64+
num_heads: int = 2
65+
qkv_features: int = 16
66+
67+
@nn.compact
68+
def __call__(self, query, key_value):
69+
output = nn.MultiHeadDotProductAttention(
70+
num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
71+
key_value)
72+
return dict(query=output, key_value=key_value)
73+
74+
5175
class SequentialTest(absltest.TestCase):
5276

5377
def test_construction(self):
@@ -103,5 +127,38 @@ def test_same_output_as_mlp_with_activation(self):
103127
np.testing.assert_array_equal(output_1, output_2)
104128

105129

130+
def test_tuple_output(self):
131+
sequential = nn.Sequential([
132+
AttentionTuple(),
133+
AttentionTuple(),
134+
])
135+
136+
key1, key2 = random.split(random.PRNGKey(0), 2)
137+
query = random.uniform(key1, (3, 5))
138+
key_value = random.uniform(key1, (9, 5))
139+
params_1 = sequential.init(key2, query, key_value)
140+
outputs = sequential.apply(params_1, query, key_value)
141+
np.testing.assert_equal(len(outputs), 2)
142+
out_query, out_key_value = outputs
143+
np.testing.assert_equal(out_query.shape, (3, 5))
144+
np.testing.assert_equal(out_key_value.shape, (9, 5))
145+
146+
def test_dict_output(self):
147+
sequential = nn.Sequential([
148+
AttentionDict(),
149+
AttentionDict(),
150+
])
151+
152+
key1, key2 = random.split(random.PRNGKey(0), 2)
153+
query = random.uniform(key1, (3, 5))
154+
key_value = random.uniform(key1, (9, 5))
155+
params_1 = sequential.init(key2, query, key_value)
156+
outputs = sequential.apply(params_1, query, key_value)
157+
np.testing.assert_equal(len(outputs), 2)
158+
out_query, out_key_value = outputs['query'], outputs['key_value']
159+
np.testing.assert_equal(out_query.shape, (3, 5))
160+
np.testing.assert_equal(out_key_value.shape, (9, 5))
161+
162+
106163
if __name__ == '__main__':
107164
absltest.main()

0 commit comments

Comments
 (0)