Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit f27475a

Browse files
dirkgrepwalsh
andauthored
Enable multi-process training on CPU (#4272)
* Use torch.device everywhere * Update changelog * Run distributed tests even on CPU * Fix bug when running distributed tests on CPU * Remove unused imports * Update CHANGELOG.md Co-authored-by: Evan Pete Walsh <[email protected]> Co-authored-by: Evan Pete Walsh <[email protected]>
1 parent 7e683dd commit f27475a

File tree

10 files changed

+112
-71
lines changed

10 files changed

+112
-71
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717
### Added
1818

1919
- Additional CI checks to ensure docstrings are consistently formatted.
20+
- Ability to train on CPU with multiple processes by setting `cuda_devices` to a list of negative integers in your training config. For example: `"distributed": {"cuda_devices": [-1, -1]}`. This is mainly to make it easier to test and debug distributed training code..
2021

2122
### Changed
2223

allennlp/commands/train.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -402,13 +402,21 @@ def _train_worker(
402402
params["trainer"]["world_size"] = world_size
403403
params["trainer"]["distributed"] = True
404404

405-
torch.cuda.set_device(int(gpu_id))
406-
dist.init_process_group(
407-
backend="nccl",
408-
init_method=f"tcp://{master_addr}:{master_port}",
409-
world_size=world_size,
410-
rank=global_rank,
411-
)
405+
if gpu_id >= 0:
406+
torch.cuda.set_device(int(gpu_id))
407+
dist.init_process_group(
408+
backend="nccl",
409+
init_method=f"tcp://{master_addr}:{master_port}",
410+
world_size=world_size,
411+
rank=global_rank,
412+
)
413+
else:
414+
dist.init_process_group(
415+
backend="gloo",
416+
init_method=f"tcp://{master_addr}:{master_port}",
417+
world_size=world_size,
418+
rank=global_rank,
419+
)
412420
logging.info(
413421
f"Process group of world size {world_size} initialized "
414422
f"for distributed training in worker {global_rank}"

allennlp/common/checks.py

+36-29
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import subprocess
1010

11+
import torch
1112
from torch import cuda
1213

1314
logger = logging.getLogger(__name__)
@@ -100,36 +101,42 @@ def from_list(strings):
100101
return int(cuda_device) # type: ignore
101102

102103

103-
def check_for_gpu(device_id: Union[int, List[int]]):
104-
if isinstance(device_id, list):
105-
for did in device_id:
104+
def check_for_gpu(device: Union[int, torch.device, List[Union[int, torch.device]]]):
105+
if isinstance(device, list):
106+
for did in device:
106107
check_for_gpu(did)
107-
elif device_id is not None and device_id >= 0:
108-
num_devices_available = cuda.device_count()
109-
if num_devices_available == 0:
110-
# Torch will give a more informative exception than ours, so we want to include
111-
# that context as well if it's available. For example, if you try to run torch 1.5
112-
# on a machine with CUDA10.1 you'll get the following:
113-
#
114-
# The NVIDIA driver on your system is too old (found version 10010).
115-
#
116-
torch_gpu_error = ""
117-
try:
118-
cuda._check_driver()
119-
except Exception as e:
120-
torch_gpu_error = "\n{0}".format(e)
121-
122-
raise ConfigurationError(
123-
"Experiment specified a GPU but none is available;"
124-
" if you want to run on CPU use the override"
125-
" 'trainer.cuda_device=-1' in the json config file." + torch_gpu_error
126-
)
127-
elif device_id >= num_devices_available:
128-
raise ConfigurationError(
129-
f"Experiment specified GPU device {device_id}"
130-
f" but there are only {num_devices_available} devices "
131-
f" available."
132-
)
108+
elif device is None:
109+
return
110+
else:
111+
from allennlp.common.util import int_to_device
112+
113+
device = int_to_device(device)
114+
if device != torch.device("cpu"):
115+
num_devices_available = cuda.device_count()
116+
if num_devices_available == 0:
117+
# Torch will give a more informative exception than ours, so we want to include
118+
# that context as well if it's available. For example, if you try to run torch 1.5
119+
# on a machine with CUDA10.1 you'll get the following:
120+
#
121+
# The NVIDIA driver on your system is too old (found version 10010).
122+
#
123+
torch_gpu_error = ""
124+
try:
125+
cuda._check_driver()
126+
except Exception as e:
127+
torch_gpu_error = "\n{0}".format(e)
128+
129+
raise ConfigurationError(
130+
"Experiment specified a GPU but none is available;"
131+
" if you want to run on CPU use the override"
132+
" 'trainer.cuda_device=-1' in the json config file." + torch_gpu_error
133+
)
134+
elif device.index >= num_devices_available:
135+
raise ConfigurationError(
136+
f"Experiment specified GPU device {device.index}"
137+
f" but there are only {num_devices_available} devices "
138+
f" available."
139+
)
133140

134141

135142
def check_for_java() -> bool:

allennlp/common/testing/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,10 @@ def requires_multi_gpu(test_method):
3838
test_method
3939
)
4040
)
41+
42+
43+
def cpu_or_gpu(test_method):
44+
"""
45+
Decorator to indicate that a test should run on both CPU and GPU
46+
"""
47+
return pytest.mark.gpu(test_method)

allennlp/common/util.py

+8
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,14 @@ def is_lazy(iterable: Iterable[A]) -> bool:
427427
return not isinstance(iterable, list)
428428

429429

430+
def int_to_device(device: Union[int, torch.device]) -> torch.device:
431+
if isinstance(device, torch.device):
432+
return device
433+
if device < 0:
434+
return torch.device("cpu")
435+
return torch.device(device)
436+
437+
430438
def log_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> None:
431439
frozen_parameter_names, tunable_parameter_names = get_frozen_and_tunable_parameter_names(model)
432440

allennlp/nn/util.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,16 @@ def has_tensor(obj) -> bool:
3434
return False
3535

3636

37-
def move_to_device(obj, cuda_device: int):
37+
def move_to_device(obj, cuda_device: Union[torch.device, int]):
3838
"""
3939
Given a structure (possibly) containing Tensors on the CPU,
4040
move all the Tensors to the specified GPU (or do nothing, if they should be on the CPU).
4141
"""
42+
from allennlp.common.util import int_to_device
4243

43-
if cuda_device < 0 or not has_tensor(obj):
44+
cuda_device = int_to_device(cuda_device)
45+
46+
if cuda_device == torch.device("cpu") or not has_tensor(obj):
4447
return obj
4548
elif isinstance(obj, torch.Tensor):
4649
return obj.cuda(cuda_device)

allennlp/training/trainer.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import time
77
import traceback
88
from contextlib import contextmanager
9-
from typing import Any, Dict, Iterator, List, Optional, Tuple
9+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
10+
11+
from allennlp.common.util import int_to_device
1012

1113
try:
1214
from apex import amp
@@ -49,7 +51,7 @@ class Trainer(Registrable):
4951
def __init__(
5052
self,
5153
serialization_dir: str,
52-
cuda_device: int = -1,
54+
cuda_device: Union[int, torch.device] = -1,
5355
distributed: bool = False,
5456
local_rank: int = 0,
5557
world_size: int = 1,
@@ -65,28 +67,19 @@ def __init__(
6567
"our Trainer always uses a single GPU per process."
6668
)
6769

68-
if not isinstance(cuda_device, int):
69-
raise ConfigurationError("Expected an int for cuda_device, got {}".format(cuda_device))
70-
7170
if distributed and world_size <= 1:
7271
raise ConfigurationError(
73-
"Distributed training can be performed only with more than 1 GPU device. Check "
72+
"Distributed training can be performed only with more than 1 device. Check "
7473
"`cuda_device` key in the experiment configuration."
7574
)
7675

77-
self.cuda_device = cuda_device
76+
self.cuda_device = int_to_device(cuda_device)
7877

7978
self._distributed = distributed
8079
self._rank = local_rank
8180
self._master = self._rank == 0
8281
self._world_size = world_size
8382

84-
def _move_to_gpu(self, model: Model) -> Model:
85-
if self.cuda_device != -1:
86-
return model.cuda(self.cuda_device)
87-
else:
88-
return model
89-
9083
def train(self) -> Dict[str, Any]:
9184
"""
9285
Train a model and return the results.
@@ -383,7 +376,9 @@ def __init__(
383376
# these places: `model.__call__`, `model.train` and `model.eval`.
384377
if self._distributed:
385378
self._pytorch_model = DistributedDataParallel(
386-
self.model, device_ids=[self.cuda_device], find_unused_parameters=True
379+
self.model,
380+
device_ids=None if self.cuda_device == torch.device("cpu") else [self.cuda_device],
381+
find_unused_parameters=True,
387382
)
388383
else:
389384
self._pytorch_model = self.model
@@ -556,7 +551,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
556551
train_reg_loss,
557552
batches_this_epoch,
558553
world_size=self._world_size,
559-
cuda_device=[self.cuda_device],
554+
cuda_device=self.cuda_device,
560555
)
561556

562557
if self._master:
@@ -600,7 +595,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
600595
batches_this_epoch,
601596
reset=True,
602597
world_size=self._world_size,
603-
cuda_device=[self.cuda_device],
598+
cuda_device=self.cuda_device,
604599
)
605600
metrics["cpu_memory_MB"] = peak_cpu_usage
606601
for (gpu_num, memory) in gpu_usage:
@@ -672,7 +667,7 @@ def _validation_loss(self, epoch: int) -> Tuple[float, float, int]:
672667
val_reg_loss,
673668
batches_this_epoch,
674669
world_size=self._world_size,
675-
cuda_device=[self.cuda_device],
670+
cuda_device=self.cuda_device,
676671
)
677672
description = training_util.description_from_metrics(val_metrics)
678673
val_generator_tqdm.set_description(description, refresh=False)
@@ -693,7 +688,7 @@ def _validation_loss(self, epoch: int) -> Tuple[float, float, int]:
693688
f"Worker {torch.distributed.get_rank()} completed its entire epoch (validation)."
694689
)
695690
# Indicate that we're done so that any workers that have remaining data stop validation early.
696-
done = torch.tensor(1, device=self.cuda_device if self.cuda_device >= 0 else None)
691+
done = torch.tensor(1, device=self.cuda_device)
697692
torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
698693
assert done.item()
699694

@@ -764,7 +759,7 @@ def train(self) -> Dict[str, Any]:
764759
num_batches,
765760
reset=True,
766761
world_size=self._world_size,
767-
cuda_device=[self.cuda_device],
762+
cuda_device=self.cuda_device,
768763
)
769764

770765
# Check validation metric for early stopping

allennlp/training/util.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import os
77
import shutil
8-
from typing import Any, Dict, Iterable, List, Optional, Union
8+
from typing import Any, Dict, Iterable, Optional, Union
99

1010
import torch
1111
import torch.distributed as dist
@@ -284,7 +284,7 @@ def get_metrics(
284284
num_batches: int,
285285
reset: bool = False,
286286
world_size: int = 1,
287-
cuda_device: Union[int, List] = 0,
287+
cuda_device: Union[int, torch.device] = torch.device("cpu"),
288288
) -> Dict[str, float]:
289289
"""
290290
Gets the metrics but sets `"loss"` to
@@ -299,10 +299,7 @@ def get_metrics(
299299
# In distributed mode, average out all metrics across GPUs
300300
aggregated_metrics = {}
301301
for metric_name, metric_val in metrics.items():
302-
if isinstance(cuda_device, list):
303-
metric_tensor = torch.tensor(metric_val).to(torch.device(cuda_device[0]))
304-
else:
305-
metric_tensor = torch.tensor(metric_val).to(torch.device(cuda_device))
302+
metric_tensor = torch.tensor(metric_val).to(cuda_device)
306303
dist.all_reduce(metric_tensor, op=dist.ReduceOp.SUM)
307304
reduced_metric = metric_tensor.item() / world_size
308305
aggregated_metrics[metric_name] = reduced_metric

tests/commands/train_test.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from allennlp.commands.train import Train, train_model, train_model_from_args, TrainModel
1616
from allennlp.common import Params
1717
from allennlp.common.checks import ConfigurationError
18-
from allennlp.common.testing import AllenNlpTestCase, requires_gpu, requires_multi_gpu
18+
from allennlp.common.testing import AllenNlpTestCase, cpu_or_gpu
1919
from allennlp.data import DatasetReader, Instance, Vocabulary
2020
from allennlp.data.dataloader import TensorDict
2121
from allennlp.models import load_archive, Model
@@ -111,8 +111,13 @@ def test_train_model(self):
111111
recover=True,
112112
)
113113

114-
@requires_gpu
114+
@cpu_or_gpu
115115
def test_train_model_distributed(self):
116+
if torch.cuda.device_count() >= 2:
117+
devices = [0, 1]
118+
else:
119+
devices = [-1, -1]
120+
116121
params = lambda: Params(
117122
{
118123
"model": {
@@ -127,7 +132,7 @@ def test_train_model_distributed(self):
127132
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
128133
"data_loader": {"batch_size": 2},
129134
"trainer": {"num_epochs": 2, "optimizer": "adam"},
130-
"distributed": {"cuda_devices": [0, 1]},
135+
"distributed": {"cuda_devices": devices},
131136
}
132137
)
133138

@@ -146,9 +151,14 @@ def test_train_model_distributed(self):
146151
# Check we can load the serialized model
147152
assert load_archive(out_dir).model
148153

149-
@requires_multi_gpu
154+
@cpu_or_gpu
150155
@pytest.mark.parametrize("lazy", [True, False])
151156
def test_train_model_distributed_with_sharded_reader(self, lazy):
157+
if torch.cuda.device_count() >= 2:
158+
devices = [0, 1]
159+
else:
160+
devices = [-1, -1]
161+
152162
params = lambda: Params(
153163
{
154164
"model": {
@@ -167,7 +177,7 @@ def test_train_model_distributed_with_sharded_reader(self, lazy):
167177
"validation_data_path": SEQUENCE_TAGGING_SHARDS_PATH,
168178
"data_loader": {"batch_size": 2},
169179
"trainer": {"num_epochs": 2, "optimizer": "adam"},
170-
"distributed": {"cuda_devices": [0, 1]},
180+
"distributed": {"cuda_devices": devices},
171181
}
172182
)
173183

@@ -232,9 +242,14 @@ def test_train_model_distributed_with_sharded_reader(self, lazy):
232242
assert train_complete in worker1_log
233243
assert validation_complete in worker1_log
234244

235-
@requires_multi_gpu
245+
@cpu_or_gpu
236246
@pytest.mark.parametrize("lazy", [True, False])
237247
def test_train_model_distributed_without_sharded_reader(self, lazy: bool):
248+
if torch.cuda.device_count() >= 2:
249+
devices = [0, 1]
250+
else:
251+
devices = [-1, -1]
252+
238253
num_epochs = 2
239254
params = lambda: Params(
240255
{
@@ -256,7 +271,7 @@ def test_train_model_distributed_without_sharded_reader(self, lazy: bool):
256271
"tests.commands.train_test.TrainingDataLoggerBatchCallback"
257272
],
258273
},
259-
"distributed": {"cuda_devices": [0, 1]},
274+
"distributed": {"cuda_devices": devices},
260275
}
261276
)
262277

tests/nn/util_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1422,7 +1422,7 @@ class A(NamedTuple):
14221422
"b": FakeTensor(),
14231423
"c": (1, FakeTensor()),
14241424
}
1425-
new_device = 4
1425+
new_device = torch.device(4)
14261426
moved_obj = util.move_to_device(structured_obj, new_device)
14271427
assert moved_obj["a"][0].a == 1
14281428
assert moved_obj["a"][0].b._device == new_device

0 commit comments

Comments
 (0)