1
- from typing import Optional , List
1
+ from typing import Optional , List , Union
2
2
3
3
from overrides import overrides
4
4
import torch
5
+ import torch .distributed as dist
5
6
7
+ from allennlp .common .util import is_distributed
6
8
from allennlp .training .metrics .metric import Metric
7
9
8
10
@@ -57,6 +59,7 @@ def __call__( # type: ignore
57
59
predicted_indices , predicted_labels , gold_indices , gold_labels , mask
58
60
)
59
61
predicted_indices , predicted_labels , gold_indices , gold_labels , mask = detached
62
+ device = predicted_indices .device
60
63
61
64
if mask is None :
62
65
mask = torch .ones_like (predicted_indices ).bool ()
@@ -78,14 +81,30 @@ def __call__( # type: ignore
78
81
correct_labels_and_indices = correct_indices * correct_labels
79
82
labeled_exact_match = (correct_labels_and_indices + ~ mask ).prod (dim = - 1 )
80
83
84
+ if is_distributed ():
85
+ dist .all_reduce (correct_indices , op = dist .ReduceOp .SUM )
86
+ dist .all_reduce (unlabeled_exact_match , op = dist .ReduceOp .SUM )
87
+ dist .all_reduce (correct_labels_and_indices , op = dist .ReduceOp .SUM )
88
+ dist .all_reduce (labeled_exact_match , op = dist .ReduceOp .SUM )
89
+
81
90
self ._unlabeled_correct += correct_indices .sum ()
82
91
self ._exact_unlabeled_correct += unlabeled_exact_match .sum ()
83
92
self ._labeled_correct += correct_labels_and_indices .sum ()
84
93
self ._exact_labeled_correct += labeled_exact_match .sum ()
85
94
self ._total_sentences += correct_indices .size (0 )
86
95
self ._total_words += correct_indices .numel () - (~ mask ).sum ()
87
96
88
- def get_metric (self , reset : bool = False ):
97
+ if is_distributed ():
98
+ _total_sentences = torch .tensor (self ._total_sentences ).to (device )
99
+ _total_words = torch .tensor (self ._total_words ).to (device )
100
+ dist .all_reduce (_total_sentences , op = dist .ReduceOp .SUM )
101
+ dist .all_reduce (_total_words , op = dist .ReduceOp .SUM )
102
+ self ._total_sentences = _total_sentences .item ()
103
+ self ._total_words = _total_words .item ()
104
+
105
+ def get_metric (
106
+ self , reset : bool = False , cuda_device : Union [int , torch .device ] = torch .device ("cpu" ),
107
+ ):
89
108
"""
90
109
# Returns
91
110
@@ -95,6 +114,7 @@ def get_metric(self, reset: bool = False):
95
114
labeled_attachment_score = 0.0
96
115
unlabeled_exact_match = 0.0
97
116
labeled_exact_match = 0.0
117
+
98
118
if self ._total_words > 0.0 :
99
119
unlabeled_attachment_score = float (self ._unlabeled_correct ) / float (self ._total_words )
100
120
labeled_attachment_score = float (self ._labeled_correct ) / float (self ._total_words )
@@ -105,12 +125,13 @@ def get_metric(self, reset: bool = False):
105
125
labeled_exact_match = float (self ._exact_labeled_correct ) / float (self ._total_sentences )
106
126
if reset :
107
127
self .reset ()
108
- return {
128
+ metrics = {
109
129
"UAS" : unlabeled_attachment_score ,
110
130
"LAS" : labeled_attachment_score ,
111
131
"UEM" : unlabeled_exact_match ,
112
132
"LEM" : labeled_exact_match ,
113
133
}
134
+ return metrics
114
135
115
136
@overrides
116
137
def reset (self ):
0 commit comments