Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 07b5749

Browse files
gurunathparasarammatt-gardner
authored andcommitted
Enable multi-gpu training in find_learning_rate.py (#2045)
* Enable multi-gpu training in find_learning_rate.py * Added test for multi-gpu training * Minor changes in tests * Change in indentation * Changes in testing multi-gpu usage Removed new class for testing multi-gpu usage and moved the testing function to existing class `TestFindLearningRate` * Changes in find_learning_rate_test.py * Minor changes in find_learning_rate_test.py * Remove redundant code
1 parent 43243ac commit 07b5749

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

allennlp/commands/find_learning_rate.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,12 @@ def find_learning_rate_model(params: Params, serialization_dir: str,
170170

171171
prepare_environment(params)
172172

173-
check_for_gpu(params.get('trainer').get('cuda_device', -1))
173+
cuda_device = params.params.get('trainer').get('cuda_device', -1)
174+
if isinstance(cuda_device, list):
175+
for device in cuda_device:
176+
check_for_gpu(device)
177+
else:
178+
check_for_gpu(cuda_device)
174179

175180
all_datasets = datasets_from_params(params)
176181
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))

allennlp/tests/commands/find_learning_rate_test.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import os
44
import pytest
55

6+
import torch
7+
68
from allennlp.common import Params
79
from allennlp.data import Vocabulary, DataIterator
810
from allennlp.models import Model
@@ -12,6 +14,7 @@
1214
from allennlp.commands.find_learning_rate import search_learning_rate, \
1315
find_learning_rate_from_args, find_learning_rate_model, FindLearningRate
1416

17+
1518
class TestFindLearningRate(AllenNlpTestCase):
1619

1720
def setUp(self):
@@ -44,7 +47,8 @@ def setUp(self):
4447
})
4548

4649
def test_find_learning_rate(self):
47-
find_learning_rate_model(self.params(), os.path.join(self.TEST_DIR, 'test_find_learning_rate'),
50+
find_learning_rate_model(self.params(),
51+
os.path.join(self.TEST_DIR, 'test_find_learning_rate'),
4852
start_lr=1e-5,
4953
end_lr=1,
5054
num_batches=100,
@@ -89,7 +93,6 @@ def test_find_learning_rate(self):
8993
stopping_factor=None,
9094
force=True)
9195

92-
9396
def test_find_learning_rate_args(self):
9497
parser = argparse.ArgumentParser(description="Testing")
9598
subparsers = parser.add_subparsers(title='Commands', metavar='')
@@ -115,6 +118,21 @@ def test_find_learning_rate_args(self):
115118
assert cm.exception.code == 2 # argparse code for incorrect usage
116119

117120

121+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
122+
reason="Need multiple GPUs.")
123+
def test_find_learning_rate_multi_gpu(self):
124+
params = self.params()
125+
params["trainer"]["cuda_device"] = [0, 1]
126+
find_learning_rate_model(params,
127+
os.path.join(self.TEST_DIR, 'test_find_learning_rate_multi_gpu'),
128+
start_lr=1e-5,
129+
end_lr=1,
130+
num_batches=100,
131+
linear_steps=True,
132+
stopping_factor=None,
133+
force=False)
134+
135+
118136
class TestSearchLearningRate(AllenNlpTestCase):
119137

120138
def setUp(self):
@@ -144,7 +162,7 @@ def setUp(self):
144162
"num_epochs": 2,
145163
"optimizer": "adam"
146164
}
147-
})
165+
})
148166
all_datasets = datasets_from_params(params)
149167
vocab = Vocabulary.from_params(
150168
params.pop("vocabulary", {}),
@@ -159,12 +177,12 @@ def setUp(self):
159177
serialization_dir = os.path.join(self.TEST_DIR, 'test_search_learning_rate')
160178

161179
self.trainer = Trainer.from_params(model,
162-
serialization_dir,
163-
iterator,
164-
train_data,
165-
params=trainer_params,
166-
validation_data=None,
167-
validation_iterator=None)
180+
serialization_dir,
181+
iterator,
182+
train_data,
183+
params=trainer_params,
184+
validation_data=None,
185+
validation_iterator=None)
168186

169187
def test_search_learning_rate_with_num_batches_less_than_ten(self):
170188
with pytest.raises(ConfigurationError):
@@ -175,6 +193,7 @@ def test_search_learning_rate_linear_steps(self):
175193
assert len(learning_rates_losses) > 1
176194

177195
def test_search_learning_rate_without_stopping_factor(self):
178-
learning_rates, losses = search_learning_rate(self.trainer, num_batches=100, stopping_factor=None)
196+
learning_rates, losses = search_learning_rate(self.trainer, num_batches=100,
197+
stopping_factor=None)
179198
assert len(learning_rates) == 101
180199
assert len(losses) == 101

0 commit comments

Comments
 (0)