@@ -1326,6 +1326,26 @@ def lesser_equal(x, y):
1326
1326
return tf .less_equal (x , y )
1327
1327
1328
1328
1329
+ def where (x ):
1330
+ """Returns locations of true values in a boolean tensor.
1331
+
1332
+ This operation returns the coordinates of true elements in input. The coordinates are
1333
+ returned in a 2-D tensor where the first dimension (rows) represents the number of
1334
+ true elements, and the second dimension (columns) represents the coordinates of the
1335
+ true elements. Keep in mind, the shape of the output tensor can vary depending on
1336
+ how many true values there are in input.
1337
+
1338
+ # Arguments
1339
+ x: input bool tensor.
1340
+
1341
+ # Returns
1342
+ An integer tensor of indices.
1343
+
1344
+ """
1345
+ x = tf .cast (x , tf .bool )
1346
+ return tf .where (x )
1347
+
1348
+
1329
1349
def maximum (x , y ):
1330
1350
"""Element-wise maximum of two tensors.
1331
1351
@@ -1587,13 +1607,27 @@ def tile(x, n):
1587
1607
return tf .tile (x , n )
1588
1608
1589
1609
1590
- def flatten (x ):
1591
- """Flatten a tensor.
1610
+ def flatten (x , outdim = 1 ):
1611
+ """Returns a view of this tensor with `outdim` dimensions, whose shape
1612
+ for the first `outdim-1` dimensions will be the same as `x`, and
1613
+ shape in the remaining dimension will be expanded to fit in
1614
+ all the data from `x`.
1615
+
1616
+ # Arguments
1617
+ x: input tensor.
1618
+ outdim: number of dimensions in the output tensor.
1592
1619
1593
1620
# Returns
1594
- A tensor, reshaped into 1-D
1621
+ A tensor, reshaped outdim dimensions.
1622
+
1595
1623
"""
1596
- return tf .reshape (x , [- 1 ])
1624
+
1625
+ if outdim > 1 :
1626
+ shape = concatenate ([tf .shape (x )[:outdim - 1 ], variable ([- 1 ], dtype = 'int32' )])
1627
+ else :
1628
+ shape = [- 1 ]
1629
+
1630
+ return tf .reshape (x , shape )
1597
1631
1598
1632
1599
1633
def batch_flatten (x ):
@@ -2023,7 +2057,10 @@ def rnn(step_function, inputs, initial_states,
2023
2057
2024
2058
# TODO: remove later.
2025
2059
if hasattr (tf , 'select' ):
2026
- tf .where = tf .select
2060
+ where_op = tf .select
2061
+ else :
2062
+ where_op = tf .where
2063
+
2027
2064
if hasattr (tf , 'stack' ):
2028
2065
stack = tf .stack
2029
2066
unstack = tf .unstack
@@ -2069,14 +2106,14 @@ def rnn(step_function, inputs, initial_states,
2069
2106
else :
2070
2107
prev_output = successive_outputs [- 1 ]
2071
2108
2072
- output = tf . where (tiled_mask_t , output , prev_output )
2109
+ output = where_op (tiled_mask_t , output , prev_output )
2073
2110
2074
2111
return_states = []
2075
2112
for state , new_state in zip (states , new_states ):
2076
2113
# (see earlier comment for tile explanation)
2077
2114
tiled_mask_t = tf .tile (mask_t ,
2078
2115
stack ([1 , tf .shape (new_state )[1 ]]))
2079
- return_states .append (tf . where (tiled_mask_t ,
2116
+ return_states .append (where_op (tiled_mask_t ,
2080
2117
new_state ,
2081
2118
state ))
2082
2119
states = return_states
@@ -2145,8 +2182,8 @@ def _step(time, output_ta_t, *states):
2145
2182
new_state .set_shape (state .get_shape ())
2146
2183
tiled_mask_t = tf .tile (mask_t ,
2147
2184
stack ([1 , tf .shape (output )[1 ]]))
2148
- output = tf . where (tiled_mask_t , output , states [0 ])
2149
- new_states = [tf . where (tiled_mask_t , new_states [i ], states [i ]) for i in range (len (states ))]
2185
+ output = where_op (tiled_mask_t , output , states [0 ])
2186
+ new_states = [where_op (tiled_mask_t , new_states [i ], states [i ]) for i in range (len (states ))]
2150
2187
output_ta_t = output_ta_t .write (time , output )
2151
2188
return (time + 1 , output_ta_t ) + tuple (new_states )
2152
2189
else :
0 commit comments