Skip to content

Commit a13a238

Browse files
author
DarrenZhang01
committed
Revise the rest of the issues according to the code review.
1 parent ea29343 commit a13a238

File tree

2 files changed

+66
-53
lines changed

2 files changed

+66
-53
lines changed

trax/tf_numpy/extensions/extensions.py

+41-33
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import threading
2626
import numpy as np
2727
import six
28-
from more_itertools import sort_together
2928

3029
import tensorflow.compat.v2 as tf
3130
from tensorflow import nn
@@ -569,22 +568,6 @@ def tf_dot_general(lhs, rhs, dimension_numbers):
569568
return tf.einsum(equation, lhs, rhs)
570569

571570

572-
# TODO (DarrenZhang01): Complement the docstring.
573-
def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides):
574-
""" Evaluate the output shape in for transpose convolutions.
575-
"""
576-
output_shape = [lhs_shape[0]]
577-
for i in range(1, len(lhs_shape) - 1):
578-
if padding == "SAME":
579-
output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] +
580-
rhs_shape[i])
581-
if padding == "VALID":
582-
output_shape.append((lhs_shape[i] - 1) * window_strides[i-1])
583-
output_shape.append(lhs_shape[-1])
584-
return tf.constant(output_shape)
585-
586-
587-
# TODO (DarrenZhang01): Complement the docstring.
588571
def _conv_general_param_type_converter(window_strides, lhs_dilation,
589572
rhs_dilation, dim):
590573
""" Convert the inputs strides, lhs_dilation, rhs_dilation to the standard
@@ -607,32 +590,58 @@ def _as_list_of_size(item, size):
607590
# TODO (DarrenZhang01): Support feature_group_count, batch_group_count and
608591
# precision, and allow lhs_dilation and rhs_dilation to happen at the
609592
# same time.
610-
def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
611-
lhs_dilation=None, rhs_dilation=None,
612-
dimension_numbers=None, feature_group_count=1,
613-
batch_group_count=1, precision=None):
614-
""" A general conv API that integrates normal conv, deconvolution,
615-
dilated convolution, etc."""
593+
def tf_conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
594+
lhs_dilation=None, rhs_dilation=None,
595+
dimension_numbers=None, feature_group_count=1,
596+
batch_group_count=1, precision=None):
597+
""" A general conv API for TensorFlow.
598+
599+
According JAX version:
600+
https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html
601+
602+
Args: (Use JAX documentation as a reference)
603+
lhs: a rank n+2 dimensional input array.
604+
rhs: a rank n+2 dimensional array of kernel weights.
605+
window_strides: a sequence of n integers, representing the inter-window
606+
strides.
607+
padding: either the string ‘SAME’, the string ‘VALID’, or a sequence of n
608+
(low, high) integer pairs that give the padding to apply before and
609+
after each spatial dimension.
610+
output_shape: the output shape of the convolution.
611+
lhs_dilation: None, or a sequence of n integers, giving the dilation factor
612+
to apply in each spatial dimension of lhs. LHS dilation is
613+
also known as transposed convolution.
614+
rhs_dilation: None, or a sequence of n integers, giving the dilation factor
615+
to apply in each spatial dimension of rhs. RHS dilation is
616+
also known as atrous convolution.
617+
dimension_numbers: either None, a ConvDimensionNumbers object, or a 3-tuple
618+
(lhs_spec, rhs_spec, out_spec), where each element is a
619+
string of length n+2.
620+
feature_group_count: integer, default 1.
621+
batch_group_count: integer, default 1.
622+
precision: Optional. Either None, which means the default precision for the
623+
backend, or a Precision enum value.
624+
"""
616625
dim = None
617626
lhs_spec, rhs_spec, out_spec = dimension_numbers
618627
if lhs_spec != out_spec:
619628
raise ValueError("Current implementation requires the `data_format` of the "
620-
"inputs and outputs to be the same.")
629+
"inputs and outputs to be the same.")
621630
if len(lhs_spec) >= 6:
622631
raise ValueError("Current implmentation does not support 4 or higher"
623-
"dimensional convolution, but got: ", len(lhs_spec) - 2)
632+
"dimensional convolution, but got: ", len(lhs_spec) - 2)
624633
dim = len(lhs_spec) - 2
625634
if lhs_dilation and rhs_dilation:
626635
if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim:
627636
lhs_dilation, rhs_dilation = None, None
628637
else:
629638
raise ValueError("Current implementation does not support that "
630-
"deconvolution and dilation to be performed at the same "
631-
"time, but got lhs_dilation: {}, rhs_dilation: {}".format(
632-
lhs_dilation, rhs_dilation))
639+
"deconvolution and dilation to be performed at the same "
640+
"time, but got lhs_dilation: {}, rhs_dilation: {}"
641+
.format(lhs_dilation, rhs_dilation))
633642
if padding not in ["SAME", "VALID"]:
634643
raise ValueError("Current implementation requires the padding parameter"
635-
"to be either 'VALID' or 'SAME', but got: ", padding)
644+
"to be either 'VALID' or 'SAME', but got: ", padding)
636645
# Convert params from int/Sequence[int] to list of ints.
637646
strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter(
638647
window_strides, lhs_dilation, rhs_dilation, dim
@@ -656,15 +665,14 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
656665
spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"}
657666
data_format = 'N' + spatial_dim_maps[dim] + 'C'
658667

659-
output = None
660668
if rhs_dilation or (lhs_dilation is None and rhs_dilation is None):
661669
output = _tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format,
662-
rhs_dilation)
670+
rhs_dilation)
663671
else:
664672
output = _tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides,
665-
padding, data_format, lhs_dilation)
673+
padding, data_format, lhs_dilation)
666674
output = tf_np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C']))
667-
return tf_np.asarray(output)
675+
return output
668676

669677

670678
def conv(inp,

trax/tf_numpy/extensions/extensions_test.py

+25-20
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@
2121

2222
import functools
2323
from absl import flags
24+
import itertools
2425
from absl.testing import parameterized
2526

27+
import jax
2628
from jax import lax
27-
import jax.numpy as jnp
2829
import numpy as np
2930
import tensorflow.compat.v2 as tf
3031

3132
from trax.tf_numpy import extensions
3233
import trax.tf_numpy.numpy as tf_np
3334

35+
3436
FLAGS = flags.FLAGS
3537

3638
flags.DEFINE_bool("requires_tpu", False, "Requires TPU.")
@@ -423,16 +425,16 @@ def test_tf_dot_general(self, lhs_np, rhs_np, dims):
423425
self.assertAllClose(result, np.array(ans))
424426

425427

426-
# TODO (Zhibo Zhang): Run pylint on this function.
427428
@parameterized.named_parameters([
428429
("_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
429430
"_lhs_dilation={}_rhs_dilation={}"
430431
"_feature_group_count={}_batch_group_count={}_dims={}"
431432
"_perms={}".format(lhs_shape, rhs_shape,
432-
strides, padding, lhs_dilation, rhs_dilation,
433-
feature_group_count, batch_group_count, ",".join(dimension_numbers), perms),
434-
lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation,
435-
feature_group_count, batch_group_count, dimension_numbers, perms)
433+
strides, padding, lhs_dilation, rhs_dilation,
434+
feature_group_count, batch_group_count, ",".join(
435+
dimension_numbers), perms),
436+
lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation,
437+
feature_group_count, batch_group_count, dimension_numbers, perms)
436438
for batch_group_count, feature_group_count in [(1, 1)]
437439
for lhs_shape, rhs_shape in [
438440
((b * batch_group_count, i * feature_group_count, 9, w),
@@ -442,29 +444,32 @@ def test_tf_dot_general(self, lhs_np, rhs_np, dims):
442444
for strides in [(1, 1), (2, 1)]
443445
for padding in ['SAME']
444446
for lhs_dilation, rhs_dilation in [
445-
(None, (1, 1))
447+
(None, (1, 1))
446448
]
447449
for dimension_numbers, perms in [
448-
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0]))
450+
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0]))
449451
]])
450452
def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides,
451453
padding, lhs_dilation, rhs_dilation,
452454
feature_group_count, batch_group_count,
453455
dimension_numbers, perms):
454-
tf.print("dimension_numbers: {}".format(dimension_numbers), output_stream=sys.stdout)
455456
lhs_perm, rhs_perm = perms # permute to compatible shapes
456457

457-
lhs_tf = tf_np.transpose(tf_np.ones(lhs_shape), lhs_perm)
458-
rhs_tf = tf_np.transpose(tf_np.ones(rhs_shape), rhs_perm)
459-
460-
lhs_jax = jnp.transpose(jnp.ones(lhs_shape), lhs_perm)
461-
rhs_jax = jnp.transpose(jnp.ones(rhs_shape), rhs_perm)
462-
463-
jax_conv = jax.lax.conv_general_dilated(lhs_jax, rhs_jax, strides, padding, lhs_dilation,
464-
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count)
465-
466-
tf_conv = lax.conv_general_dilated(lhs_tf, rhs_tf, strides, padding, jax_conv.shape, lhs_dilation,
467-
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count)
458+
lhs = np.transpose(np.ones(lhs_shape), lhs_perm)
459+
rhs = np.transpose(np.ones(rhs_shape), rhs_perm)
460+
461+
jax_conv = jax.lax.conv_general_dilated(lhs, rhs, strides, padding,
462+
lhs_dilation, rhs_dilation,
463+
dimension_numbers,
464+
feature_group_count,
465+
batch_group_count)
466+
467+
tf_conv = extensions.tf_conv_general_dilated(lhs, rhs, strides,
468+
padding, jax_conv.shape,
469+
lhs_dilation, rhs_dilation,
470+
dimension_numbers,
471+
feature_group_count,
472+
batch_group_count)
468473

469474
self.assertAllEqual(tf_conv, tf_np.asarray(jax_conv))
470475

0 commit comments

Comments
 (0)