Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 7525c61

Browse files
authored
Remove scattering for multi-GPU training. (#2200)
- Instead just pull off a batch for each GPU. - Enables increasing the effective batch size for `bidirectional_language_model.jsonnet` by 2x giving a 1.5x speedup.
1 parent d0a5a40 commit 7525c61

File tree

11 files changed

+86
-139
lines changed

11 files changed

+86
-139
lines changed

allennlp/commands/find_learning_rate.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from allennlp.commands.subcommand import Subcommand
5959
from allennlp.common.checks import ConfigurationError, check_for_gpu
6060
from allennlp.common import Params, Tqdm
61-
from allennlp.common.util import prepare_environment
61+
from allennlp.common.util import prepare_environment, lazy_groups_of
6262
from allennlp.data import Vocabulary, DataIterator
6363
from allennlp.models import Model
6464
from allennlp.training import Trainer
@@ -263,8 +263,11 @@ def search_learning_rate(trainer: Trainer,
263263

264264
trainer.model.train()
265265

266-
train_generator = trainer.iterator(trainer.train_data,
267-
shuffle=trainer.shuffle)
266+
num_gpus = len(trainer._cuda_devices) # pylint: disable=protected-access
267+
268+
raw_train_generator = trainer.iterator(trainer.train_data,
269+
shuffle=trainer.shuffle)
270+
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
268271
train_generator_tqdm = Tqdm.tqdm(train_generator,
269272
total=num_batches)
270273

@@ -276,7 +279,7 @@ def search_learning_rate(trainer: Trainer,
276279
else:
277280
lr_update_factor = (end_lr / start_lr) ** (1.0 / num_batches)
278281

279-
for i, batch in enumerate(train_generator_tqdm):
282+
for i, batch_group in enumerate(train_generator_tqdm):
280283

281284
if linear_steps:
282285
current_lr = start_lr + (lr_update_factor * i)
@@ -287,7 +290,7 @@ def search_learning_rate(trainer: Trainer,
287290
param_group['lr'] = current_lr
288291

289292
trainer.optimizer.zero_grad()
290-
loss = trainer.batch_loss(batch, for_training=True)
293+
loss = trainer.batch_loss(batch_group, for_training=True)
291294
loss.backward()
292295
loss = loss.detach().cpu().item()
293296

allennlp/common/util.py

-95
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
Various utilities that don't fit anwhere else.
33
"""
4-
from ctypes import sizeof, c_void_p, c_int64, cast, py_object, c_uint64
54
from itertools import zip_longest, islice
65
from typing import Any, Callable, Dict, List, Tuple, TypeVar, Iterable, Iterator, Union
76
import importlib
@@ -14,8 +13,6 @@
1413
import os
1514
import re
1615

17-
from torch.nn.parallel._functions import Scatter
18-
1916
try:
2017
import resource
2118
except ImportError:
@@ -392,98 +389,6 @@ def from_list(strings):
392389
# TODO(brendanr): Determine why mypy can't tell that this matches the Union.
393390
return int(cuda_device) # type: ignore
394391

395-
class ScatterableList(list):
396-
"""
397-
A normal list, but one that should be scattered like a tensor.
398-
"""
399-
400-
# Ensure pointers will fit in a torch.LongTensor. "64 bits ought to be enough for anybody."
401-
assert sizeof(c_void_p) <= sizeof(c_int64)
402-
403-
def to_pointer_tensor(self) -> torch.LongTensor:
404-
"""
405-
Converts the elements to pointers, casts them to ``int64`` and then returns them in a tensor. This cast is
406-
important as ``id`` gives back unsigned integers while ``torch.LongTensor`` is signed.
407-
408-
See:
409-
https://github.com/python/cpython/blob/6ec5cf24b7f38ea72bb42d5cd60dca0d3ee332f9/Python/bltinmodule.c#L1118
410-
https://github.com/python/cpython/blob/6ec5cf24b7f38ea72bb42d5cd60dca0d3ee332f9/Objects/longobject.c#L990
411-
"""
412-
pointers = [c_int64(id(element)).value for element in self]
413-
return torch.LongTensor(pointers)
414-
415-
@classmethod
416-
def from_pointer_tensor(cls, pointers: torch.LongTensor) -> list:
417-
"""
418-
The inverse of ``to_pointer_tensor`` except that a plain ``list`` is returned. Typically this will be
419-
called on a single chunk of the scattered tensor.
420-
421-
Parameters
422-
----------
423-
pointers : ``torch.LongTensor``, required.
424-
A tensor of shape (list_length,).
425-
"""
426-
return [cast(c_uint64(pointer.item()).value, py_object).value for pointer in pointers]
427-
428-
def scatter(inputs, target_gpus, dim=0):
429-
"""
430-
Slices tensors and ScatterableLists into approximately equal chunks and distributes them across given GPUs.
431-
Duplicates references to objects that are not tensors or ScatterableLists.
432-
433-
Adapted from `scatter` at:
434-
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/torch/nn/parallel/scatter_gather.py#L5-L30.
435-
436-
Please see the LICENSE and NOTICE files as well:
437-
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/LICENSE
438-
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/NOTICE
439-
"""
440-
def scatter_map(obj):
441-
if isinstance(obj, torch.Tensor):
442-
return Scatter.apply(target_gpus, None, dim, obj)
443-
if isinstance(obj, ScatterableList):
444-
# In order to have precisely the same method of scattering as PyTorch we scatter
445-
# a tensor of pointers.
446-
pointers = scatter_map(obj.to_pointer_tensor())
447-
# Then we reconstruct the lists from the pointer tensors.
448-
return [obj.from_pointer_tensor(chunk) for chunk in pointers]
449-
if isinstance(obj, tuple) and obj:
450-
return list(zip(*map(scatter_map, obj)))
451-
if isinstance(obj, list) and obj:
452-
return list(map(list, zip(*map(scatter_map, obj))))
453-
if isinstance(obj, dict) and obj:
454-
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
455-
return [obj for _ in target_gpus]
456-
457-
# After scatter_map is called, a scatter_map cell will exist. This cell
458-
# has a reference to the actual function scatter_map, which has references
459-
# to a closure that has a reference to the scatter_map cell (because the
460-
# fn is recursive). To avoid this reference cycle, we set the function to
461-
# None, clearing the cell
462-
try:
463-
return scatter_map(inputs)
464-
finally:
465-
scatter_map = None
466-
467-
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
468-
"""Scatter with support for kwargs dictionary.
469-
470-
Adapted from `scatter_kwargs` at:
471-
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/torch/nn/parallel/scatter_gather.py#L33-L43
472-
473-
Please see the LICENSE and NOTICE files as well:
474-
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/LICENSE
475-
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/NOTICE
476-
"""
477-
inputs = scatter(inputs, target_gpus, dim) if inputs else []
478-
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
479-
if len(inputs) < len(kwargs):
480-
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
481-
elif len(kwargs) < len(inputs):
482-
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
483-
inputs = tuple(inputs)
484-
kwargs = tuple(kwargs)
485-
return inputs, kwargs
486-
487392
def get_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> List:
488393
frozen_parameter_names = []
489394
tunable_parameter_names = []

allennlp/data/fields/metadata_field.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from overrides import overrides
55

6-
from allennlp.common.util import ScatterableList
76
from allennlp.data.fields.field import DataArray, Field
87

98

@@ -61,8 +60,8 @@ def empty_field(self) -> 'MetadataField':
6160

6261
@classmethod
6362
@overrides
64-
def batch_tensors(cls, tensor_list: List[DataArray]) -> ScatterableList: # type: ignore
65-
return ScatterableList(tensor_list)
63+
def batch_tensors(cls, tensor_list: List[DataArray]) -> List[DataArray]: # type: ignore
64+
return tensor_list
6665

6766

6867
def __str__(self) -> str:

allennlp/data/fields/production_rule_field.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44
from overrides import overrides
55

6-
from allennlp.common.util import ScatterableList
76
from allennlp.data.fields.field import Field
87
from allennlp.data.vocabulary import Vocabulary
98

@@ -114,9 +113,9 @@ def empty_field(self): # pylint: disable=no-self-use
114113
return ProductionRuleField(rule='', is_global_rule=False)
115114

116115
@overrides
117-
def batch_tensors(self, tensor_list: List[ProductionRule]) -> ScatterableList: # type: ignore
116+
def batch_tensors(self, tensor_list: List[ProductionRule]) -> List[ProductionRule]: # type: ignore
118117
# pylint: disable=no-self-use
119-
return ScatterableList(tensor_list)
118+
return tensor_list
120119

121120
def __str__(self) -> str:
122121
return f"ProductionRuleField with rule: {self.rule} (is_global_rule: " \

allennlp/data/iterators/bucket_iterator.py

+3
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Itera
124124
if excess:
125125
batches.append(Batch(excess))
126126

127+
# TODO(brendanr): Add multi-GPU friendly grouping, i.e. group
128+
# num_gpu batches together, shuffle and then expand the groups.
129+
# This guards against imbalanced batches across GPUs.
127130
move_to_front = self._biggest_batch_first and len(batches) > 1
128131
if move_to_front:
129132
# We'll actually pop the last _two_ batches, because the last one might not be full.

allennlp/data/iterators/data_iterator.py

+2
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def __call__(self,
125125
tensor_dicts = self._cache[key]
126126

127127
if shuffle:
128+
# TODO(brendanr): How can we handle this shuffle in a way
129+
# that respects multi-GPU friendly grouping?
128130
random.shuffle(tensor_dicts)
129131
for tensor_dict in tensor_dicts:
130132
if self._track_epoch:

allennlp/tests/models/simple_tagger_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def test_regularization(self):
6464
training_batch = next(iterator(self.instances, num_epochs=1))
6565
validation_batch = next(iterator(self.instances, num_epochs=1))
6666

67-
training_loss = trainer.batch_loss(training_batch, for_training=True).item()
68-
validation_loss = trainer.batch_loss(validation_batch, for_training=False).item()
67+
training_loss = trainer.batch_loss([training_batch], for_training=True).item()
68+
validation_loss = trainer.batch_loss([validation_batch], for_training=False).item()
6969

7070
# Training loss should have the regularization penalty, but validation loss should not.
7171
numpy.testing.assert_almost_equal(training_loss, validation_loss)
@@ -116,8 +116,8 @@ def test_regularization(self):
116116
training_batch = next(self.iterator(self.instances, num_epochs=1))
117117
validation_batch = next(self.iterator(self.instances, num_epochs=1))
118118

119-
training_loss = self.trainer.batch_loss(training_batch, for_training=True).data
120-
validation_loss = self.trainer.batch_loss(validation_batch, for_training=False).data
119+
training_loss = self.trainer.batch_loss([training_batch], for_training=True).data
120+
validation_loss = self.trainer.batch_loss([validation_batch], for_training=False).data
121121

122122
# Training loss should have the regularization penalty, but validation loss should not.
123123
assert (training_loss != validation_loss).all()

allennlp/tests/training/trainer_test.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from allennlp.common.params import Params
2121
from allennlp.models.simple_tagger import SimpleTagger
2222
from allennlp.data.iterators import BasicIterator
23-
from allennlp.data.dataset_readers import SequenceTaggingDatasetReader
23+
from allennlp.data.dataset_readers import SequenceTaggingDatasetReader, WikiTablesDatasetReader
24+
from allennlp.models.archival import load_archive
2425
from allennlp.models.model import Model
2526

2627

@@ -133,6 +134,22 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore # pylint
133134
assert 'peak_gpu_1_memory_MB' in metrics
134135
assert isinstance(metrics['peak_gpu_1_memory_MB'], int)
135136

137+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
138+
reason="Need multiple GPUs.")
139+
def test_production_rule_field_with_multiple_gpus(self):
140+
wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/'
141+
wikitables_reader = WikiTablesDatasetReader(tables_directory=wikitables_dir,
142+
dpd_output_directory=wikitables_dir + 'dpd_output/')
143+
instances = wikitables_reader.read(wikitables_dir + 'sample_data.examples')
144+
archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz'
145+
model = load_archive(archive_path).model
146+
model.cuda()
147+
148+
multigpu_iterator = BasicIterator(batch_size=4)
149+
multigpu_iterator.index_with(model.vocab)
150+
trainer = Trainer(model, self.optimizer, multigpu_iterator, instances, num_epochs=2, cuda_device=[0, 1])
151+
trainer.train()
152+
136153
def test_trainer_can_resume_training(self):
137154
trainer = Trainer(self.model, self.optimizer,
138155
self.iterator, self.instances,

allennlp/training/trainer.py

+29-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
import logging
3+
import math
34
import os
45
import time
56
import re
@@ -13,10 +14,10 @@
1314
from allennlp.common import Params
1415
from allennlp.common.checks import ConfigurationError
1516
from allennlp.common.util import (dump_metrics, gpu_memory_mb, parse_cuda_device, peak_memory_mb,
16-
get_frozen_and_tunable_parameter_names)
17+
get_frozen_and_tunable_parameter_names, lazy_groups_of)
1718
from allennlp.common.tqdm import Tqdm
1819
from allennlp.data.instance import Instance
19-
from allennlp.data.iterators.data_iterator import DataIterator
20+
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict
2021
from allennlp.data.vocabulary import Vocabulary
2122
from allennlp.models.model import Model
2223
from allennlp.nn import util as nn_util
@@ -216,14 +217,16 @@ def __init__(self,
216217
def rescale_gradients(self) -> Optional[float]:
217218
return training_util.rescale_gradients(self.model, self._grad_norm)
218219

219-
def batch_loss(self, batch: torch.Tensor, for_training: bool) -> torch.Tensor:
220+
def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor:
220221
"""
221-
Does a forward pass on the given batch and returns the ``loss`` value in the result.
222+
Does a forward pass on the given batches and returns the ``loss`` value in the result.
222223
If ``for_training`` is `True` also applies regularization penalty.
223224
"""
224225
if self._multiple_gpu:
225-
output_dict = training_util.data_parallel(batch, self.model, self._cuda_devices)
226+
output_dict = training_util.data_parallel(batch_group, self.model, self._cuda_devices)
226227
else:
228+
assert len(batch_group) == 1
229+
batch = batch_group[0]
227230
batch = nn_util.move_to_device(batch, self._cuda_devices[0])
228231
output_dict = self.model(**batch)
229232

@@ -255,11 +258,14 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
255258
# Set the model to "train" mode.
256259
self.model.train()
257260

261+
num_gpus = len(self._cuda_devices)
262+
258263
# Get tqdm for the training batches
259-
train_generator = self.iterator(self.train_data,
260-
num_epochs=1,
261-
shuffle=self.shuffle)
262-
num_training_batches = self.iterator.get_num_batches(self.train_data)
264+
raw_train_generator = self.iterator(self.train_data,
265+
num_epochs=1,
266+
shuffle=self.shuffle)
267+
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
268+
num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data)/num_gpus)
263269
self._last_log = time.time()
264270
last_save_time = time.time()
265271

@@ -269,18 +275,20 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
269275

270276
histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging())
271277

278+
272279
logger.info("Training")
273280
train_generator_tqdm = Tqdm.tqdm(train_generator,
274281
total=num_training_batches)
275282
cumulative_batch_size = 0
276-
for batch in train_generator_tqdm:
283+
for batch_group in train_generator_tqdm:
277284
batches_this_epoch += 1
278285
self._batch_num_total += 1
279286
batch_num_total = self._batch_num_total
280287

281288
self.optimizer.zero_grad()
282289

283-
loss = self.batch_loss(batch, for_training=True)
290+
loss = self.batch_loss(batch_group, for_training=True)
291+
284292
if torch.isnan(loss):
285293
raise ValueError("nan loss encountered")
286294

@@ -329,7 +337,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
329337
self._tensorboard.log_histograms(self.model, histogram_parameters)
330338

331339
if self._log_batch_size_period:
332-
cur_batch = training_util.get_batch_size(batch)
340+
cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group])
333341
cumulative_batch_size += cur_batch
334342
if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
335343
average = cumulative_batch_size/batches_this_epoch
@@ -365,17 +373,20 @@ def _validation_loss(self) -> Tuple[float, int]:
365373
else:
366374
val_iterator = self.iterator
367375

368-
val_generator = val_iterator(self._validation_data,
369-
num_epochs=1,
370-
shuffle=False)
371-
num_validation_batches = val_iterator.get_num_batches(self._validation_data)
376+
num_gpus = len(self._cuda_devices)
377+
378+
raw_val_generator = val_iterator(self._validation_data,
379+
num_epochs=1,
380+
shuffle=False)
381+
val_generator = lazy_groups_of(raw_val_generator, num_gpus)
382+
num_validation_batches = math.ceil(val_iterator.get_num_batches(self._validation_data)/num_gpus)
372383
val_generator_tqdm = Tqdm.tqdm(val_generator,
373384
total=num_validation_batches)
374385
batches_this_epoch = 0
375386
val_loss = 0
376-
for batch in val_generator_tqdm:
387+
for batch_group in val_generator_tqdm:
377388

378-
loss = self.batch_loss(batch, for_training=False)
389+
loss = self.batch_loss(batch_group, for_training=False)
379390
if loss is not None:
380391
# You shouldn't necessarily have to compute a loss for validation, so we allow for
381392
# `loss` to be None. We need to be careful, though - `batches_this_epoch` is

0 commit comments

Comments
 (0)