Skip to content

Commit 6d85aac

Browse files
author
Flax Authors
committed
Merge pull request #2578 from mathisgerdes:transpose-kernel
PiperOrigin-RevId: 504365751
2 parents 6bccee3 + bbf7856 commit 6d85aac

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

flax/linen/linear.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,8 @@ class ConvTranspose(Module):
581581
for details.
582582
kernel_init: initializer for the convolutional kernel.
583583
bias_init: initializer for the bias.
584+
transpose_kernel: if True flips spatial axes and swaps the input/output
585+
channel axes of the kernel.
584586
"""
585587
features: int
586588
kernel_size: Union[int, Tuple[int, ...]]
@@ -594,6 +596,7 @@ class ConvTranspose(Module):
594596
precision: PrecisionLike = None
595597
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
596598
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
599+
transpose_kernel: bool = False
597600

598601
@compact
599602
def __call__(self, inputs: Array) -> Array:
@@ -636,7 +639,10 @@ def __call__(self, inputs: Array) -> Array:
636639
strides = self.strides or (1,) * (inputs.ndim - 2)
637640

638641
in_features = jnp.shape(inputs)[-1]
639-
kernel_shape = kernel_size + (in_features, self.features)
642+
if self.transpose_kernel:
643+
kernel_shape = kernel_size + (self.features, in_features)
644+
else:
645+
kernel_shape = kernel_size + (in_features, self.features)
640646

641647
if self.mask is not None and self.mask.shape != kernel_shape:
642648
raise ValueError('Mask needs to have the same shape as weights. '
@@ -667,6 +673,7 @@ def __call__(self, inputs: Array) -> Array:
667673
strides,
668674
padding_lax,
669675
rhs_dilation=self.kernel_dilation,
676+
transpose_kernel=self.transpose_kernel,
670677
precision=self.precision)
671678

672679
if self.padding == 'CIRCULAR':
@@ -689,12 +696,20 @@ def __call__(self, inputs: Array) -> Array:
689696
-(y_dim - x_dim) % (2 * x_dim)
690697
for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims)
691698
]
692-
# Divide the padding equaly between left and right. The choice to put
693-
# "+1" on the left (and not on the right) represents a convention for
694-
# aligning even-sized kernels.
695-
total_pad = [
696-
((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs
697-
]
699+
if self.transpose_kernel:
700+
# If the kernel is transposed, the "+1" is put on the right to
701+
# mirror the regular convolution. If the same kernel parameters are used
702+
# as for Conv, this layer then computes the proper transpose convolution.
703+
total_pad = [
704+
(size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs
705+
]
706+
else:
707+
# Divide the padding equally between left and right. The choice to put
708+
# "+1" on the left (and not on the right) represents a convention for
709+
# aligning even-sized kernels.
710+
total_pad = [
711+
((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs
712+
]
698713
y = jnp.pad(y, [(0, 0)] + total_pad + [(0, 0)])
699714
# Wrap the result periodically around each spatial dimension,
700715
# one by one.

tests/linen/linen_linear_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,23 @@ def test_circular_conv_transpose_2d_custom_bias(self):
929929
correct_ans = np.expand_dims(correct_ans, (0, 3))
930930
np.testing.assert_allclose(y, correct_ans)
931931

932+
@parameterized.product(
933+
use_bias=(True, False))
934+
def test_transpose_kernel_conv_transpose(self, use_bias):
935+
rng = dict(params=random.PRNGKey(0))
936+
x = jnp.ones((1, 15, 15, 3))
937+
conv_module = nn.ConvTranspose(
938+
features=4,
939+
use_bias=use_bias,
940+
strides=(2, 2),
941+
kernel_size=(6, 6),
942+
padding='CIRCULAR',
943+
transpose_kernel=True,
944+
)
945+
y, initial_params = conv_module.init_with_output(rng, x)
946+
self.assertEqual(initial_params['params']['kernel'].shape, (6, 6, 4, 3))
947+
self.assertEqual(y.shape, (1, 30, 30, 4))
948+
932949
@parameterized.product(
933950
module=(nn.Conv, nn.ConvLocal)
934951
)

0 commit comments

Comments
 (0)