25
25
import threading
26
26
import numpy as np
27
27
import six
28
- from more_itertools import sort_together
29
28
30
29
import tensorflow .compat .v2 as tf
31
30
from tensorflow import nn
@@ -569,22 +568,6 @@ def tf_dot_general(lhs, rhs, dimension_numbers):
569
568
return tf .einsum (equation , lhs , rhs )
570
569
571
570
572
- # TODO (DarrenZhang01): Complement the docstring.
573
- def _eval_output_shape (lhs_shape , rhs_shape , padding , window_strides ):
574
- """ Evaluate the output shape in for transpose convolutions.
575
- """
576
- output_shape = [lhs_shape [0 ]]
577
- for i in range (1 , len (lhs_shape ) - 1 ):
578
- if padding == "SAME" :
579
- output_shape .append ((lhs_shape [i ] - 1 ) * window_strides [i - 1 ] +
580
- rhs_shape [i ])
581
- if padding == "VALID" :
582
- output_shape .append ((lhs_shape [i ] - 1 ) * window_strides [i - 1 ])
583
- output_shape .append (lhs_shape [- 1 ])
584
- return tf .constant (output_shape )
585
-
586
-
587
- # TODO (DarrenZhang01): Complement the docstring.
588
571
def _conv_general_param_type_converter (window_strides , lhs_dilation ,
589
572
rhs_dilation , dim ):
590
573
""" Convert the inputs strides, lhs_dilation, rhs_dilation to the standard
@@ -607,32 +590,58 @@ def _as_list_of_size(item, size):
607
590
# TODO (DarrenZhang01): Support feature_group_count, batch_group_count and
608
591
# precision, and allow lhs_dilation and rhs_dilation to happen at the
609
592
# same time.
610
- def conv_general_dilated (lhs , rhs , window_strides , padding , output_shape ,
611
- lhs_dilation = None , rhs_dilation = None ,
612
- dimension_numbers = None , feature_group_count = 1 ,
613
- batch_group_count = 1 , precision = None ):
614
- """ A general conv API that integrates normal conv, deconvolution,
615
- dilated convolution, etc."""
593
+ def tf_conv_general_dilated (lhs , rhs , window_strides , padding , output_shape ,
594
+ lhs_dilation = None , rhs_dilation = None ,
595
+ dimension_numbers = None , feature_group_count = 1 ,
596
+ batch_group_count = 1 , precision = None ):
597
+ """ A general conv API for TensorFlow.
598
+
599
+ According JAX version:
600
+ https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html
601
+
602
+ Args: (Use JAX documentation as a reference)
603
+ lhs: a rank n+2 dimensional input array.
604
+ rhs: a rank n+2 dimensional array of kernel weights.
605
+ window_strides: a sequence of n integers, representing the inter-window
606
+ strides.
607
+ padding: either the string ‘SAME’, the string ‘VALID’, or a sequence of n
608
+ (low, high) integer pairs that give the padding to apply before and
609
+ after each spatial dimension.
610
+ output_shape: the output shape of the convolution.
611
+ lhs_dilation: None, or a sequence of n integers, giving the dilation factor
612
+ to apply in each spatial dimension of lhs. LHS dilation is
613
+ also known as transposed convolution.
614
+ rhs_dilation: None, or a sequence of n integers, giving the dilation factor
615
+ to apply in each spatial dimension of rhs. RHS dilation is
616
+ also known as atrous convolution.
617
+ dimension_numbers: either None, a ConvDimensionNumbers object, or a 3-tuple
618
+ (lhs_spec, rhs_spec, out_spec), where each element is a
619
+ string of length n+2.
620
+ feature_group_count: integer, default 1.
621
+ batch_group_count: integer, default 1.
622
+ precision: Optional. Either None, which means the default precision for the
623
+ backend, or a Precision enum value.
624
+ """
616
625
dim = None
617
626
lhs_spec , rhs_spec , out_spec = dimension_numbers
618
627
if lhs_spec != out_spec :
619
628
raise ValueError ("Current implementation requires the `data_format` of the "
620
- "inputs and outputs to be the same." )
629
+ "inputs and outputs to be the same." )
621
630
if len (lhs_spec ) >= 6 :
622
631
raise ValueError ("Current implmentation does not support 4 or higher"
623
- "dimensional convolution, but got: " , len (lhs_spec ) - 2 )
632
+ "dimensional convolution, but got: " , len (lhs_spec ) - 2 )
624
633
dim = len (lhs_spec ) - 2
625
634
if lhs_dilation and rhs_dilation :
626
635
if lhs_dilation == (1 ,) * dim and rhs_dilation == (1 ,) * dim :
627
636
lhs_dilation , rhs_dilation = None , None
628
637
else :
629
638
raise ValueError ("Current implementation does not support that "
630
- "deconvolution and dilation to be performed at the same "
631
- "time, but got lhs_dilation: {}, rhs_dilation: {}" . format (
632
- lhs_dilation , rhs_dilation ))
639
+ "deconvolution and dilation to be performed at the same "
640
+ "time, but got lhs_dilation: {}, rhs_dilation: {}"
641
+ . format ( lhs_dilation , rhs_dilation ))
633
642
if padding not in ["SAME" , "VALID" ]:
634
643
raise ValueError ("Current implementation requires the padding parameter"
635
- "to be either 'VALID' or 'SAME', but got: " , padding )
644
+ "to be either 'VALID' or 'SAME', but got: " , padding )
636
645
# Convert params from int/Sequence[int] to list of ints.
637
646
strides , lhs_dilation , rhs_dilation = _conv_general_param_type_converter (
638
647
window_strides , lhs_dilation , rhs_dilation , dim
@@ -656,15 +665,14 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
656
665
spatial_dim_maps = {1 : 'W' , 2 : "HW" , 3 : "DHW" }
657
666
data_format = 'N' + spatial_dim_maps [dim ] + 'C'
658
667
659
- output = None
660
668
if rhs_dilation or (lhs_dilation is None and rhs_dilation is None ):
661
669
output = _tf_nn_APIs [dim ][0 ](lhs , rhs , strides , padding , data_format ,
662
- rhs_dilation )
670
+ rhs_dilation )
663
671
else :
664
672
output = _tf_nn_APIs [dim ][1 ](lhs , rhs , tf .constant (output_shape ), strides ,
665
- padding , data_format , lhs_dilation )
673
+ padding , data_format , lhs_dilation )
666
674
output = tf_np .moveaxis (output , (0 , dim + 1 ), (dim_maps ['N' ], dim_maps ['C' ]))
667
- return tf_np . asarray ( output )
675
+ return output
668
676
669
677
670
678
def conv (inp ,
0 commit comments