Skip to content

Commit f1b33bf

Browse files
committed
Model loading, test-time augmentation.
- Set up UNet2DS and the unet2ds_nf example to load the architecture and weights from the same HDF5 file. - Implemented 8x test-time augmentation for UNet2DS. This improved the test score from 0.535 to 0.542.
1 parent 87869d9 commit f1b33bf

File tree

5 files changed

+247
-202
lines changed

5 files changed

+247
-202
lines changed

deepcalcium/models/neurons/unet_2d_summary.py

+119-137
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from itertools import cycle
88
from keras.callbacks import Callback, ModelCheckpoint, EarlyStopping, CSVLogger, ReduceLROnPlateau
99
from keras.optimizers import Adam
10+
from keras.losses import binary_crossentropy
1011
from keras_contrib.callbacks import DeadReluDetector
1112
from math import ceil
1213
from os import path, mkdir, remove
@@ -23,8 +24,9 @@
2324

2425
from deepcalcium.utils.runtime import funcname
2526
from deepcalcium.datasets.nf import nf_mask_metrics
26-
from deepcalcium.utils.keras_helpers import MetricsPlotCallback
27+
from deepcalcium.utils.keras_helpers import MetricsPlotCallback, F1, prec, reca, dice, dicesq, dice_loss, dicesq_loss, posyt, posyp, load_model_with_new_input_shape
2728
from deepcalcium.utils.visuals import mask_outlines
29+
from deepcalcium.utils.data_utils import INVERTIBLE_2D_AUGMENTATIONS
2830

2931

3032
class ValidationMetricsCB(Callback):
@@ -64,11 +66,14 @@ def on_epoch_end(self, epoch, logs={}):
6466
logger.info('\n')
6567
tic = time()
6668

67-
# Save weights from the training model and load them into validation model.
68-
path_weights = '%s/weights.tmp' % self.cpdir
69-
self.model.save_weights(path_weights)
70-
self.model_val.load_weights(path_weights)
71-
remove(path_weights)
69+
# Transfer weights from the training model to the validation model.
70+
self.model_val.set_weights(self.model.get_weights())
71+
72+
# # Save weights from the training model and load them into validation model.
73+
# path_weights = '%s/weights.tmp' % self.cpdir
74+
# self.model.save_weights(path_weights)
75+
# self.model_val.load_weights(path_weights)
76+
# remove(path_weights)
7277

7378
# Tracking precision, recall, f1 values.
7479
pp, rr, ff = [], [], []
@@ -207,7 +212,6 @@ def conv_layer(nb_filters, x):
207212
x = Conv2D(2, 1)(x)
208213
x = Activation('softmax')(x)
209214
x = Lambda(lambda x: x[:, :, :, -1])(x)
210-
# x = Lambda(lambda x: x[:, :, :, -1], output_shape=window_shape)(x)
211215

212216
return Model(inputs=inputs, outputs=x)
213217

@@ -239,16 +243,21 @@ def __init__(self, cpdir='%s/.deep-calcium-datasets/tmp' % path.expanduser('~'),
239243
if not path.exists(self.cpdir):
240244
mkdir(self.cpdir)
241245

242-
def fit(self, datasets, weights_path=None, shape_trn=(96, 96), shape_val=(512, 512), batch_size_trn=32,
246+
cobj = [F1, prec, reca, dice, dicesq, posyt, posyp, dice_loss, dicesq_loss]
247+
self.custom_objects = {x.__name__: x for x in cobj}
248+
249+
def fit(self, datasets, model_path=None, proceed=False, shape_trn=(96, 96), shape_val=(512, 512), batch_size_trn=32,
243250
batch_size_val=1, nb_steps_trn=200, nb_epochs=20, prop_trn=0.75, prop_val=0.25, keras_callbacks=[],
244-
optimizer=Adam(0.002), loss='binary_crossentropy'):
251+
optimizer=Adam(0.002), loss=binary_crossentropy):
245252
"""Constructs network based on parameters and trains with the given data.
246253
247254
# Arguments
248255
datasets: List of HDF5 datasets. Each of these will be passed to self.series_summary_func and
249256
self.mask_summary_func to compute its series and mask summaries, so the HDF5 structure
250257
should be compatible with those functions.
251-
weights_path: filesystem path to weights that should be loaded into the network.
258+
model_path: filesystem path to serialized model that should be loaded into the network.
259+
proceed: whether to continue training where the model left off or start over. Only relevant when a
260+
model_path is given because it uses the saved optimizer state.
252261
shape_trn: (height, width) shape of the windows cropped for training.
253262
shape_val: (height, width) shape of the windows used for validation.
254263
batch_size_trn: Batch size used for training.
@@ -258,92 +267,67 @@ def fit(self, datasets, weights_path=None, shape_trn=(96, 96), shape_val=(512, 5
258267
prop_val: Proportion of each summary image used to validate, cropped from the bottom of the image.
259268
keras_callbacks: List of callbacks appended to internal callbacks for training.
260269
optimizer: Instanitated keras optimizer.
261-
loss: Loss function, currently either binary_crossentropy or dice_squared from https://arxiv.org/abs/1606.04797.
270+
loss: Loss function, one of binary_crossentropy, dice, or dice-squared from https://arxiv.org/abs/1606.04797.
271+
272+
# Returns
273+
history: the Keras training history as a dictionary of metrics and their values after each epoch.
274+
262275
"""
263276

277+
# Error check.
264278
assert len(shape_trn) == 2
265279
assert len(shape_val) == 2
266280
assert shape_trn[0] == shape_trn[1]
267281
assert shape_val[0] == shape_val[1]
268282
assert 0 < prop_trn < 1
269283
assert 0 < prop_val < 1
270-
assert loss in {'binary_crossentropy', 'dice_squared'}
271-
272-
logger = logging.getLogger(funcname())
273-
274-
# Define, compile neural net.
275-
model = self.net_builder(shape_trn)
276-
model_val = self.net_builder(shape_val)
277-
json.dump(model.to_json(), open('%s/model.json' % self.cpdir, 'w'), indent=2)
278-
279-
# Metric: True positive proportion.
280-
def ytpos(yt, yp):
281-
size = K.sum(K.ones_like(yt))
282-
return K.sum(yt) / (size + K.epsilon())
283-
284-
# Metric: Predicted positive proportion.
285-
def yppos(yt, yp):
286-
size = K.sum(K.ones_like(yp))
287-
return K.sum(K.round(yp)) / (size + K.epsilon())
288-
289-
# Metric: Binary pixel-wise precision.
290-
def prec(yt, yp):
291-
yp = K.round(yp)
292-
tp = K.sum(yt * yp)
293-
fp = K.sum(K.clip(yp - yt, 0, 1))
294-
return tp / (tp + fp + K.epsilon())
295-
296-
# Metric: Binary pixel-wise recall.
297-
def reca(yt, yp):
298-
yp = K.round(yp)
299-
tp = K.sum(yt * yp)
300-
fn = K.sum(K.clip(yt - yp, 0, 1))
301-
return tp / (tp + fn + K.epsilon())
302-
303-
# Metric: Squared dice coefficient from VNet paper.
304-
def dice_squared(yt, yp):
305-
nmr = 2 * K.sum(yt * yp)
306-
dnm = K.sum(yt**2) + K.sum(yp**2) + K.epsilon()
307-
return (nmr / dnm)
308-
309-
def dice_squared_loss(yt, yp):
310-
return 1 - dice_squared(yt, yp)
311-
312-
if loss == 'dice_squared':
313-
loss = dice_squared_loss
284+
assert not (proceed and not model_path)
285+
286+
losses = {
287+
'binary_crossentropy': binary_crossentropy,
288+
'dice_loss': dice_loss,
289+
'dicesq_loss': dicesq_loss
290+
}
291+
assert loss in losses.keys() or loss in losses.values()
292+
loss = losses[loss] if type(loss) == str else loss
293+
294+
# Load network from disk.
295+
if model_path:
296+
lmwnis = load_model_with_new_input_shape
297+
model = lmwnis(model_path, shape_trn, compile=proceed,
298+
custom_objects=self.custom_objects)
299+
model_val = lmwnis(model_path, shape_val, compile=False,
300+
custom_objects=self.custom_objects)
301+
302+
# Define, compile network.
314303
else:
315-
loss = 'binary_crossentropy'
304+
model = self.net_builder(shape_trn)
305+
model_val = self.net_builder(shape_val)
306+
model.summary()
316307

317-
model.compile(optimizer=optimizer, loss=loss,
318-
metrics=[dice_squared, ytpos, yppos, prec, reca])
319-
model.summary()
320-
321-
if weights_path is not None:
322-
model.load_weights(weights_path)
323-
logger.info('Loaded weights from %s.' % weights_path)
308+
if not proceed:
309+
model.compile(optimizer=optimizer, loss=loss,
310+
metrics=[F1, prec, reca, dice, dicesq, posyt, posyp])
324311

325312
# Pre-compute summaries once to avoid problems with accessing HDF5.
326313
S_summ = [self.series_summary_func(ds) for ds in datasets]
327314
M_summ = [self.mask_summary_func(ds) for ds in datasets]
328315

329316
# Define generators for training and validation data.
330-
y_coords_trn = [(0, int(s.shape[0] * prop_trn)) for s in S_summ]
331-
gen_trn = self.batch_gen_fit(
332-
S_summ, M_summ, y_coords_trn, batch_size_trn, shape_trn, nb_max_augment=15)
317+
yctrn = [(0, int(s.shape[0] * prop_trn)) for s in S_summ]
318+
gen_trn = self.batch_gen(S_summ, M_summ, yctrn, batch_size_trn,
319+
shape_trn, nb_max_augment=15)
333320

334321
# Validation setup.
335-
y_coords_val = [(s.shape[0] - int(s.shape[0] * prop_val), s.shape[0])
336-
for s in S_summ]
337-
322+
ycval = [(s.shape[0] - int(s.shape[0] * prop_val), s.shape[0]) for s in S_summ]
338323
names = [ds.attrs['name'] for ds in datasets]
339324

340325
callbacks = [
341-
ValidationMetricsCB(model_val, S_summ, M_summ,
342-
names, y_coords_val, self.cpdir),
326+
ValidationMetricsCB(model_val, S_summ, M_summ, names, ycval, self.cpdir),
343327
CSVLogger('%s/metrics.csv' % self.cpdir),
344328
MetricsPlotCallback('%s/metrics.png' % self.cpdir,
345329
'%s/metrics.csv' % self.cpdir),
346-
ModelCheckpoint('%s/weights_val_nf_f1_mean.hdf5' % self.cpdir, mode='max',
330+
ModelCheckpoint('%s/model_val_nf_f1_mean.hdf5' % self.cpdir, mode='max',
347331
monitor='val_nf_f1_mean', save_best_only=True, verbose=1),
348332
EarlyStopping(monitor='val_nf_f1_mean', min_delta=1e-3,
349333
patience=10, verbose=1, mode='max'),
@@ -360,7 +344,7 @@ def dice_squared_loss(yt, yp):
360344

361345
return trained.history
362346

363-
def batch_gen_fit(self, S_summ, M_summ, y_coords, batch_size, window_shape, nb_max_augment=0):
347+
def batch_gen(self, S_summ, M_summ, y_coords, batch_size, window_shape, nb_max_augment=0):
364348
"""Builds and yields batches of image windows and corresponding mask windows for training.
365349
Includes random data augmentation.
366350
@@ -460,94 +444,92 @@ def stretch(a, b):
460444

461445
yield s_batch, m_batch
462446

463-
def evaluate(self, datasets, weights_path=None, window_shape=(512, 512), save=False):
464-
"""Evaluates predicted masks vs. true masks for the given sequences."""
465-
466-
logger = logging.getLogger(funcname())
467-
468-
model = self.net_builder(window_shape)
469-
if weights_path is not None:
470-
model.load_weights(weights_path)
471-
logger.info('Loaded weights from %s.' % weights_path)
447+
def predict(self, datasets, model_path, window_shape=(512, 512), print_scores=False, save=False, augmentation=False):
448+
"""Make predictions on the given datasets. Currently uses batches of 1.
472449
473-
# Currently only supporting full-sized windows.
474-
assert window_shape == (512, 512), 'TODO: implement variable window sizes.'
475-
476-
# Padding helper.
477-
_, hw, ww = model.input_shape
478-
pad = lambda x: np.pad(
479-
x, ((0, hw - x.shape[0]), (0, ww - x.shape[1])), mode='reflect')
480-
481-
# Evaluate each sequence, mask pair.
482-
mean_prec, mean_reca, mean_comb = 0., 0., 0.
483-
for ds in datasets:
484-
name = ds.attrs['name']
485-
s = self.series_summary_func(ds)
486-
m = self.mask_summary_func(ds)
487-
hs, ws = s.shape
488-
489-
# Pad and make prediction.
490-
s_batch = np.zeros((1, ) + window_shape)
491-
s_batch[0] = pad(s)
492-
mp = model.predict(s_batch)[0, :hs, :ws].round()
493-
494-
# Track scores.
495-
prec, reca, incl, excl, comb = nf_mask_metrics(m, mp)
496-
logger.info('%s: prec=%.3lf, reca=%.3lf, incl=%.3lf, excl=%.3lf, comb=%.3lf' % (
497-
name, prec, reca, incl, excl, comb))
498-
mean_prec += prec / len(datasets)
499-
mean_reca += reca / len(datasets)
500-
mean_comb += comb / len(datasets)
501-
502-
# Save mask and prediction.
503-
if save:
504-
imsave('%s/%s_mp.png' % (self.cpdir, name),
505-
mask_outlines(s, [m, mp], ['blue', 'red']))
506-
507-
logger.info('Mean prec=%.3lf, reca=%.3lf, comb=%.3lf' %
508-
(mean_prec, mean_reca, mean_comb))
509-
510-
return mean_comb
450+
Arguments:
451+
datasets: List of HDF5 datasets. Each of these will be passed to self.series_summary_func and
452+
self.mask_summary_func to compute its series and mask summaries, so the HDF5 structure
453+
should be compatible with those functions.
454+
model_path: Path to the serialized Keras model HDF5 file. This file should include both the
455+
architecture and the weights.
456+
window_shape: Tuple window shape used for making predictions. Summary images with windows smaller
457+
than this are padded up to match this shape.
458+
print_scores: Flag to print the Neurofinder evaluation metrics. Only works when the datasets include
459+
ground-truth masks.
460+
save: Flag to save the predictions as PNGs with outlines around the predicted neurons in red. If
461+
the ground-truth masks are given, it will also show outlines around the groun-truth neurons.
462+
augmentation: Flag to perform 8x test-time augmentation. Predictions are made for each of the
463+
augmentations, the augmentation is inverted to its original orientation, and the average
464+
of all the augmentations is used as the prediction. In practice, this improved a
465+
Neurofinder submission from 0.5356 to 0.542.
466+
467+
Returns:
468+
Mp: list of the predicted masks stored as Numpy arrays.
511469
512-
def predict(self, datasets, weights_path=None, window_shape=(512, 512), batch_size=10, save=False):
513-
"""Predicts masks for the given sequences. Optionally saves the masks. Returns the masks as numpy arrays in order corresponding the given sequences."""
470+
"""
514471

515472
logger = logging.getLogger(funcname())
516-
517-
model = self.net_builder(window_shape)
518-
if weights_path is not None:
519-
model.load_weights(weights_path)
520-
logger.info('Loaded weights from %s.' % weights_path)
473+
model = load_model_with_new_input_shape(model_path, window_shape, compile=False,
474+
custom_objects=self.custom_objects)
475+
logger.info('Loaded model from %s.' % model_path)
521476

522477
# Currently only supporting full-sized windows.
523478
assert window_shape == (512, 512), 'TODO: implement variable window sizes.'
524479

525480
# Padding helper.
526-
_, hw, ww = model.input_shape
527-
pad = lambda x: np.pad(x, ((0, hw - x.shape[0]), (0, ww - x.shape[1])), 'reflect')
481+
def pad(x):
482+
_, hw, ww = model.input_shape
483+
return np.pad(x, ((0, hw - x.shape[0]), (0, ww - x.shape[1])), mode='reflect')
528484

529-
# Store predictions.
485+
# Store predicted masks and scores.
530486
Mp = []
487+
mean_prec, mean_reca, mean_comb = 0., 0., 0.
531488

532489
# Evaluate each sequence, mask pair.
533-
mean_prec, mean_reca, mean_comb = 0., 0., 0.
534490
for ds in datasets:
535491
name = ds.attrs['name']
536492
s = self.series_summary_func(ds)
537493
hs, ws = s.shape
538494

539-
# Pad and make prediction.
495+
# Pad and make prediction(s).
540496
s_batch = np.zeros((1, ) + window_shape)
541497
s_batch[0] = pad(s)
542-
mp = model.predict(s_batch)[0, :hs, :ws].round()
543-
assert mp.shape == s.shape
498+
499+
if augmentation:
500+
mp = np.zeros_like(s)
501+
for name, aug, inv in INVERTIBLE_2D_AUGMENTATIONS:
502+
mpaug = model.predict(aug(s_batch))
503+
mp += inv(mpaug)[0, :hs, :ws] / len(INVERTIBLE_2D_AUGMENTATIONS)
504+
mp = mp.round()
505+
506+
else:
507+
mp = model.predict(s_batch)[0, :hs, :ws].round()
508+
544509
Mp.append(mp)
545510

546-
# Save prediction.
547-
if save:
511+
# Track scores.
512+
if print_scores:
513+
m = self.mask_summary_func(ds)
514+
prec, reca, incl, excl, comb = nf_mask_metrics(m, mp)
515+
logger.info('%s: prec=%.3lf, reca=%.3lf, incl=%.3lf, excl=%.3lf, comb=%.3lf' % (
516+
name, prec, reca, incl, excl, comb))
517+
mean_prec += prec / len(datasets)
518+
mean_reca += reca / len(datasets)
519+
mean_comb += comb / len(datasets)
520+
521+
# Save mask and prediction.
522+
if save and 'masks' in ds:
523+
m = self.mask_summary_func(ds)
524+
outlined = mask_outlines(s, [m, mp], ['blue', 'red'])
525+
imsave('%s/%s_mp.png' % (self.cpdir, name), outlined)
526+
527+
elif save:
548528
outlined = mask_outlines(s, [mp], ['red'])
549529
imsave('%s/%s_mp.png' % (self.cpdir, name), outlined)
550530

551-
logger.info('%s prediction complete.' % name)
531+
if print_scores:
532+
logger.info('Mean prec=%.3lf, reca=%.3lf, comb=%.3lf' %
533+
(mean_prec, mean_reca, mean_comb))
552534

553535
return Mp

deepcalcium/utils/data_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Augmentations that can be applied to a batch of 2D images and inverted.
44
# Structure is the augmentation name, the augmentation, and the inverse
55
# of the augmentation. Intended for test-time augmentation for segmentation.
6-
INVERTIBLE_2D_BATCH_AUGMENTATIONS = [
6+
INVERTIBLE_2D_AUGMENTATIONS = [
77
('identity',
88
lambda x: x,
99
lambda x: x),

0 commit comments

Comments
 (0)