Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Remove scattering for multi-GPU training. #2200

Merged
merged 87 commits into from
Jan 18, 2019
Merged
Show file tree
Hide file tree
Changes from 80 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
5d179a4
Transformer ELMo
brendan-ai2 Nov 21, 2018
2db75b4
wip
brendan-ai2 Dec 1, 2018
f7deed3
Add bidirectional transformer token embedder
brendan-ai2 Dec 4, 2018
c9de1ec
transformer elmo config template
brendan-ai2 Dec 4, 2018
634b4a2
MORE
brendan-ai2 Dec 4, 2018
e4a7b51
Works
brendan-ai2 Dec 5, 2018
9eb6e46
Add broken layer norm.
brendan-ai2 Dec 5, 2018
ac425a4
Address some more comments
brendan-ai2 Dec 5, 2018
f203cde
Merge branch 'lm_without_dataset_modifications_2' into lm_without_dat…
brendan-ai2 Dec 5, 2018
bde39fe
Fix for vidurj
brendan-ai2 Dec 5, 2018
4b3a81c
easy feedback
brendan-ai2 Dec 10, 2018
595b668
Fix norm issue
brendan-ai2 Dec 10, 2018
731e69c
Rename
brendan-ai2 Dec 10, 2018
4522f1c
Start and end tokens in reader
brendan-ai2 Dec 10, 2018
d091cc8
comment fix
brendan-ai2 Dec 10, 2018
971e600
fixes
brendan-ai2 Dec 10, 2018
24b763b
style
brendan-ai2 Dec 10, 2018
71e2cce
fix docs
brendan-ai2 Dec 10, 2018
01a111a
Merge branch 'master' into lm_without_dataset_modifications_2
brendan-ai2 Dec 10, 2018
975060a
Merge branch 'master' into lm_without_dataset_modifications_2
brendan-ai2 Dec 13, 2018
100f07f
Merge branch 'lm_without_dataset_modifications_2' into lm_without_dat…
brendan-ai2 Dec 13, 2018
5dcd700
cleanup
brendan-ai2 Dec 14, 2018
87e6241
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 14, 2018
4b3ce38
Bidirectional fixture
brendan-ai2 Dec 14, 2018
1338afb
Test
brendan-ai2 Dec 14, 2018
fa86367
cleanup
brendan-ai2 Dec 14, 2018
7cd29aa
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 14, 2018
f6e57d1
Merge branch 'lm_without_dataset_modifications_3' of github.com:brend…
brendan-ai2 Dec 14, 2018
6ec4a6e
works
brendan-ai2 Dec 16, 2018
d7c0208
Model file
brendan-ai2 Dec 16, 2018
c54534d
update parser config
brendan-ai2 Dec 16, 2018
16ca024
Merge branch 'lm_without_dataset_modifications_3' of github.com:brend…
brendan-ai2 Dec 16, 2018
a8e8eb6
fixes
brendan-ai2 Dec 16, 2018
53a283c
formatting
brendan-ai2 Dec 16, 2018
75f03fd
Type fixes
brendan-ai2 Dec 16, 2018
e8ad0c6
Renames
brendan-ai2 Dec 16, 2018
8cfe033
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 16, 2018
13f6d83
Merge branch 'lm_without_dataset_modifications_3' of github.com:brend…
brendan-ai2 Dec 17, 2018
726cf13
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 17, 2018
8dacca9
another test, jsonnet improvements
brendan-ai2 Dec 17, 2018
c6694b9
Merge branch 'lm_without_dataset_modifications_3' of github.com:brend…
brendan-ai2 Dec 17, 2018
9e617b2
Merge branch 'lm_without_dataset_modifications_3' of github.com:brend…
brendan-ai2 Dec 17, 2018
468300f
docs
brendan-ai2 Dec 17, 2018
30894d0
Merge branch 'lm_without_dataset_modifications_3' into lm_train_fixes
brendan-ai2 Dec 17, 2018
e79119f
Drop scatter
brendan-ai2 Dec 17, 2018
724cf89
Potentially works? On one shard at least.
brendan-ai2 Dec 17, 2018
219d026
Merge branch 'lm_train_fixes' of github.com:brendan-ai2/allennlp into…
brendan-ai2 Dec 17, 2018
8d81d3f
Fix
brendan-ai2 Dec 17, 2018
ae8be54
Merge branch 'lm_train_fixes' of github.com:brendan-ai2/allennlp into…
brendan-ai2 Dec 17, 2018
2cf180e
Added failing test case
matt-gardner Dec 18, 2018
f68e647
Merge branch 'master' into lm_train_fixes
brendan-ai2 Dec 18, 2018
e0d71c4
Merge branch 'master' into lm_train_fixes
brendan-ai2 Dec 18, 2018
2d06b4b
Merge branch 'lm_train_fixes' of github.com:brendan-ai2/allennlp into…
brendan-ai2 Dec 18, 2018
7662c0f
hacks
brendan-ai2 Dec 18, 2018
a48e494
more hacks
brendan-ai2 Dec 18, 2018
560f99f
respond to feedback
brendan-ai2 Dec 19, 2018
4da2ff3
Add todo
brendan-ai2 Dec 19, 2018
50c1e15
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 19, 2018
f93b6dc
lint
brendan-ai2 Dec 19, 2018
4f667ab
Add todos
brendan-ai2 Dec 20, 2018
0cbb57a
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 20, 2018
64bbfbd
Merge branch 'lm_without_dataset_modifications_3' into lm_train_fixes
brendan-ai2 Dec 21, 2018
255045a
Merge branch 'master' into lm_train_fixes
brendan-ai2 Dec 21, 2018
862b87b
cleanups
brendan-ai2 Dec 21, 2018
cc51bbc
Merge branch 'master' into lm_train_fixes
brendan-ai2 Dec 21, 2018
eb8419c
Fix batch size
brendan-ai2 Dec 21, 2018
4856e77
Try for more
brendan-ai2 Dec 22, 2018
e329e72
3k samples
brendan-ai2 Dec 22, 2018
9104044
2k
brendan-ai2 Dec 22, 2018
4a133b1
log grad stats and learning rate
brendan-ai2 Jan 11, 2019
86c76fb
Merge branch 'pr-2199' into lm_train_fixes
brendan-ai2 Jan 15, 2019
8844663
merge
brendan-ai2 Jan 17, 2019
f4726a6
fix
brendan-ai2 Jan 17, 2019
2323bbf
Fix
brendan-ai2 Jan 17, 2019
98629f1
merge
brendan-ai2 Jan 17, 2019
a68db07
drop some logging
brendan-ai2 Jan 17, 2019
6eda737
stash pop
brendan-ai2 Jan 17, 2019
63644f1
fixes
brendan-ai2 Jan 17, 2019
79ed01a
cleanup
brendan-ai2 Jan 17, 2019
d3e4921
Add todos
brendan-ai2 Jan 17, 2019
34b2adf
fixes
brendan-ai2 Jan 18, 2019
c7a5a96
merge
brendan-ai2 Jan 18, 2019
7c19f04
fixes, delete ScatterableList, scatter_kwargs, etc.
brendan-ai2 Jan 18, 2019
95f5804
More cleanup
brendan-ai2 Jan 18, 2019
ee5df46
cleanup
brendan-ai2 Jan 18, 2019
2e6f990
drop no-op changes
brendan-ai2 Jan 18, 2019
12e62d4
Merge branch 'master' into lm_train_fixes
brendan-ai2 Jan 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions allennlp/data/iterators/bucket_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Itera
if excess:
batches.append(Batch(excess))

# TODO(brendanr): Add multi-GPU friendly grouping, i.e. group
# num_gpu batches together, shuffle and then expand the groups.
# This guards against imbalanced batches across GPUs.
move_to_front = self._biggest_batch_first and len(batches) > 1
if move_to_front:
# We'll actually pop the last _two_ batches, because the last one might not be full.
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/iterators/data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __call__(self,
tensor_dicts = self._cache[key]

if shuffle:
# TODO(brendanr): How can we handle this shuffle in a way
# that respects multi-GPU friendly grouping?
random.shuffle(tensor_dicts)
for tensor_dict in tensor_dicts:
if self._track_epoch:
Expand Down
25 changes: 24 additions & 1 deletion allennlp/tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from allennlp.common.params import Params
from allennlp.models.simple_tagger import SimpleTagger
from allennlp.data.iterators import BasicIterator
from allennlp.data.dataset_readers import SequenceTaggingDatasetReader
from allennlp.data.dataset_readers import SequenceTaggingDatasetReader, WikiTablesDatasetReader
from allennlp.models.archival import load_archive
from allennlp.models.model import Model


Expand Down Expand Up @@ -91,6 +92,9 @@ def test_trainer_can_run(self):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
def test_trainer_can_run_cuda(self):
# Trainer expects the model to already be on the correct device.
self.model.cuda(0)

trainer = Trainer(self.model, self.optimizer,
self.iterator, self.instances, num_epochs=2,
cuda_device=0)
Expand All @@ -99,6 +103,8 @@ def test_trainer_can_run_cuda(self):
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need multiple GPUs.")
def test_trainer_can_run_multiple_gpu(self):
# Trainer expects the model to already be on some GPU in the multi-GPU setting.
self.model.cuda(0)

class MetaDataCheckWrapper(Model):
"""
Expand Down Expand Up @@ -132,6 +138,23 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore # pylint
assert 'peak_gpu_1_memory_MB' in metrics
assert isinstance(metrics['peak_gpu_1_memory_MB'], int)

@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need multiple GPUs.")
def test_production_rule_field_with_multiple_gpus(self):
wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/'
wikitables_reader = WikiTablesDatasetReader(tables_directory=wikitables_dir,
dpd_output_directory=wikitables_dir + 'dpd_output/')
instances = wikitables_reader.read(wikitables_dir + 'sample_data.examples')
archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz'
model = load_archive(archive_path).model
# Trainer expects the model to already be on some GPU in the multi-GPU setting.
model.cuda(0)

multigpu_iterator = BasicIterator(batch_size=4)
multigpu_iterator.index_with(model.vocab)
trainer = Trainer(model, self.optimizer, multigpu_iterator, instances, num_epochs=2, cuda_device=[0, 1])
trainer.train()

def test_trainer_can_resume_training(self):
trainer = Trainer(self.model, self.optimizer,
self.iterator, self.instances,
Expand Down
47 changes: 29 additions & 18 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import logging
import math
import os
import time
import re
Expand All @@ -13,10 +14,10 @@
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import (dump_metrics, gpu_memory_mb, parse_cuda_device, peak_memory_mb,
get_frozen_and_tunable_parameter_names)
get_frozen_and_tunable_parameter_names, lazy_groups_of)
from allennlp.common.tqdm import Tqdm
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict
from allennlp.data.vocabulary import Vocabulary
from allennlp.models.model import Model
from allennlp.nn import util as nn_util
Expand Down Expand Up @@ -212,14 +213,16 @@ def __init__(self,
def rescale_gradients(self) -> Optional[float]:
return training_util.rescale_gradients(self.model, self._grad_norm)

def batch_loss(self, batch: torch.Tensor, for_training: bool) -> torch.Tensor:
def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor:
"""
Does a forward pass on the given batch and returns the ``loss`` value in the result.
Does a forward pass on the given batches and returns the ``loss`` value in the result.
If ``for_training`` is `True` also applies regularization penalty.
"""
if self._multiple_gpu:
output_dict = training_util.data_parallel(batch, self.model, self._cuda_devices)
output_dict = training_util.data_parallel(batch_group, self.model, self._cuda_devices)
else:
assert len(batch_group) == 1
batch = batch_group[0]
batch = nn_util.move_to_device(batch, self._cuda_devices[0])
output_dict = self.model(**batch)

Expand Down Expand Up @@ -251,11 +254,14 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
# Set the model to "train" mode.
self.model.train()

num_gpus = len(self._cuda_devices)

# Get tqdm for the training batches
train_generator = self.iterator(self.train_data,
num_epochs=1,
shuffle=self.shuffle)
num_training_batches = self.iterator.get_num_batches(self.train_data)
raw_train_generator = self.iterator(self.train_data,
num_epochs=1,
shuffle=self.shuffle)
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data)/num_gpus)
self._last_log = time.time()
last_save_time = time.time()

Expand All @@ -265,18 +271,20 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:

histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging())


logger.info("Training")
train_generator_tqdm = Tqdm.tqdm(train_generator,
total=num_training_batches)
cumulative_batch_size = 0
for batch in train_generator_tqdm:
for batch_group in train_generator_tqdm:
batches_this_epoch += 1
self._batch_num_total += 1
batch_num_total = self._batch_num_total

self.optimizer.zero_grad()

loss = self.batch_loss(batch, for_training=True)
loss = self.batch_loss(batch_group, for_training=True)

if torch.isnan(loss):
raise ValueError("nan loss encountered")

Expand Down Expand Up @@ -325,7 +333,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
self._tensorboard.log_histograms(self.model, histogram_parameters)

if self._log_batch_size_period:
cur_batch = training_util.get_batch_size(batch)
cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group])
cumulative_batch_size += cur_batch
if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
average = cumulative_batch_size/batches_this_epoch
Expand Down Expand Up @@ -361,17 +369,20 @@ def _validation_loss(self) -> Tuple[float, int]:
else:
val_iterator = self.iterator

val_generator = val_iterator(self._validation_data,
num_epochs=1,
shuffle=False)
num_validation_batches = val_iterator.get_num_batches(self._validation_data)
num_gpus = len(self._cuda_devices)

raw_val_generator = val_iterator(self._validation_data,
num_epochs=1,
shuffle=False)
val_generator = lazy_groups_of(raw_val_generator, num_gpus)
num_validation_batches = math.ceil(val_iterator.get_num_batches(self._validation_data)/num_gpus)
val_generator_tqdm = Tqdm.tqdm(val_generator,
total=num_validation_batches)
batches_this_epoch = 0
val_loss = 0
for batch in val_generator_tqdm:
for batch_group in val_generator_tqdm:

loss = self.batch_loss(batch, for_training=False)
loss = self.batch_loss(batch_group, for_training=False)
if loss is not None:
# You shouldn't necessarily have to compute a loss for validation, so we allow for
# `loss` to be None. We need to be careful, though - `batches_this_epoch` is
Expand Down
16 changes: 11 additions & 5 deletions allennlp/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data import Instance
from allennlp.data.iterators import DataIterator
from allennlp.data.iterators.data_iterator import TensorDict
from allennlp.models.model import Model
from allennlp.models.archival import CONFIG_NAME
from allennlp.nn import util as nn_util
Expand Down Expand Up @@ -228,24 +229,29 @@ def create_serialization_dir(
"does not exist. There is nothing to recover from.")
os.makedirs(serialization_dir, exist_ok=True)

def data_parallel(batch, model: Model, cuda_devices: List) -> Dict[str, torch.Tensor]:
def data_parallel(batch_group: List[TensorDict],
model: Model,
cuda_devices: List) -> Dict[str, torch.Tensor]:
"""
Performs a forward pass using multiple GPUs. This is a simplification
of torch.nn.parallel.data_parallel to support the allennlp model
interface.
"""
inputs, module_kwargs = scatter_kwargs((), batch, cuda_devices, 0)
assert len(batch_group) <= len(cuda_devices)

used_device_ids = cuda_devices[:len(inputs)]
inputs = [()] * len(batch_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputs is supposed to be a list of empty tuples? This never gets updated before getting passed to parallel_apply.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment to clarify. You can see that () was passed to the old scatter_kwargs as well.

# We pass all our arguments as kwargs. Create a list of empty tuples of the
# correct shape to serve as (non-existent) positional arguments.

moved = [nn_util.move_to_device(batch, device)
for batch, device in zip(batch_group, cuda_devices)]

used_device_ids = cuda_devices[:len(moved)]
replicas = replicate(model, used_device_ids)
outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
outputs = parallel_apply(replicas, inputs, moved, used_device_ids)

# Only the 'loss' is needed.
# a (num_gpu, ) tensor with loss on each GPU
losses = gather([output['loss'].unsqueeze(0) for output in outputs], used_device_ids[0], 0)
return {'loss': losses.mean()}


def enable_gradient_clipping(model: Model, grad_clipping: Optional[float]) -> None:
if grad_clipping is not None:
for parameter in model.parameters():
Expand Down
9 changes: 5 additions & 4 deletions training_config/bidirectional_language_model.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ local BASE_READER = {
"type": "elmo_characters"
}
},
"max_sequence_length": 500,
"max_sequence_length": 400,
"start_tokens": ["<S>"],
"end_tokens": ["</S>"]
};
Expand All @@ -34,7 +34,7 @@ local BASE_ITERATOR = {
// samples in every batch.
"batch_size": 512 * NUM_GPUS,
"sorting_keys": [["source", "num_tokens"]],
"maximum_samples_per_batch": ["num_tokens", NUM_GPUS * 1000]
"maximum_samples_per_batch": ["num_tokens", 2000]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a minor backwards compatibility issue here. We're effectively multiplying the batch size (for multi-GPU users) by the number of GPUs. In practice this will result in some OOMs for users that were running close to their memory limits. Given that we had an experimental warning for that use case I think this okay, but I'm curious if you have other thoughts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine to me, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

};

{
Expand Down Expand Up @@ -117,7 +117,7 @@ local BASE_ITERATOR = {
// The multiprocess dataset reader and iterator use many file descriptors,
// so we need to increase the ulimit depending on the size of this queue.
// See https://pytorch.org/docs/stable/multiprocessing.html#file-descriptor-file-descriptor
// for a description of the underlying issue. `ulimit -n 4096` has sufficed,
// for a description of the underlying issue. `ulimit -n 8192` has sufficed,
// but that number could use tuning.
"output_queue_size": 500
},
Expand All @@ -139,6 +139,7 @@ local BASE_ITERATOR = {
// See https://github.com/allenai/calypso/blob/master/bin/train_transformer_lm1b.py#L51.
// Adjusted based on our sample size relative to Calypso's.
"warmup_steps": 6000
}
},
"should_log_learning_rate": true
}
}