39
39
tf .int64 , tf .int32 , tf .int16 , tf .int8 , tf .uint8 , tf .uint16 , tf .uint32 ,
40
40
tf .uint64
41
41
]
42
+ _tf_nn_APIs = {1 : [nn .conv1d , nn .conv1d_transpose ],
43
+ 2 : [nn .conv2d , nn .conv2d_transpose ],
44
+ 3 : [nn .conv3d , nn .conv3d_transpose ]}
42
45
43
46
44
47
def most_precise_int_dtype (x ):
@@ -583,22 +586,20 @@ def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides):
583
586
584
587
# TODO (DarrenZhang01): Complement the docstring.
585
588
def _conv_general_param_type_converter (window_strides , lhs_dilation ,
586
- rhs_dilation ):
589
+ rhs_dilation , dim ):
587
590
""" Convert the inputs strides, lhs_dilation, rhs_dilation to the standard
588
591
TF conv inputs.
589
592
For example,
590
593
in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2]
591
594
if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2]
592
595
"""
593
- strides = [window_strides ] * dim if isinstance (window_strides , int ) else \
594
- list (window_strides )
595
- if lhs_dilation :
596
- lhs_dilation = [lhs_dilation ] * dim if isinstance (lhs_dilation , int ) else \
597
- list (lhs_dilation )
598
- if rhs_dilation :
599
- rhs_dilation = [rhs_dilation ] * dim if isinstance (rhs_dilation , int ) else \
600
- list (rhs_dilation )
601
- return (strides , lhs_dilation , rhs_dilation )
596
+ def _as_list_of_size (item , size ):
597
+ if item is None :
598
+ return None
599
+ return [item ] * size if isinstance (item , int ) else list (item )
600
+ return (_as_list_of_size (window_strides , dim ),
601
+ _as_list_of_size (lhs_dilation , dim ),
602
+ _as_list_of_size (rhs_dilation , dim ))
602
603
603
604
604
605
# TODO (DarrenZhang01): Expand the test cases of general convolution and revise
@@ -634,7 +635,7 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
634
635
"to be either 'VALID' or 'SAME', but got: " , padding )
635
636
# Convert params from int/Sequence[int] to list of ints.
636
637
strides , lhs_dilation , rhs_dilation = _conv_general_param_type_converter (
637
- window_strides , lhs_dilation , rhs_dilation
638
+ window_strides , lhs_dilation , rhs_dilation , dim
638
639
)
639
640
# Preprocess the shapes
640
641
dim_maps = {}
@@ -649,24 +650,21 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
649
650
dim_maps ['N' ] = lhs_spec [0 ]
650
651
dim_maps ['C' ] = lhs_spec [1 ]
651
652
652
- lhs = np .moveaxis (lhs , (dim_maps ['N' ], dim_maps ['C' ]), (0 , dim + 1 ))
653
+ lhs = tf_np .moveaxis (lhs , (dim_maps ['N' ], dim_maps ['C' ]), (0 , dim + 1 ))
653
654
# Adjust the filters, put the dimension 'I' and 'O' at last.
654
- rhs = np .moveaxis (rhs , (dim_maps ['O' ], dim_maps ['I' ]), (dim + 1 , dim ))
655
+ rhs = tf_np .moveaxis (rhs , (dim_maps ['O' ], dim_maps ['I' ]), (dim + 1 , dim ))
655
656
spatial_dim_maps = {1 : 'W' , 2 : "HW" , 3 : "DHW" }
656
657
data_format = 'N' + spatial_dim_maps [dim ] + 'C'
657
- tf_nn_APIs = {1 : [nn .conv1d , nn .conv1d_transpose ],
658
- 2 : [nn .conv2d , nn .conv2d_transpose ],
659
- 3 : [nn .conv3d , nn .conv3d_transpose ]}
660
658
661
659
output = None
662
660
if rhs_dilation or (lhs_dilation is None and rhs_dilation is None ):
663
- output = tf_nn_APIs [dim ][0 ](lhs , rhs , strides , padding , data_format ,
661
+ output = _tf_nn_APIs [dim ][0 ](lhs , rhs , strides , padding , data_format ,
664
662
rhs_dilation )
665
663
else :
666
- output = tf_nn_APIs [dim ][1 ](lhs , rhs , tf .constant (output_shape ), strides ,
664
+ output = _tf_nn_APIs [dim ][1 ](lhs , rhs , tf .constant (output_shape ), strides ,
667
665
padding , data_format , lhs_dilation )
668
- output = np .moveaxis (output , (0 , dim + 1 ), (dim_maps ['N' ], dim_maps ['C' ]))
669
- return np .asarray (output )
666
+ output = tf_np .moveaxis (output , (0 , dim + 1 ), (dim_maps ['N' ], dim_maps ['C' ]))
667
+ return tf_np .asarray (output )
670
668
671
669
672
670
def conv (inp ,
0 commit comments