Skip to content

Commit 8c5a7b9

Browse files
author
DarrenZhang01
committed
Add the helper function _eval_output_shape.
1 parent acd76cc commit 8c5a7b9

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

trax/tf_numpy/extensions/extensions.py

+14
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,20 @@ def tf_dot_general(lhs, rhs, dimension_numbers):
565565
return tf.einsum(equation, lhs, rhs)
566566

567567

568+
# TODO (Zhibo Zhang): Run pylint and complement the docstring.
569+
def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides):
570+
""" Evaluate the output shape in for transpose convolutions.
571+
"""
572+
output_shape = [lhs_shape[0]]
573+
for i in range(1, len(lhs_shape) - 1):
574+
if padding == "SAME":
575+
output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] + rhs_shape[i])
576+
if padding == "VALID":
577+
output_shape.append((lhs_shape[i] - 1) * window_strides[i-1])
578+
output_shape.append(lhs_shape[-1])
579+
return tf.constant(output_shape)
580+
581+
568582
# TODO (Zhibo Zhang): Run pylint and complement the docstring.
569583
def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation):
570584
""" Convert the inputs strides, lhs_dilation, rhs_dilation to the standard

0 commit comments

Comments
 (0)