1
1
2
2
import logging
3
+ import math
3
4
import os
4
5
import time
5
6
import re
13
14
from allennlp .common import Params
14
15
from allennlp .common .checks import ConfigurationError
15
16
from allennlp .common .util import (dump_metrics , gpu_memory_mb , parse_cuda_device , peak_memory_mb ,
16
- get_frozen_and_tunable_parameter_names )
17
+ get_frozen_and_tunable_parameter_names , lazy_groups_of )
17
18
from allennlp .common .tqdm import Tqdm
18
19
from allennlp .data .instance import Instance
19
- from allennlp .data .iterators .data_iterator import DataIterator
20
+ from allennlp .data .iterators .data_iterator import DataIterator , TensorDict
20
21
from allennlp .data .vocabulary import Vocabulary
21
22
from allennlp .models .model import Model
22
23
from allennlp .nn import util as nn_util
@@ -216,14 +217,16 @@ def __init__(self,
216
217
def rescale_gradients (self ) -> Optional [float ]:
217
218
return training_util .rescale_gradients (self .model , self ._grad_norm )
218
219
219
- def batch_loss (self , batch : torch . Tensor , for_training : bool ) -> torch .Tensor :
220
+ def batch_loss (self , batch_group : List [ TensorDict ] , for_training : bool ) -> torch .Tensor :
220
221
"""
221
- Does a forward pass on the given batch and returns the ``loss`` value in the result.
222
+ Does a forward pass on the given batches and returns the ``loss`` value in the result.
222
223
If ``for_training`` is `True` also applies regularization penalty.
223
224
"""
224
225
if self ._multiple_gpu :
225
- output_dict = training_util .data_parallel (batch , self .model , self ._cuda_devices )
226
+ output_dict = training_util .data_parallel (batch_group , self .model , self ._cuda_devices )
226
227
else :
228
+ assert len (batch_group ) == 1
229
+ batch = batch_group [0 ]
227
230
batch = nn_util .move_to_device (batch , self ._cuda_devices [0 ])
228
231
output_dict = self .model (** batch )
229
232
@@ -255,11 +258,14 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
255
258
# Set the model to "train" mode.
256
259
self .model .train ()
257
260
261
+ num_gpus = len (self ._cuda_devices )
262
+
258
263
# Get tqdm for the training batches
259
- train_generator = self .iterator (self .train_data ,
260
- num_epochs = 1 ,
261
- shuffle = self .shuffle )
262
- num_training_batches = self .iterator .get_num_batches (self .train_data )
264
+ raw_train_generator = self .iterator (self .train_data ,
265
+ num_epochs = 1 ,
266
+ shuffle = self .shuffle )
267
+ train_generator = lazy_groups_of (raw_train_generator , num_gpus )
268
+ num_training_batches = math .ceil (self .iterator .get_num_batches (self .train_data )/ num_gpus )
263
269
self ._last_log = time .time ()
264
270
last_save_time = time .time ()
265
271
@@ -269,18 +275,20 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
269
275
270
276
histogram_parameters = set (self .model .get_parameters_for_histogram_tensorboard_logging ())
271
277
278
+
272
279
logger .info ("Training" )
273
280
train_generator_tqdm = Tqdm .tqdm (train_generator ,
274
281
total = num_training_batches )
275
282
cumulative_batch_size = 0
276
- for batch in train_generator_tqdm :
283
+ for batch_group in train_generator_tqdm :
277
284
batches_this_epoch += 1
278
285
self ._batch_num_total += 1
279
286
batch_num_total = self ._batch_num_total
280
287
281
288
self .optimizer .zero_grad ()
282
289
283
- loss = self .batch_loss (batch , for_training = True )
290
+ loss = self .batch_loss (batch_group , for_training = True )
291
+
284
292
if torch .isnan (loss ):
285
293
raise ValueError ("nan loss encountered" )
286
294
@@ -329,7 +337,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
329
337
self ._tensorboard .log_histograms (self .model , histogram_parameters )
330
338
331
339
if self ._log_batch_size_period :
332
- cur_batch = training_util .get_batch_size (batch )
340
+ cur_batch = sum ([ training_util .get_batch_size (batch ) for batch in batch_group ] )
333
341
cumulative_batch_size += cur_batch
334
342
if (batches_this_epoch - 1 ) % self ._log_batch_size_period == 0 :
335
343
average = cumulative_batch_size / batches_this_epoch
@@ -365,17 +373,20 @@ def _validation_loss(self) -> Tuple[float, int]:
365
373
else :
366
374
val_iterator = self .iterator
367
375
368
- val_generator = val_iterator (self ._validation_data ,
369
- num_epochs = 1 ,
370
- shuffle = False )
371
- num_validation_batches = val_iterator .get_num_batches (self ._validation_data )
376
+ num_gpus = len (self ._cuda_devices )
377
+
378
+ raw_val_generator = val_iterator (self ._validation_data ,
379
+ num_epochs = 1 ,
380
+ shuffle = False )
381
+ val_generator = lazy_groups_of (raw_val_generator , num_gpus )
382
+ num_validation_batches = math .ceil (val_iterator .get_num_batches (self ._validation_data )/ num_gpus )
372
383
val_generator_tqdm = Tqdm .tqdm (val_generator ,
373
384
total = num_validation_batches )
374
385
batches_this_epoch = 0
375
386
val_loss = 0
376
- for batch in val_generator_tqdm :
387
+ for batch_group in val_generator_tqdm :
377
388
378
- loss = self .batch_loss (batch , for_training = False )
389
+ loss = self .batch_loss (batch_group , for_training = False )
379
390
if loss is not None :
380
391
# You shouldn't necessarily have to compute a loss for validation, so we allow for
381
392
# `loss` to be None. We need to be careful, though - `batches_this_epoch` is
0 commit comments