Skip to content

Commit e169591

Browse files
author
DarrenZhang01
committed
Revise some format problems according to pylint.
1 parent 8c5a7b9 commit e169591

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

trax/tf_numpy/extensions/extensions.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from more_itertools import sort_together
2929

3030
import tensorflow.compat.v2 as tf
31+
from tensorflow import nn
3132

3233
import trax.tf_numpy.numpy as tf_np
3334

@@ -565,22 +566,24 @@ def tf_dot_general(lhs, rhs, dimension_numbers):
565566
return tf.einsum(equation, lhs, rhs)
566567

567568

568-
# TODO (Zhibo Zhang): Run pylint and complement the docstring.
569+
# TODO (DarrenZhang01): Complement the docstring.
569570
def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides):
570571
""" Evaluate the output shape in for transpose convolutions.
571572
"""
572573
output_shape = [lhs_shape[0]]
573574
for i in range(1, len(lhs_shape) - 1):
574575
if padding == "SAME":
575-
output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] + rhs_shape[i])
576+
output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] +
577+
rhs_shape[i])
576578
if padding == "VALID":
577579
output_shape.append((lhs_shape[i] - 1) * window_strides[i-1])
578580
output_shape.append(lhs_shape[-1])
579581
return tf.constant(output_shape)
580582

581583

582-
# TODO (Zhibo Zhang): Run pylint and complement the docstring.
583-
def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation):
584+
# TODO (DarrenZhang01): Complement the docstring.
585+
def _conv_general_param_type_converter(window_strides, lhs_dilation,
586+
rhs_dilation):
584587
""" Convert the inputs strides, lhs_dilation, rhs_dilation to the standard
585588
TF conv inputs.
586589
For example,
@@ -598,11 +601,11 @@ def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilatio
598601
return (strides, lhs_dilation, rhs_dilation)
599602

600603

601-
# TODO (Zhibo Zhang): Run pylint and complement the docstring.
602-
# TOTO (Zhibo Zhang): Expand the test cases of general convolution and revise
604+
# TODO (DarrenZhang01): Expand the test cases of general convolution and revise
603605
# the according bugs.
604-
# TODO (Zhibo Zhang): Support feature_group_count, batch_group_count and precision, and
605-
# allow lhs_dilation and rhs_dilation to happen at the same time.
606+
# TODO (DarrenZhang01): Support feature_group_count, batch_group_count and
607+
# precision, and allow lhs_dilation and rhs_dilation to happen at the
608+
# same time.
606609
def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
607610
lhs_dilation=None, rhs_dilation=None,
608611
dimension_numbers=None, feature_group_count=1,
@@ -612,26 +615,26 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
612615
dim = None
613616
lhs_spec, rhs_spec, out_spec = dimension_numbers
614617
if lhs_spec != out_spec:
615-
raise TypeError("Current implementation requires the `data_format` of the "
618+
raise ValueError("Current implementation requires the `data_format` of the "
616619
"inputs and outputs to be the same.")
617620
if len(lhs_spec) >= 6:
618-
raise TypeError("Current implmentation does not support 4 or higher"
621+
raise ValueError("Current implmentation does not support 4 or higher"
619622
"dimensional convolution, but got: ", len(lhs_spec) - 2)
620623
dim = len(lhs_spec) - 2
621624
if lhs_dilation and rhs_dilation:
622625
if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim:
623626
lhs_dilation, rhs_dilation = None, None
624627
else:
625-
raise TypeError("Current implementation does not support that deconvolution"
626-
"and dilation to be performed at the same time, but got"
627-
" lhs_dilation: {}, rhs_dilation: {}".format(lhs_dilation,
628-
rhs_dilation))
628+
raise ValueError("Current implementation does not support that "
629+
"deconvolution and dilation to be performed at the same "
630+
"time, but got lhs_dilation: {}, rhs_dilation: {}".format(
631+
lhs_dilation, rhs_dilation))
629632
if padding not in ["SAME", "VALID"]:
630-
raise TypeError("Current implementation requires the padding parameter"
633+
raise ValueError("Current implementation requires the padding parameter"
631634
"to be either 'VALID' or 'SAME', but got: ", padding)
632635
# Convert params from int/Sequence[int] to list of ints.
633636
strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter(
634-
window_strides, lhs_dilation, rhs_dilation
637+
window_strides, lhs_dilation, rhs_dilation
635638
)
636639
# Preprocess the shapes
637640
dim_maps = {}

0 commit comments

Comments
 (0)