@@ -581,6 +581,8 @@ class ConvTranspose(Module):
581
581
for details.
582
582
kernel_init: initializer for the convolutional kernel.
583
583
bias_init: initializer for the bias.
584
+ transpose_kernel: if True flips spatial axes and swaps the input/output
585
+ channel axes of the kernel.
584
586
"""
585
587
features : int
586
588
kernel_size : Union [int , Tuple [int , ...]]
@@ -594,6 +596,7 @@ class ConvTranspose(Module):
594
596
precision : PrecisionLike = None
595
597
kernel_init : Callable [[PRNGKey , Shape , Dtype ], Array ] = default_kernel_init
596
598
bias_init : Callable [[PRNGKey , Shape , Dtype ], Array ] = zeros
599
+ transpose_kernel : bool = False
597
600
598
601
@compact
599
602
def __call__ (self , inputs : Array ) -> Array :
@@ -636,7 +639,10 @@ def __call__(self, inputs: Array) -> Array:
636
639
strides = self .strides or (1 ,) * (inputs .ndim - 2 )
637
640
638
641
in_features = jnp .shape (inputs )[- 1 ]
639
- kernel_shape = kernel_size + (in_features , self .features )
642
+ if self .transpose_kernel :
643
+ kernel_shape = kernel_size + (self .features , in_features )
644
+ else :
645
+ kernel_shape = kernel_size + (in_features , self .features )
640
646
641
647
if self .mask is not None and self .mask .shape != kernel_shape :
642
648
raise ValueError ('Mask needs to have the same shape as weights. '
@@ -667,6 +673,7 @@ def __call__(self, inputs: Array) -> Array:
667
673
strides ,
668
674
padding_lax ,
669
675
rhs_dilation = self .kernel_dilation ,
676
+ transpose_kernel = self .transpose_kernel ,
670
677
precision = self .precision )
671
678
672
679
if self .padding == 'CIRCULAR' :
@@ -689,12 +696,20 @@ def __call__(self, inputs: Array) -> Array:
689
696
- (y_dim - x_dim ) % (2 * x_dim )
690
697
for y_dim , x_dim in zip (y .shape [1 :- 1 ], scaled_x_dims )
691
698
]
692
- # Divide the padding equaly between left and right. The choice to put
693
- # "+1" on the left (and not on the right) represents a convention for
694
- # aligning even-sized kernels.
695
- total_pad = [
696
- ((size_diff + 1 ) // 2 , size_diff // 2 ) for size_diff in size_diffs
697
- ]
699
+ if self .transpose_kernel :
700
+ # If the kernel is transposed, the "+1" is put on the right to
701
+ # mirror the regular convolution. If the same kernel parameters are used
702
+ # as for Conv, this layer then computes the proper transpose convolution.
703
+ total_pad = [
704
+ (size_diff // 2 , (size_diff + 1 ) // 2 ) for size_diff in size_diffs
705
+ ]
706
+ else :
707
+ # Divide the padding equally between left and right. The choice to put
708
+ # "+1" on the left (and not on the right) represents a convention for
709
+ # aligning even-sized kernels.
710
+ total_pad = [
711
+ ((size_diff + 1 ) // 2 , size_diff // 2 ) for size_diff in size_diffs
712
+ ]
698
713
y = jnp .pad (y , [(0 , 0 )] + total_pad + [(0 , 0 )])
699
714
# Wrap the result periodically around each spatial dimension,
700
715
# one by one.
0 commit comments