7
7
from itertools import cycle
8
8
from keras .callbacks import Callback , ModelCheckpoint , EarlyStopping , CSVLogger , ReduceLROnPlateau
9
9
from keras .optimizers import Adam
10
+ from keras .losses import binary_crossentropy
10
11
from keras_contrib .callbacks import DeadReluDetector
11
12
from math import ceil
12
13
from os import path , mkdir , remove
23
24
24
25
from deepcalcium .utils .runtime import funcname
25
26
from deepcalcium .datasets .nf import nf_mask_metrics
26
- from deepcalcium .utils .keras_helpers import MetricsPlotCallback
27
+ from deepcalcium .utils .keras_helpers import MetricsPlotCallback , F1 , prec , reca , dice , dicesq , dice_loss , dicesq_loss , posyt , posyp , load_model_with_new_input_shape
27
28
from deepcalcium .utils .visuals import mask_outlines
29
+ from deepcalcium .utils .data_utils import INVERTIBLE_2D_AUGMENTATIONS
28
30
29
31
30
32
class ValidationMetricsCB (Callback ):
@@ -64,11 +66,14 @@ def on_epoch_end(self, epoch, logs={}):
64
66
logger .info ('\n ' )
65
67
tic = time ()
66
68
67
- # Save weights from the training model and load them into validation model.
68
- path_weights = '%s/weights.tmp' % self .cpdir
69
- self .model .save_weights (path_weights )
70
- self .model_val .load_weights (path_weights )
71
- remove (path_weights )
69
+ # Transfer weights from the training model to the validation model.
70
+ self .model_val .set_weights (self .model .get_weights ())
71
+
72
+ # # Save weights from the training model and load them into validation model.
73
+ # path_weights = '%s/weights.tmp' % self.cpdir
74
+ # self.model.save_weights(path_weights)
75
+ # self.model_val.load_weights(path_weights)
76
+ # remove(path_weights)
72
77
73
78
# Tracking precision, recall, f1 values.
74
79
pp , rr , ff = [], [], []
@@ -207,7 +212,6 @@ def conv_layer(nb_filters, x):
207
212
x = Conv2D (2 , 1 )(x )
208
213
x = Activation ('softmax' )(x )
209
214
x = Lambda (lambda x : x [:, :, :, - 1 ])(x )
210
- # x = Lambda(lambda x: x[:, :, :, -1], output_shape=window_shape)(x)
211
215
212
216
return Model (inputs = inputs , outputs = x )
213
217
@@ -239,16 +243,21 @@ def __init__(self, cpdir='%s/.deep-calcium-datasets/tmp' % path.expanduser('~'),
239
243
if not path .exists (self .cpdir ):
240
244
mkdir (self .cpdir )
241
245
242
- def fit (self , datasets , weights_path = None , shape_trn = (96 , 96 ), shape_val = (512 , 512 ), batch_size_trn = 32 ,
246
+ cobj = [F1 , prec , reca , dice , dicesq , posyt , posyp , dice_loss , dicesq_loss ]
247
+ self .custom_objects = {x .__name__ : x for x in cobj }
248
+
249
+ def fit (self , datasets , model_path = None , proceed = False , shape_trn = (96 , 96 ), shape_val = (512 , 512 ), batch_size_trn = 32 ,
243
250
batch_size_val = 1 , nb_steps_trn = 200 , nb_epochs = 20 , prop_trn = 0.75 , prop_val = 0.25 , keras_callbacks = [],
244
- optimizer = Adam (0.002 ), loss = ' binary_crossentropy' ):
251
+ optimizer = Adam (0.002 ), loss = binary_crossentropy ):
245
252
"""Constructs network based on parameters and trains with the given data.
246
253
247
254
# Arguments
248
255
datasets: List of HDF5 datasets. Each of these will be passed to self.series_summary_func and
249
256
self.mask_summary_func to compute its series and mask summaries, so the HDF5 structure
250
257
should be compatible with those functions.
251
- weights_path: filesystem path to weights that should be loaded into the network.
258
+ model_path: filesystem path to serialized model that should be loaded into the network.
259
+ proceed: whether to continue training where the model left off or start over. Only relevant when a
260
+ model_path is given because it uses the saved optimizer state.
252
261
shape_trn: (height, width) shape of the windows cropped for training.
253
262
shape_val: (height, width) shape of the windows used for validation.
254
263
batch_size_trn: Batch size used for training.
@@ -258,92 +267,67 @@ def fit(self, datasets, weights_path=None, shape_trn=(96, 96), shape_val=(512, 5
258
267
prop_val: Proportion of each summary image used to validate, cropped from the bottom of the image.
259
268
keras_callbacks: List of callbacks appended to internal callbacks for training.
260
269
optimizer: Instanitated keras optimizer.
261
- loss: Loss function, currently either binary_crossentropy or dice_squared from https://arxiv.org/abs/1606.04797.
270
+ loss: Loss function, one of binary_crossentropy, dice, or dice-squared from https://arxiv.org/abs/1606.04797.
271
+
272
+ # Returns
273
+ history: the Keras training history as a dictionary of metrics and their values after each epoch.
274
+
262
275
"""
263
276
277
+ # Error check.
264
278
assert len (shape_trn ) == 2
265
279
assert len (shape_val ) == 2
266
280
assert shape_trn [0 ] == shape_trn [1 ]
267
281
assert shape_val [0 ] == shape_val [1 ]
268
282
assert 0 < prop_trn < 1
269
283
assert 0 < prop_val < 1
270
- assert loss in {'binary_crossentropy' , 'dice_squared' }
271
-
272
- logger = logging .getLogger (funcname ())
273
-
274
- # Define, compile neural net.
275
- model = self .net_builder (shape_trn )
276
- model_val = self .net_builder (shape_val )
277
- json .dump (model .to_json (), open ('%s/model.json' % self .cpdir , 'w' ), indent = 2 )
278
-
279
- # Metric: True positive proportion.
280
- def ytpos (yt , yp ):
281
- size = K .sum (K .ones_like (yt ))
282
- return K .sum (yt ) / (size + K .epsilon ())
283
-
284
- # Metric: Predicted positive proportion.
285
- def yppos (yt , yp ):
286
- size = K .sum (K .ones_like (yp ))
287
- return K .sum (K .round (yp )) / (size + K .epsilon ())
288
-
289
- # Metric: Binary pixel-wise precision.
290
- def prec (yt , yp ):
291
- yp = K .round (yp )
292
- tp = K .sum (yt * yp )
293
- fp = K .sum (K .clip (yp - yt , 0 , 1 ))
294
- return tp / (tp + fp + K .epsilon ())
295
-
296
- # Metric: Binary pixel-wise recall.
297
- def reca (yt , yp ):
298
- yp = K .round (yp )
299
- tp = K .sum (yt * yp )
300
- fn = K .sum (K .clip (yt - yp , 0 , 1 ))
301
- return tp / (tp + fn + K .epsilon ())
302
-
303
- # Metric: Squared dice coefficient from VNet paper.
304
- def dice_squared (yt , yp ):
305
- nmr = 2 * K .sum (yt * yp )
306
- dnm = K .sum (yt ** 2 ) + K .sum (yp ** 2 ) + K .epsilon ()
307
- return (nmr / dnm )
308
-
309
- def dice_squared_loss (yt , yp ):
310
- return 1 - dice_squared (yt , yp )
311
-
312
- if loss == 'dice_squared' :
313
- loss = dice_squared_loss
284
+ assert not (proceed and not model_path )
285
+
286
+ losses = {
287
+ 'binary_crossentropy' : binary_crossentropy ,
288
+ 'dice_loss' : dice_loss ,
289
+ 'dicesq_loss' : dicesq_loss
290
+ }
291
+ assert loss in losses .keys () or loss in losses .values ()
292
+ loss = losses [loss ] if type (loss ) == str else loss
293
+
294
+ # Load network from disk.
295
+ if model_path :
296
+ lmwnis = load_model_with_new_input_shape
297
+ model = lmwnis (model_path , shape_trn , compile = proceed ,
298
+ custom_objects = self .custom_objects )
299
+ model_val = lmwnis (model_path , shape_val , compile = False ,
300
+ custom_objects = self .custom_objects )
301
+
302
+ # Define, compile network.
314
303
else :
315
- loss = 'binary_crossentropy'
304
+ model = self .net_builder (shape_trn )
305
+ model_val = self .net_builder (shape_val )
306
+ model .summary ()
316
307
317
- model .compile (optimizer = optimizer , loss = loss ,
318
- metrics = [dice_squared , ytpos , yppos , prec , reca ])
319
- model .summary ()
320
-
321
- if weights_path is not None :
322
- model .load_weights (weights_path )
323
- logger .info ('Loaded weights from %s.' % weights_path )
308
+ if not proceed :
309
+ model .compile (optimizer = optimizer , loss = loss ,
310
+ metrics = [F1 , prec , reca , dice , dicesq , posyt , posyp ])
324
311
325
312
# Pre-compute summaries once to avoid problems with accessing HDF5.
326
313
S_summ = [self .series_summary_func (ds ) for ds in datasets ]
327
314
M_summ = [self .mask_summary_func (ds ) for ds in datasets ]
328
315
329
316
# Define generators for training and validation data.
330
- y_coords_trn = [(0 , int (s .shape [0 ] * prop_trn )) for s in S_summ ]
331
- gen_trn = self .batch_gen_fit (
332
- S_summ , M_summ , y_coords_trn , batch_size_trn , shape_trn , nb_max_augment = 15 )
317
+ yctrn = [(0 , int (s .shape [0 ] * prop_trn )) for s in S_summ ]
318
+ gen_trn = self .batch_gen ( S_summ , M_summ , yctrn , batch_size_trn ,
319
+ shape_trn , nb_max_augment = 15 )
333
320
334
321
# Validation setup.
335
- y_coords_val = [(s .shape [0 ] - int (s .shape [0 ] * prop_val ), s .shape [0 ])
336
- for s in S_summ ]
337
-
322
+ ycval = [(s .shape [0 ] - int (s .shape [0 ] * prop_val ), s .shape [0 ]) for s in S_summ ]
338
323
names = [ds .attrs ['name' ] for ds in datasets ]
339
324
340
325
callbacks = [
341
- ValidationMetricsCB (model_val , S_summ , M_summ ,
342
- names , y_coords_val , self .cpdir ),
326
+ ValidationMetricsCB (model_val , S_summ , M_summ , names , ycval , self .cpdir ),
343
327
CSVLogger ('%s/metrics.csv' % self .cpdir ),
344
328
MetricsPlotCallback ('%s/metrics.png' % self .cpdir ,
345
329
'%s/metrics.csv' % self .cpdir ),
346
- ModelCheckpoint ('%s/weights_val_nf_f1_mean .hdf5' % self .cpdir , mode = 'max' ,
330
+ ModelCheckpoint ('%s/model_val_nf_f1_mean .hdf5' % self .cpdir , mode = 'max' ,
347
331
monitor = 'val_nf_f1_mean' , save_best_only = True , verbose = 1 ),
348
332
EarlyStopping (monitor = 'val_nf_f1_mean' , min_delta = 1e-3 ,
349
333
patience = 10 , verbose = 1 , mode = 'max' ),
@@ -360,7 +344,7 @@ def dice_squared_loss(yt, yp):
360
344
361
345
return trained .history
362
346
363
- def batch_gen_fit (self , S_summ , M_summ , y_coords , batch_size , window_shape , nb_max_augment = 0 ):
347
+ def batch_gen (self , S_summ , M_summ , y_coords , batch_size , window_shape , nb_max_augment = 0 ):
364
348
"""Builds and yields batches of image windows and corresponding mask windows for training.
365
349
Includes random data augmentation.
366
350
@@ -460,94 +444,92 @@ def stretch(a, b):
460
444
461
445
yield s_batch , m_batch
462
446
463
- def evaluate (self , datasets , weights_path = None , window_shape = (512 , 512 ), save = False ):
464
- """Evaluates predicted masks vs. true masks for the given sequences."""
465
-
466
- logger = logging .getLogger (funcname ())
467
-
468
- model = self .net_builder (window_shape )
469
- if weights_path is not None :
470
- model .load_weights (weights_path )
471
- logger .info ('Loaded weights from %s.' % weights_path )
447
+ def predict (self , datasets , model_path , window_shape = (512 , 512 ), print_scores = False , save = False , augmentation = False ):
448
+ """Make predictions on the given datasets. Currently uses batches of 1.
472
449
473
- # Currently only supporting full-sized windows.
474
- assert window_shape == (512 , 512 ), 'TODO: implement variable window sizes.'
475
-
476
- # Padding helper.
477
- _ , hw , ww = model .input_shape
478
- pad = lambda x : np .pad (
479
- x , ((0 , hw - x .shape [0 ]), (0 , ww - x .shape [1 ])), mode = 'reflect' )
480
-
481
- # Evaluate each sequence, mask pair.
482
- mean_prec , mean_reca , mean_comb = 0. , 0. , 0.
483
- for ds in datasets :
484
- name = ds .attrs ['name' ]
485
- s = self .series_summary_func (ds )
486
- m = self .mask_summary_func (ds )
487
- hs , ws = s .shape
488
-
489
- # Pad and make prediction.
490
- s_batch = np .zeros ((1 , ) + window_shape )
491
- s_batch [0 ] = pad (s )
492
- mp = model .predict (s_batch )[0 , :hs , :ws ].round ()
493
-
494
- # Track scores.
495
- prec , reca , incl , excl , comb = nf_mask_metrics (m , mp )
496
- logger .info ('%s: prec=%.3lf, reca=%.3lf, incl=%.3lf, excl=%.3lf, comb=%.3lf' % (
497
- name , prec , reca , incl , excl , comb ))
498
- mean_prec += prec / len (datasets )
499
- mean_reca += reca / len (datasets )
500
- mean_comb += comb / len (datasets )
501
-
502
- # Save mask and prediction.
503
- if save :
504
- imsave ('%s/%s_mp.png' % (self .cpdir , name ),
505
- mask_outlines (s , [m , mp ], ['blue' , 'red' ]))
506
-
507
- logger .info ('Mean prec=%.3lf, reca=%.3lf, comb=%.3lf' %
508
- (mean_prec , mean_reca , mean_comb ))
509
-
510
- return mean_comb
450
+ Arguments:
451
+ datasets: List of HDF5 datasets. Each of these will be passed to self.series_summary_func and
452
+ self.mask_summary_func to compute its series and mask summaries, so the HDF5 structure
453
+ should be compatible with those functions.
454
+ model_path: Path to the serialized Keras model HDF5 file. This file should include both the
455
+ architecture and the weights.
456
+ window_shape: Tuple window shape used for making predictions. Summary images with windows smaller
457
+ than this are padded up to match this shape.
458
+ print_scores: Flag to print the Neurofinder evaluation metrics. Only works when the datasets include
459
+ ground-truth masks.
460
+ save: Flag to save the predictions as PNGs with outlines around the predicted neurons in red. If
461
+ the ground-truth masks are given, it will also show outlines around the groun-truth neurons.
462
+ augmentation: Flag to perform 8x test-time augmentation. Predictions are made for each of the
463
+ augmentations, the augmentation is inverted to its original orientation, and the average
464
+ of all the augmentations is used as the prediction. In practice, this improved a
465
+ Neurofinder submission from 0.5356 to 0.542.
466
+
467
+ Returns:
468
+ Mp: list of the predicted masks stored as Numpy arrays.
511
469
512
- def predict (self , datasets , weights_path = None , window_shape = (512 , 512 ), batch_size = 10 , save = False ):
513
- """Predicts masks for the given sequences. Optionally saves the masks. Returns the masks as numpy arrays in order corresponding the given sequences."""
470
+ """
514
471
515
472
logger = logging .getLogger (funcname ())
516
-
517
- model = self .net_builder (window_shape )
518
- if weights_path is not None :
519
- model .load_weights (weights_path )
520
- logger .info ('Loaded weights from %s.' % weights_path )
473
+ model = load_model_with_new_input_shape (model_path , window_shape , compile = False ,
474
+ custom_objects = self .custom_objects )
475
+ logger .info ('Loaded model from %s.' % model_path )
521
476
522
477
# Currently only supporting full-sized windows.
523
478
assert window_shape == (512 , 512 ), 'TODO: implement variable window sizes.'
524
479
525
480
# Padding helper.
526
- _ , hw , ww = model .input_shape
527
- pad = lambda x : np .pad (x , ((0 , hw - x .shape [0 ]), (0 , ww - x .shape [1 ])), 'reflect' )
481
+ def pad (x ):
482
+ _ , hw , ww = model .input_shape
483
+ return np .pad (x , ((0 , hw - x .shape [0 ]), (0 , ww - x .shape [1 ])), mode = 'reflect' )
528
484
529
- # Store predictions .
485
+ # Store predicted masks and scores .
530
486
Mp = []
487
+ mean_prec , mean_reca , mean_comb = 0. , 0. , 0.
531
488
532
489
# Evaluate each sequence, mask pair.
533
- mean_prec , mean_reca , mean_comb = 0. , 0. , 0.
534
490
for ds in datasets :
535
491
name = ds .attrs ['name' ]
536
492
s = self .series_summary_func (ds )
537
493
hs , ws = s .shape
538
494
539
- # Pad and make prediction.
495
+ # Pad and make prediction(s) .
540
496
s_batch = np .zeros ((1 , ) + window_shape )
541
497
s_batch [0 ] = pad (s )
542
- mp = model .predict (s_batch )[0 , :hs , :ws ].round ()
543
- assert mp .shape == s .shape
498
+
499
+ if augmentation :
500
+ mp = np .zeros_like (s )
501
+ for name , aug , inv in INVERTIBLE_2D_AUGMENTATIONS :
502
+ mpaug = model .predict (aug (s_batch ))
503
+ mp += inv (mpaug )[0 , :hs , :ws ] / len (INVERTIBLE_2D_AUGMENTATIONS )
504
+ mp = mp .round ()
505
+
506
+ else :
507
+ mp = model .predict (s_batch )[0 , :hs , :ws ].round ()
508
+
544
509
Mp .append (mp )
545
510
546
- # Save prediction.
547
- if save :
511
+ # Track scores.
512
+ if print_scores :
513
+ m = self .mask_summary_func (ds )
514
+ prec , reca , incl , excl , comb = nf_mask_metrics (m , mp )
515
+ logger .info ('%s: prec=%.3lf, reca=%.3lf, incl=%.3lf, excl=%.3lf, comb=%.3lf' % (
516
+ name , prec , reca , incl , excl , comb ))
517
+ mean_prec += prec / len (datasets )
518
+ mean_reca += reca / len (datasets )
519
+ mean_comb += comb / len (datasets )
520
+
521
+ # Save mask and prediction.
522
+ if save and 'masks' in ds :
523
+ m = self .mask_summary_func (ds )
524
+ outlined = mask_outlines (s , [m , mp ], ['blue' , 'red' ])
525
+ imsave ('%s/%s_mp.png' % (self .cpdir , name ), outlined )
526
+
527
+ elif save :
548
528
outlined = mask_outlines (s , [mp ], ['red' ])
549
529
imsave ('%s/%s_mp.png' % (self .cpdir , name ), outlined )
550
530
551
- logger .info ('%s prediction complete.' % name )
531
+ if print_scores :
532
+ logger .info ('Mean prec=%.3lf, reca=%.3lf, comb=%.3lf' %
533
+ (mean_prec , mean_reca , mean_comb ))
552
534
553
535
return Mp
0 commit comments