28
28
from more_itertools import sort_together
29
29
30
30
import tensorflow .compat .v2 as tf
31
+ from tensorflow import nn
31
32
32
33
import trax .tf_numpy .numpy as tf_np
33
34
@@ -565,22 +566,24 @@ def tf_dot_general(lhs, rhs, dimension_numbers):
565
566
return tf .einsum (equation , lhs , rhs )
566
567
567
568
568
- # TODO (Zhibo Zhang ): Run pylint and complement the docstring.
569
+ # TODO (DarrenZhang01 ): Complement the docstring.
569
570
def _eval_output_shape (lhs_shape , rhs_shape , padding , window_strides ):
570
571
""" Evaluate the output shape in for transpose convolutions.
571
572
"""
572
573
output_shape = [lhs_shape [0 ]]
573
574
for i in range (1 , len (lhs_shape ) - 1 ):
574
575
if padding == "SAME" :
575
- output_shape .append ((lhs_shape [i ] - 1 ) * window_strides [i - 1 ] + rhs_shape [i ])
576
+ output_shape .append ((lhs_shape [i ] - 1 ) * window_strides [i - 1 ] +
577
+ rhs_shape [i ])
576
578
if padding == "VALID" :
577
579
output_shape .append ((lhs_shape [i ] - 1 ) * window_strides [i - 1 ])
578
580
output_shape .append (lhs_shape [- 1 ])
579
581
return tf .constant (output_shape )
580
582
581
583
582
- # TODO (Zhibo Zhang): Run pylint and complement the docstring.
583
- def _conv_general_param_type_converter (window_strides , lhs_dilation , rhs_dilation ):
584
+ # TODO (DarrenZhang01): Complement the docstring.
585
+ def _conv_general_param_type_converter (window_strides , lhs_dilation ,
586
+ rhs_dilation ):
584
587
""" Convert the inputs strides, lhs_dilation, rhs_dilation to the standard
585
588
TF conv inputs.
586
589
For example,
@@ -598,11 +601,11 @@ def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilatio
598
601
return (strides , lhs_dilation , rhs_dilation )
599
602
600
603
601
- # TODO (Zhibo Zhang): Run pylint and complement the docstring.
602
- # TOTO (Zhibo Zhang): Expand the test cases of general convolution and revise
604
+ # TODO (DarrenZhang01): Expand the test cases of general convolution and revise
603
605
# the according bugs.
604
- # TODO (Zhibo Zhang): Support feature_group_count, batch_group_count and precision, and
605
- # allow lhs_dilation and rhs_dilation to happen at the same time.
606
+ # TODO (DarrenZhang01): Support feature_group_count, batch_group_count and
607
+ # precision, and allow lhs_dilation and rhs_dilation to happen at the
608
+ # same time.
606
609
def conv_general_dilated (lhs , rhs , window_strides , padding , output_shape ,
607
610
lhs_dilation = None , rhs_dilation = None ,
608
611
dimension_numbers = None , feature_group_count = 1 ,
@@ -612,26 +615,26 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
612
615
dim = None
613
616
lhs_spec , rhs_spec , out_spec = dimension_numbers
614
617
if lhs_spec != out_spec :
615
- raise TypeError ("Current implementation requires the `data_format` of the "
618
+ raise ValueError ("Current implementation requires the `data_format` of the "
616
619
"inputs and outputs to be the same." )
617
620
if len (lhs_spec ) >= 6 :
618
- raise TypeError ("Current implmentation does not support 4 or higher"
621
+ raise ValueError ("Current implmentation does not support 4 or higher"
619
622
"dimensional convolution, but got: " , len (lhs_spec ) - 2 )
620
623
dim = len (lhs_spec ) - 2
621
624
if lhs_dilation and rhs_dilation :
622
625
if lhs_dilation == (1 ,) * dim and rhs_dilation == (1 ,) * dim :
623
626
lhs_dilation , rhs_dilation = None , None
624
627
else :
625
- raise TypeError ("Current implementation does not support that deconvolution "
626
- " and dilation to be performed at the same time, but got "
627
- " lhs_dilation: {}, rhs_dilation: {}" .format (lhs_dilation ,
628
- rhs_dilation ))
628
+ raise ValueError ("Current implementation does not support that "
629
+ "deconvolution and dilation to be performed at the same "
630
+ "time, but got lhs_dilation: {}, rhs_dilation: {}" .format (
631
+ lhs_dilation , rhs_dilation ))
629
632
if padding not in ["SAME" , "VALID" ]:
630
- raise TypeError ("Current implementation requires the padding parameter"
633
+ raise ValueError ("Current implementation requires the padding parameter"
631
634
"to be either 'VALID' or 'SAME', but got: " , padding )
632
635
# Convert params from int/Sequence[int] to list of ints.
633
636
strides , lhs_dilation , rhs_dilation = _conv_general_param_type_converter (
634
- window_strides , lhs_dilation , rhs_dilation
637
+ window_strides , lhs_dilation , rhs_dilation
635
638
)
636
639
# Preprocess the shapes
637
640
dim_maps = {}
0 commit comments