@@ -81,6 +81,11 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar
81
81
default = "" ,
82
82
help = 'a JSON structure used to override the experiment configuration' )
83
83
84
+ subparser .add_argument ('--batch-weight-key' ,
85
+ type = str ,
86
+ default = "" ,
87
+ help = 'If non-empty, name of metric used to weight the loss on a per-batch basis.' )
88
+
84
89
subparser .set_defaults (func = evaluate_from_args )
85
90
86
91
return subparser
@@ -89,7 +94,8 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar
89
94
def evaluate (model : Model ,
90
95
instances : Iterable [Instance ],
91
96
data_iterator : DataIterator ,
92
- cuda_device : int ) -> Dict [str , Any ]:
97
+ cuda_device : int ,
98
+ batch_weight_key : str ) -> Dict [str , Any ]:
93
99
_warned_tqdm_ignores_underscores = False
94
100
check_for_gpu (cuda_device )
95
101
with torch .no_grad ():
@@ -101,21 +107,34 @@ def evaluate(model: Model,
101
107
logger .info ("Iterating over dataset" )
102
108
generator_tqdm = Tqdm .tqdm (iterator , total = data_iterator .get_num_batches (instances ))
103
109
110
+ # Number of batches in instances.
104
111
batch_count = 0
112
+ # Number of batches where the model produces a loss.
105
113
loss_count = 0
114
+ # Cumulative weighted loss
106
115
total_loss = 0.0
116
+ # Cumulative weight across all batches.
117
+ total_weight = 0.0
107
118
108
119
for batch in generator_tqdm :
109
120
batch_count += 1
110
121
batch = util .move_to_device (batch , cuda_device )
111
- loss = model (** batch ).get ("loss" )
122
+ output_dict = model (** batch )
123
+ loss = output_dict .get ("loss" )
112
124
113
125
metrics = model .get_metrics ()
114
126
115
127
if loss is not None :
116
128
loss_count += 1
117
- metrics ["loss" ] = loss .item ()
118
- total_loss += loss .item ()
129
+ if batch_weight_key :
130
+ weight = output_dict [batch_weight_key ].item ()
131
+ else :
132
+ weight = 1.0
133
+
134
+ total_weight += weight
135
+ total_loss += loss .item () * weight
136
+ # Report the average loss so far.
137
+ metrics ["loss" ] = total_loss / total_weight
119
138
120
139
if (not _warned_tqdm_ignores_underscores and
121
140
any (metric_name .startswith ("_" ) for metric_name in metrics )):
@@ -128,10 +147,11 @@ def evaluate(model: Model,
128
147
129
148
final_metrics = model .get_metrics (reset = True )
130
149
if loss_count > 0 :
150
+ # Sanity check
131
151
if loss_count != batch_count :
132
152
raise RuntimeError ("The model you are trying to evaluate only sometimes " +
133
153
"produced a loss!" )
134
- final_metrics ["loss" ] = total_loss / batch_count
154
+ final_metrics ["loss" ] = total_loss / total_weight
135
155
136
156
return final_metrics
137
157
@@ -168,7 +188,7 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]:
168
188
iterator = DataIterator .from_params (iterator_params )
169
189
iterator .index_with (model .vocab )
170
190
171
- metrics = evaluate (model , instances , iterator , args .cuda_device )
191
+ metrics = evaluate (model , instances , iterator , args .cuda_device , args . batch_weight_key )
172
192
173
193
logger .info ("Finished evaluating." )
174
194
logger .info ("Metrics:" )
0 commit comments