Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 6c7c807

Browse files
hzeng-otteraimatt-gardner
authored andcommitted
Adding decoder to bimpm and improve demo server. (#1665)
* Adding decoder to bimpm and add model weights and overrides to flask server. * Refine comments for pylint.
1 parent 9caac66 commit 6c7c807

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

allennlp/models/bimpm.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from overrides import overrides
88
import torch
9+
import numpy
910

1011
from allennlp.common.checks import check_dimensions_match
1112
from allennlp.data import Vocabulary
@@ -185,8 +186,9 @@ def add_matching_result(matcher, encoded_premise, encoded_hypothesis):
185186

186187
# the final forward layer
187188
logits = self.classifier_feedforward(torch.cat([aggregated_premise, aggregated_hypothesis], dim=-1))
189+
probs = torch.nn.functional.softmax(logits, dim=-1)
188190

189-
output_dict = {'logits': logits}
191+
output_dict = {'logits': logits, "probs": probs}
190192
if label is not None:
191193
loss = self.loss(logits, label)
192194
for metric in self.metrics.values():
@@ -195,6 +197,18 @@ def add_matching_result(matcher, encoded_premise, encoded_hypothesis):
195197

196198
return output_dict
197199

200+
@overrides
201+
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
202+
"""
203+
Converts indices to string labels, and adds a ``"label"`` key to the result.
204+
"""
205+
predictions = output_dict["probs"].cpu().data.numpy()
206+
argmax_indices = numpy.argmax(predictions, axis=-1)
207+
labels = [self.vocab.get_token_from_index(x, namespace="labels")
208+
for x in argmax_indices]
209+
output_dict['label'] = labels
210+
return output_dict
211+
198212
@overrides
199213
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
200214
return {metric_name: metric.get_metric(reset) for metric_name, metric in self.metrics.items()}

allennlp/service/server_simple.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ def main(args):
133133

134134
parser.add_argument('--archive-path', type=str, required=True, help='path to trained archive file')
135135
parser.add_argument('--predictor', type=str, required=True, help='name of predictor')
136+
parser.add_argument('--weights-file', type=str,
137+
help='a path that overrides which weights file to use')
138+
parser.add_argument('-o', '--overrides', type=str, default="",
139+
help='a JSON structure used to override the experiment configuration')
136140
parser.add_argument('--static-dir', type=str, help='serve index.html from this directory')
137141
parser.add_argument('--title', type=str, help='change the default page title', default="AllenNLP Demo")
138142
parser.add_argument('--field-name', type=str, action='append',
@@ -151,7 +155,7 @@ def main(args):
151155
for package_name in args.include_package:
152156
import_submodules(package_name)
153157

154-
archive = load_archive(args.archive_path)
158+
archive = load_archive(args.archive_path, weights_file=args.weights_file, overrides=args.overrides)
155159
predictor = Predictor.from_archive(archive, args.predictor)
156160
field_names = args.field_name
157161

allennlp/tests/models/bimpm_test.py

+6
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,9 @@ def test_model_can_train_save_and_load(self):
1919

2020
def test_batch_predictions_are_consistent(self):
2121
self.ensure_batch_predictions_are_consistent()
22+
23+
def test_decode_runs_correctly(self):
24+
training_tensors = self.dataset.as_tensor_dict()
25+
output_dict = self.model(**training_tensors)
26+
decode_output_dict = self.model.decode(output_dict)
27+
assert "label" in decode_output_dict

0 commit comments

Comments
 (0)