Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 44d2847

Browse files
AkshitaBdirkgr
andauthored
Metrics in distributed setting (#4525)
* initial commit to ensure that metrics work correctly in distributed setting * updating global_distributed_metric to take metric object * adding distributed f1 score * adding distributed attachment scores * bug fix * adding distributed boolean accuracy * adding distributed entropy * adding distributed evalb * adding distributed mean_absolute_error * adding distributed sequence accuracy * adding distributed unigram recall * making models compatible with distributed metrics * adding distributed auc * adding distributed bleu * adding missing argument * initial commit to ensure that metrics work correctly in distributed setting * updating global_distributed_metric to take metric object * adding distributed f1 score * adding distributed attachment scores * bug fix * adding distributed boolean accuracy * adding distributed entropy * adding distributed evalb * adding distributed mean_absolute_error * adding distributed sequence accuracy * adding distributed unigram recall * making models compatible with distributed metrics * adding distributed auc * adding distributed bleu * adding missing argument * changing start method * removing unnecessary argument * adding remaining metrics, removing extra argument * allowing float values * bug fix * more bug fixes * changing average to return float * adding timeout for distributed test * testing unequal batches * adding distributed auc * adding distributed spearman correlation * adding distributed covariance and pearson correlation * changing distributed test to function, misc changes * checking batch lengths explicitly to raise errors Co-authored-by: Dirk Groeneveld <[email protected]>
1 parent 1d61965 commit 44d2847

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1116
-77
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2929
- Fixed testing models that only return a loss when they are in training mode.
3030
- Fixed a bug in `FromParams` that caused silent failure in case of the parameter type being `Optional[Union[...]]`.
3131
- Fixed a bug where the program crashes if `evaluation_data_loader` is a `AllennlpLazyDataset`.
32+
- Fixed evaluation of all metrics when using distributed training.
3233

3334
### Added
3435

allennlp/common/testing/__init__.py

+55
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
"""
22
Utilities and helpers for writing tests.
33
"""
4+
from typing import Dict, Any, Optional, Union, Tuple, List
45
import torch
6+
from torch.testing import assert_allclose
57
import pytest
68

79
from allennlp.common.testing.test_case import AllenNlpTestCase
810
from allennlp.common.testing.model_test_case import ModelTestCase
11+
from allennlp.common.testing.distributed_test import run_distributed_test
12+
13+
from allennlp.training.metrics import Metric
914

1015

1116
_available_devices = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
@@ -45,3 +50,53 @@ def cpu_or_gpu(test_method):
4550
Decorator to indicate that a test should run on both CPU and GPU
4651
"""
4752
return pytest.mark.gpu(test_method)
53+
54+
55+
# Helpers for testing distributed metrics
56+
57+
58+
def assert_metrics_values(
59+
metrics: Dict[str, Any],
60+
desired_values: Dict[str, Any],
61+
rtol: float = 0.0001,
62+
atol: float = 1e-05,
63+
):
64+
for key in metrics:
65+
assert_allclose(metrics[key], desired_values[key], rtol=rtol, atol=atol)
66+
67+
68+
def global_distributed_metric(
69+
global_rank: int,
70+
world_size: int,
71+
gpu_id: Union[int, torch.device],
72+
metric: Metric,
73+
metric_kwargs: Dict[str, List[Any]],
74+
desired_values: Dict[str, Any],
75+
exact: Union[bool, Tuple[float, float]] = True,
76+
):
77+
kwargs = {}
78+
79+
# Use the arguments meant for the process with rank `global_rank`.
80+
for argname in metric_kwargs:
81+
kwargs[argname] = metric_kwargs[argname][global_rank]
82+
83+
metric(**kwargs)
84+
85+
metrics = metric.get_metric(False)
86+
if not isinstance(metrics, Dict) and not isinstance(desired_values, Dict):
87+
metrics = {"metric_value": metrics}
88+
desired_values = {"metric_value": desired_values}
89+
90+
# Call `assertion_metrics_values` to check if the metrics have the desired values.
91+
if isinstance(exact, bool):
92+
if exact:
93+
rtol = 0.0
94+
atol = 0.0
95+
else:
96+
rtol = 0.0001
97+
atol = 1e-05
98+
else:
99+
rtol = exact[0]
100+
atol = exact[1]
101+
102+
assert_metrics_values(metrics, desired_values, rtol, atol) # type: ignore
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import datetime
2+
from typing import List, Dict, Any, Tuple, Callable
3+
import torch
4+
import torch.distributed as dist
5+
import torch.multiprocessing as mp
6+
7+
from allennlp.common.checks import check_for_gpu
8+
9+
10+
def init_process(
11+
process_rank: int,
12+
distributed_device_ids: List[int] = None,
13+
world_size: int = 1,
14+
func: Callable = None,
15+
func_args: Tuple = None,
16+
func_kwargs: Dict[str, Any] = None,
17+
master_addr: str = "127.0.0.1",
18+
master_port: int = 29500,
19+
):
20+
assert world_size > 1
21+
22+
global_rank = process_rank
23+
24+
gpu_id = distributed_device_ids[process_rank] # type: ignore
25+
26+
if gpu_id >= 0:
27+
torch.cuda.set_device(int(gpu_id))
28+
dist.init_process_group(
29+
backend="nccl",
30+
init_method=f"tcp://{master_addr}:{master_port}",
31+
world_size=world_size,
32+
rank=global_rank,
33+
)
34+
else:
35+
dist.init_process_group(
36+
backend="gloo",
37+
init_method=f"tcp://{master_addr}:{master_port}",
38+
world_size=world_size,
39+
rank=global_rank,
40+
timeout=datetime.timedelta(seconds=120),
41+
)
42+
43+
func(global_rank, world_size, gpu_id, *func_args, **func_kwargs)
44+
45+
dist.barrier()
46+
47+
48+
def run_distributed_test(
49+
device_ids: List[int] = [-1, -1], func: Callable = None, *args, **kwargs,
50+
):
51+
"""
52+
This runs the `func` in a simulated distributed environment.
53+
54+
# Parameters
55+
56+
device_ids: `List[int]`
57+
List of devices. There need to be at least 2 devices. Default is [-1, -1].
58+
59+
func: `Callable`
60+
`func` needs to be global for spawning the processes, so that it can be pickled.
61+
"""
62+
63+
check_for_gpu(device_ids)
64+
nprocs = world_size = len(device_ids)
65+
mp.start_processes(
66+
init_process,
67+
args=(device_ids, world_size, func, args, kwargs),
68+
nprocs=nprocs,
69+
start_method="fork",
70+
)

allennlp/models/simple_tagger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
213213
}
214214

215215
if self.calculate_span_f1:
216-
f1_dict = self._f1_metric.get_metric(reset=reset)
216+
f1_dict = self._f1_metric.get_metric(reset)
217217
if self._verbose_metrics:
218218
metrics_to_return.update(f1_dict)
219219
else:

allennlp/training/metrics/attachment_scores.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from typing import Optional, List
1+
from typing import Optional, List, Union
22

33
from overrides import overrides
44
import torch
5+
import torch.distributed as dist
56

7+
from allennlp.common.util import is_distributed
68
from allennlp.training.metrics.metric import Metric
79

810

@@ -57,6 +59,7 @@ def __call__( # type: ignore
5759
predicted_indices, predicted_labels, gold_indices, gold_labels, mask
5860
)
5961
predicted_indices, predicted_labels, gold_indices, gold_labels, mask = detached
62+
device = predicted_indices.device
6063

6164
if mask is None:
6265
mask = torch.ones_like(predicted_indices).bool()
@@ -78,14 +81,30 @@ def __call__( # type: ignore
7881
correct_labels_and_indices = correct_indices * correct_labels
7982
labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1)
8083

84+
if is_distributed():
85+
dist.all_reduce(correct_indices, op=dist.ReduceOp.SUM)
86+
dist.all_reduce(unlabeled_exact_match, op=dist.ReduceOp.SUM)
87+
dist.all_reduce(correct_labels_and_indices, op=dist.ReduceOp.SUM)
88+
dist.all_reduce(labeled_exact_match, op=dist.ReduceOp.SUM)
89+
8190
self._unlabeled_correct += correct_indices.sum()
8291
self._exact_unlabeled_correct += unlabeled_exact_match.sum()
8392
self._labeled_correct += correct_labels_and_indices.sum()
8493
self._exact_labeled_correct += labeled_exact_match.sum()
8594
self._total_sentences += correct_indices.size(0)
8695
self._total_words += correct_indices.numel() - (~mask).sum()
8796

88-
def get_metric(self, reset: bool = False):
97+
if is_distributed():
98+
_total_sentences = torch.tensor(self._total_sentences).to(device)
99+
_total_words = torch.tensor(self._total_words).to(device)
100+
dist.all_reduce(_total_sentences, op=dist.ReduceOp.SUM)
101+
dist.all_reduce(_total_words, op=dist.ReduceOp.SUM)
102+
self._total_sentences = _total_sentences.item()
103+
self._total_words = _total_words.item()
104+
105+
def get_metric(
106+
self, reset: bool = False, cuda_device: Union[int, torch.device] = torch.device("cpu"),
107+
):
89108
"""
90109
# Returns
91110
@@ -95,6 +114,7 @@ def get_metric(self, reset: bool = False):
95114
labeled_attachment_score = 0.0
96115
unlabeled_exact_match = 0.0
97116
labeled_exact_match = 0.0
117+
98118
if self._total_words > 0.0:
99119
unlabeled_attachment_score = float(self._unlabeled_correct) / float(self._total_words)
100120
labeled_attachment_score = float(self._labeled_correct) / float(self._total_words)
@@ -105,12 +125,13 @@ def get_metric(self, reset: bool = False):
105125
labeled_exact_match = float(self._exact_labeled_correct) / float(self._total_sentences)
106126
if reset:
107127
self.reset()
108-
return {
128+
metrics = {
109129
"UAS": unlabeled_attachment_score,
110130
"LAS": labeled_attachment_score,
111131
"UEM": unlabeled_exact_match,
112132
"LEM": labeled_exact_match,
113133
}
134+
return metrics
114135

115136
@overrides
116137
def reset(self):

allennlp/training/metrics/auc.py

+34
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from overrides import overrides
44
import torch
5+
import torch.distributed as dist
56
from sklearn import metrics
67

8+
from allennlp.common.util import is_distributed
79
from allennlp.common.checks import ConfigurationError
810
from allennlp.training.metrics.metric import Metric
911

@@ -82,7 +84,38 @@ def __call__(
8284
[self._all_gold_labels, torch.masked_select(gold_labels, mask).long()], dim=0
8385
)
8486

87+
if is_distributed():
88+
world_size = dist.get_world_size()
89+
device = gold_labels.device
90+
91+
# Check if batch lengths are equal.
92+
_all_batch_lengths = [torch.tensor(0) for i in range(world_size)]
93+
dist.all_gather(
94+
_all_batch_lengths, torch.tensor(len(self._all_predictions), device=device)
95+
)
96+
_all_batch_lengths = [batch_length.item() for batch_length in _all_batch_lengths]
97+
98+
if len(set(_all_batch_lengths)) > 1:
99+
# Subsequent dist.all_gather() calls currently do not handle tensors of different length.
100+
raise RuntimeError(
101+
"Distributed aggregation for AUC is currently not supported for batches of unequal length."
102+
)
103+
104+
_all_predictions = [
105+
torch.zeros(self._all_predictions.shape, device=device) for i in range(world_size)
106+
]
107+
108+
_all_gold_labels = [
109+
torch.zeros(self._all_gold_labels.shape, device=device, dtype=torch.long)
110+
for i in range(world_size)
111+
]
112+
dist.all_gather(_all_predictions, self._all_predictions)
113+
dist.all_gather(_all_gold_labels, self._all_gold_labels)
114+
self._all_predictions = torch.cat(_all_predictions, dim=0)
115+
self._all_gold_labels = torch.cat(_all_gold_labels, dim=0)
116+
85117
def get_metric(self, reset: bool = False):
118+
86119
if self._all_gold_labels.shape[0] == 0:
87120
return 0.5
88121
false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
@@ -91,6 +124,7 @@ def get_metric(self, reset: bool = False):
91124
pos_label=self._positive_label,
92125
)
93126
auc = metrics.auc(false_positive_rates, true_positive_rates)
127+
94128
if reset:
95129
self.reset()
96130
return auc

allennlp/training/metrics/average.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from overrides import overrides
22

3+
import torch
4+
import torch.distributed as dist
5+
6+
from allennlp.common.util import is_distributed
37
from allennlp.training.metrics.metric import Metric
48

59

@@ -26,6 +30,14 @@ def __call__(self, value):
2630
"""
2731
self._total_value += list(self.detach_tensors(value))[0]
2832
self._count += 1
33+
if is_distributed():
34+
device = torch.device("cpu")
35+
_count = torch.tensor(self._count).to(device)
36+
_total_value = torch.tensor(self._total_value).to(device)
37+
dist.all_reduce(_count, op=dist.ReduceOp.SUM)
38+
dist.all_reduce(_total_value, op=dist.ReduceOp.SUM)
39+
self._count = _count.item()
40+
self._total_value = _total_value.item()
2941

3042
@overrides
3143
def get_metric(self, reset: bool = False):
@@ -34,6 +46,7 @@ def get_metric(self, reset: bool = False):
3446
3547
The average of all values that were passed to `__call__`.
3648
"""
49+
3750
average_value = self._total_value / self._count if self._count > 0 else 0
3851
if reset:
3952
self.reset()

allennlp/training/metrics/bleu.py

+23
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
from overrides import overrides
66
import torch
7+
import torch.distributed as dist
78

9+
from allennlp.common.util import is_distributed
810
from allennlp.training.metrics.metric import Metric
911

1012

@@ -116,10 +118,21 @@ def __call__(
116118
None
117119
"""
118120
predictions, gold_targets = self.detach_tensors(predictions, gold_targets)
121+
device = gold_targets.device
122+
if is_distributed():
123+
world_size = dist.get_world_size()
124+
119125
for ngram_size, _ in enumerate(self._ngram_weights, start=1):
120126
precision_matches, precision_totals = self._get_modified_precision_counts(
121127
predictions, gold_targets, ngram_size
122128
)
129+
if is_distributed():
130+
_precision_matches = torch.tensor(precision_matches).to(device)
131+
_precision_totals = torch.tensor(precision_totals).to(device)
132+
dist.all_reduce(_precision_matches, op=dist.ReduceOp.SUM)
133+
dist.all_reduce(_precision_totals, op=dist.ReduceOp.SUM)
134+
precision_matches = _precision_matches.item() / world_size
135+
precision_totals = _precision_totals.item() / world_size
123136
self._precision_matches[ngram_size] += precision_matches
124137
self._precision_totals[ngram_size] += precision_totals
125138
if not self._exclude_indices:
@@ -133,8 +146,17 @@ def __call__(
133146
valid_gold_targets_mask = get_valid_tokens_mask(gold_targets, self._exclude_indices)
134147
self._reference_lengths += valid_gold_targets_mask.sum().item()
135148

149+
if is_distributed():
150+
_prediction_lengths = torch.tensor(self._prediction_lengths).to(device)
151+
_reference_lengths = torch.tensor(self._reference_lengths).to(device)
152+
dist.all_reduce(_prediction_lengths, op=dist.ReduceOp.SUM)
153+
dist.all_reduce(_reference_lengths, op=dist.ReduceOp.SUM)
154+
self._prediction_lengths = _prediction_lengths.item()
155+
self._reference_lengths = _reference_lengths.item()
156+
136157
@overrides
137158
def get_metric(self, reset: bool = False) -> Dict[str, float]:
159+
138160
brevity_penalty = self._get_brevity_penalty()
139161
ngram_scores = (
140162
weight
@@ -145,6 +167,7 @@ def get_metric(self, reset: bool = False) -> Dict[str, float]:
145167
for n, weight in enumerate(self._ngram_weights, start=1)
146168
)
147169
bleu = brevity_penalty * math.exp(sum(ngram_scores))
170+
148171
if reset:
149172
self.reset()
150173
return {"BLEU": bleu}

0 commit comments

Comments
 (0)