Skip to content

Commit 5b84c45

Browse files
author
Bojan Karlas
committed
Correct calculation of metrics with masking (keras-team#2260)
* Reshape y_pred and y_true from (samples, timesteps, ... ) to (samples * timesteps, ... ) * Filter out masked timesteps from y_pred and y_true * Added K.where() and extended functionality of K.flatten() * Added/changed corresponding tests
1 parent ab3b93e commit 5b84c45

File tree

5 files changed

+147
-16
lines changed

5 files changed

+147
-16
lines changed

keras/backend/tensorflow_backend.py

+46-9
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,26 @@ def lesser_equal(x, y):
13261326
return tf.less_equal(x, y)
13271327

13281328

1329+
def where(x):
1330+
"""Returns locations of true values in a boolean tensor.
1331+
1332+
This operation returns the coordinates of true elements in input. The coordinates are
1333+
returned in a 2-D tensor where the first dimension (rows) represents the number of
1334+
true elements, and the second dimension (columns) represents the coordinates of the
1335+
true elements. Keep in mind, the shape of the output tensor can vary depending on
1336+
how many true values there are in input.
1337+
1338+
# Arguments
1339+
x: input bool tensor.
1340+
1341+
# Returns
1342+
An integer tensor of indices.
1343+
1344+
"""
1345+
x = tf.cast(x, tf.bool)
1346+
return tf.where(x)
1347+
1348+
13291349
def maximum(x, y):
13301350
"""Element-wise maximum of two tensors.
13311351
@@ -1587,13 +1607,27 @@ def tile(x, n):
15871607
return tf.tile(x, n)
15881608

15891609

1590-
def flatten(x):
1591-
"""Flatten a tensor.
1610+
def flatten(x, outdim=1):
1611+
"""Returns a view of this tensor with `outdim` dimensions, whose shape
1612+
for the first `outdim-1` dimensions will be the same as `x`, and
1613+
shape in the remaining dimension will be expanded to fit in
1614+
all the data from `x`.
1615+
1616+
# Arguments
1617+
x: input tensor.
1618+
outdim: number of dimensions in the output tensor.
15921619
15931620
# Returns
1594-
A tensor, reshaped into 1-D
1621+
A tensor, reshaped outdim dimensions.
1622+
15951623
"""
1596-
return tf.reshape(x, [-1])
1624+
1625+
if outdim > 1:
1626+
shape = concatenate([tf.shape(x)[:outdim - 1], variable([-1], dtype='int32')])
1627+
else:
1628+
shape = [-1]
1629+
1630+
return tf.reshape(x, shape)
15971631

15981632

15991633
def batch_flatten(x):
@@ -2023,7 +2057,10 @@ def rnn(step_function, inputs, initial_states,
20232057

20242058
# TODO: remove later.
20252059
if hasattr(tf, 'select'):
2026-
tf.where = tf.select
2060+
where_op = tf.select
2061+
else:
2062+
where_op = tf.where
2063+
20272064
if hasattr(tf, 'stack'):
20282065
stack = tf.stack
20292066
unstack = tf.unstack
@@ -2069,14 +2106,14 @@ def rnn(step_function, inputs, initial_states,
20692106
else:
20702107
prev_output = successive_outputs[-1]
20712108

2072-
output = tf.where(tiled_mask_t, output, prev_output)
2109+
output = where_op(tiled_mask_t, output, prev_output)
20732110

20742111
return_states = []
20752112
for state, new_state in zip(states, new_states):
20762113
# (see earlier comment for tile explanation)
20772114
tiled_mask_t = tf.tile(mask_t,
20782115
stack([1, tf.shape(new_state)[1]]))
2079-
return_states.append(tf.where(tiled_mask_t,
2116+
return_states.append(where_op(tiled_mask_t,
20802117
new_state,
20812118
state))
20822119
states = return_states
@@ -2145,8 +2182,8 @@ def _step(time, output_ta_t, *states):
21452182
new_state.set_shape(state.get_shape())
21462183
tiled_mask_t = tf.tile(mask_t,
21472184
stack([1, tf.shape(output)[1]]))
2148-
output = tf.where(tiled_mask_t, output, states[0])
2149-
new_states = [tf.where(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
2185+
output = where_op(tiled_mask_t, output, states[0])
2186+
new_states = [where_op(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
21502187
output_ta_t = output_ta_t.write(time, output)
21512188
return (time + 1, output_ta_t) + tuple(new_states)
21522189
else:

keras/backend/theano_backend.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,25 @@ def lesser_equal(x, y):
432432
return T.le(x, y)
433433

434434

435+
def where(x):
436+
"""Returns locations of true values in a boolean tensor.
437+
438+
This operation returns the coordinates of true elements in input. The coordinates are
439+
returned in a 2-D tensor where the first dimension (rows) represents the number of
440+
true elements, and the second dimension (columns) represents the coordinates of the
441+
true elements. Keep in mind, the shape of the output tensor can vary depending on
442+
how many true values there are in input.
443+
444+
# Arguments
445+
x: input bool tensor.
446+
447+
# Returns
448+
An integer tensor of indices.
449+
450+
"""
451+
return transpose(x.nonzero(return_matrix=True))
452+
453+
435454
def maximum(x, y):
436455
return T.maximum(x, y)
437456

@@ -687,9 +706,23 @@ def tile(x, n):
687706
return T.tile(x, n)
688707

689708

690-
def flatten(x):
709+
def flatten(x, outdim=1):
710+
"""Returns a view of this tensor with `outdim` dimensions, whose shape
711+
for the first `outdim-1` dimensions will be the same as `x`, and
712+
shape in the remaining dimension will be expanded to fit in
713+
all the data from `x`.
714+
715+
# Arguments
716+
x: input tensor.
717+
outdim: number of dimensions in the output tensor.
718+
719+
# Returns
720+
A tensor, reshaped outdim dimensions.
721+
722+
"""
723+
691724
# TODO: `keras_shape` inference.
692-
return T.flatten(x)
725+
return T.flatten(x, outdim)
693726

694727

695728
def batch_flatten(x):

keras/engine/training.py

+38
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,39 @@ def stop(self, timeout=None):
486486
self.queue = None
487487

488488

489+
def masked_tensor(x, mask):
490+
""" Applies a mask to an input tensor.
491+
492+
# Arguments
493+
x: a tensor of shape `(samples, timesteps, ... )`
494+
mask: a mask boolean tensor of shape `(samples, timesteps)` where each
495+
value represents weather the given timestep in a given sample
496+
should be masked out or not.
497+
498+
# Returns
499+
A tensor of shape `(samples * timesteps, ... )` with all
500+
timesteps from all samples that had value 1 in the mask tensor.
501+
502+
"""
503+
504+
# Flatten first two dimensions of input tensor. We do it by shifting the first two
505+
# dimensions to the end, flattening them and shifting them back to the front.
506+
ndim = K.ndim(x)
507+
shift_end_pattern = tuple(list(range(2, ndim)) + [0, 1])
508+
shift_front_pattern = tuple([ndim - 2] + list(range(0, ndim - 2)))
509+
x = K.permute_dimensions(x, shift_end_pattern)
510+
x = K.flatten(x, ndim - 1)
511+
x = K.permute_dimensions(x, shift_front_pattern)
512+
513+
# Also flatten the 2D mask tensor.
514+
mask = K.flatten(mask)
515+
516+
# Extract indices of flattened mask tensor to keep.
517+
indices = K.flatten(K.where(mask))
518+
519+
return K.gather(x, indices)
520+
521+
489522
class Model(Container):
490523

491524
def compile(self, optimizer, loss, metrics=None, loss_weights=None,
@@ -694,6 +727,11 @@ def append_metric(layer_num, metric_name, metric_tensor):
694727
y_true = self.targets[i]
695728
y_pred = self.outputs[i]
696729
output_metrics = nested_metrics[i]
730+
mask = masks[i]
731+
732+
if mask is not None:
733+
y_true = masked_tensor(y_true, mask)
734+
y_pred = masked_tensor(y_pred, mask)
697735

698736
for metric in output_metrics:
699737
if metric == 'accuracy' or metric == 'acc':

tests/keras/backend/test_backends.py

+8
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def test_shape_operations(self):
105105
pattern=(2, 0, 1))
106106
check_single_tensor_operation('repeat', (4, 1), n=3)
107107
check_single_tensor_operation('flatten', (4, 1))
108+
check_single_tensor_operation('flatten', (4, 4, 4), outdim=2)
108109
check_single_tensor_operation('expand_dims', (4, 3), dim=-1)
109110
check_single_tensor_operation('expand_dims', (4, 3, 2), dim=1)
110111
check_single_tensor_operation('squeeze', (4, 3, 1), axis=2)
@@ -839,6 +840,13 @@ def test_one_hot(self):
839840
koh = K.eval(K.one_hot(K.variable(indices, dtype='int32'), nb_classes))
840841
assert np.all(koh == oh)
841842

843+
def test_where(self):
844+
x = np.random.randint(0, 2, size=(10, 10))
845+
exp_out = np.stack(np.nonzero(x), axis=1)
846+
for K in [KTH, KTF]:
847+
k_out = K.eval(K.where(K.variable(x, dtype='int32')))
848+
assert np.all(k_out == exp_out)
849+
842850
def test_sparse_dot(self):
843851
x_d = np.array([0, 7, 2, 3], dtype=np.float32)
844852
x_r = np.array([0, 2, 2, 3], dtype=np.int64)

tests/test_loss_masking.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import pytest
33

44
from keras.models import Sequential
5-
from keras.engine.training import weighted_objective
6-
from keras.layers.core import TimeDistributedDense, Masking
5+
from keras.engine.training import weighted_objective, masked_tensor
6+
from keras.layers.core import Dense, Masking
7+
from keras.layers.wrappers import TimeDistributed
78
from keras.utils.test_utils import keras_test
89
from keras import objectives
910
from keras import backend as K
@@ -16,12 +17,14 @@ def test_masking():
1617
[[0], [0]]])
1718
model = Sequential()
1819
model.add(Masking(mask_value=0, input_shape=(2, 1)))
19-
model.add(TimeDistributedDense(1, init='one'))
20-
model.compile(loss='mse', optimizer='sgd')
20+
model.add(TimeDistributed(Dense(1, init='one')))
21+
model.compile(loss='mse', optimizer='sgd', metrics=['accuracy'])
2122
y = np.array([[[1], [1]],
2223
[[1], [1]]])
23-
loss = model.train_on_batch(X, y)
24+
(loss, acc) = model.train_on_batch(X, y)
25+
2426
assert loss == 0
27+
assert acc == 1
2528

2629

2730
@keras_test
@@ -42,5 +45,17 @@ def test_loss_masking():
4245
K.variable(mask)))
4346

4447

48+
@keras_test
49+
def test_masked_tensor():
50+
x = np.random.randint(0, 10, size=(5, 10, 5))
51+
mask = np.random.randint(0, 2, size=(5, 10))
52+
i = np.nonzero(mask)
53+
exp_out = x[i[0], i[1], :]
54+
55+
k_out = K.eval(masked_tensor(K.variable(x, dtype='int32'), K.variable(mask, dtype='int32')))
56+
57+
assert np.all(k_out == exp_out)
58+
59+
4560
if __name__ == '__main__':
4661
pytest.main([__file__])

0 commit comments

Comments
 (0)