Skip to content

Commit ea29343

Browse files
author
DarrenZhang01
committed
Define inner functions as pointed out in the code review.
1 parent e169591 commit ea29343

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

trax/tf_numpy/extensions/extensions.py

+18-20
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
tf.int64, tf.int32, tf.int16, tf.int8, tf.uint8, tf.uint16, tf.uint32,
4040
tf.uint64
4141
]
42+
_tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose],
43+
2: [nn.conv2d, nn.conv2d_transpose],
44+
3: [nn.conv3d, nn.conv3d_transpose]}
4245

4346

4447
def most_precise_int_dtype(x):
@@ -583,22 +586,20 @@ def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides):
583586

584587
# TODO (DarrenZhang01): Complement the docstring.
585588
def _conv_general_param_type_converter(window_strides, lhs_dilation,
586-
rhs_dilation):
589+
rhs_dilation, dim):
587590
""" Convert the inputs strides, lhs_dilation, rhs_dilation to the standard
588591
TF conv inputs.
589592
For example,
590593
in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2]
591594
if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2]
592595
"""
593-
strides = [window_strides] * dim if isinstance(window_strides, int) else \
594-
list(window_strides)
595-
if lhs_dilation:
596-
lhs_dilation = [lhs_dilation] * dim if isinstance(lhs_dilation, int) else \
597-
list(lhs_dilation)
598-
if rhs_dilation:
599-
rhs_dilation = [rhs_dilation] * dim if isinstance(rhs_dilation, int) else \
600-
list(rhs_dilation)
601-
return (strides, lhs_dilation, rhs_dilation)
596+
def _as_list_of_size(item, size):
597+
if item is None:
598+
return None
599+
return [item] * size if isinstance(item, int) else list(item)
600+
return (_as_list_of_size(window_strides, dim),
601+
_as_list_of_size(lhs_dilation, dim),
602+
_as_list_of_size(rhs_dilation, dim))
602603

603604

604605
# TODO (DarrenZhang01): Expand the test cases of general convolution and revise
@@ -634,7 +635,7 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
634635
"to be either 'VALID' or 'SAME', but got: ", padding)
635636
# Convert params from int/Sequence[int] to list of ints.
636637
strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter(
637-
window_strides, lhs_dilation, rhs_dilation
638+
window_strides, lhs_dilation, rhs_dilation, dim
638639
)
639640
# Preprocess the shapes
640641
dim_maps = {}
@@ -649,24 +650,21 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
649650
dim_maps['N'] = lhs_spec[0]
650651
dim_maps['C'] = lhs_spec[1]
651652

652-
lhs = np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1))
653+
lhs = tf_np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1))
653654
# Adjust the filters, put the dimension 'I' and 'O' at last.
654-
rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim))
655+
rhs = tf_np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim))
655656
spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"}
656657
data_format = 'N' + spatial_dim_maps[dim] + 'C'
657-
tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose],
658-
2: [nn.conv2d, nn.conv2d_transpose],
659-
3: [nn.conv3d, nn.conv3d_transpose]}
660658

661659
output = None
662660
if rhs_dilation or (lhs_dilation is None and rhs_dilation is None):
663-
output = tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format,
661+
output = _tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format,
664662
rhs_dilation)
665663
else:
666-
output = tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides,
664+
output = _tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides,
667665
padding, data_format, lhs_dilation)
668-
output = np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C']))
669-
return np.asarray(output)
666+
output = tf_np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C']))
667+
return tf_np.asarray(output)
670668

671669

672670
def conv(inp,

0 commit comments

Comments
 (0)