Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 9d8862a

Browse files
authored
Move default predictors (#4154)
* Move DecomposableAttention test * Move default predictors * Fix test * Productivity through formatting
1 parent b0c7ac7 commit 9d8862a

File tree

7 files changed

+40
-102
lines changed

7 files changed

+40
-102
lines changed

allennlp/models/basic_classifier.py

+2
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,5 @@ def make_output_human_readable(
179179
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
180180
metrics = {"accuracy": self._accuracy.get_metric(reset)}
181181
return metrics
182+
183+
default_predictor = "text_classifier"

allennlp/models/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class Model(torch.nn.Module, Registrable):
6767
"""
6868

6969
_warn_for_unseparable_batches: Set[str] = set()
70+
default_predictor: Optional[str] = None
7071

7172
def __init__(self, vocab: Vocabulary, regularizer: RegularizerApplicator = None) -> None:
7273
super().__init__()

allennlp/models/simple_tagger.py

+2
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,5 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
215215
else:
216216
metrics_to_return.update({x: y for x, y in f1_dict.items() if "overall" in x})
217217
return metrics_to_return
218+
219+
default_predictor = "sentence-tagger"

allennlp/predictors/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,5 @@
77
a `Predictor` that wraps it.
88
"""
99
from allennlp.predictors.predictor import Predictor
10-
from allennlp.predictors.decomposable_attention import DecomposableAttentionPredictor
1110
from allennlp.predictors.sentence_tagger import SentenceTaggerPredictor
1211
from allennlp.predictors.text_classifier import TextClassifierPredictor

allennlp/predictors/decomposable_attention.py

-58
This file was deleted.

allennlp/predictors/predictor.py

+2-20
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,6 @@
1515
from allennlp.models.archival import Archive, load_archive
1616
from allennlp.nn import util
1717

18-
# a mapping from model `type` to the default Predictor for that type
19-
DEFAULT_PREDICTORS = {
20-
"atis_parser": "atis-parser",
21-
"basic_classifier": "text_classifier",
22-
"biaffine_parser": "biaffine-dependency-parser",
23-
"bimpm": "textual-entailment",
24-
"constituency_parser": "constituency-parser",
25-
"coref": "coreference-resolution",
26-
"crf_tagger": "sentence-tagger",
27-
"decomposable_attention": "textual-entailment",
28-
"event2mind": "event2mind",
29-
"simple_tagger": "sentence-tagger",
30-
"srl": "semantic-role-labeling",
31-
"srl_bert": "semantic-role-labeling",
32-
"quarel_parser": "quarel-parser",
33-
"wikitables_mml_parser": "wikitables-parser",
34-
}
35-
3618

3719
class Predictor(Registrable):
3820
"""
@@ -299,8 +281,8 @@ def from_archive(
299281

300282
if not predictor_name:
301283
model_type = config.get("model").get("type")
302-
if model_type in DEFAULT_PREDICTORS:
303-
predictor_name = DEFAULT_PREDICTORS[model_type]
284+
model_class, _ = Model.resolve_class_name(model_type)
285+
predictor_name = model_class.default_predictor
304286
predictor_class: Type[Predictor] = Predictor.by_name( # type: ignore
305287
predictor_name
306288
) if predictor_name is not None else cls

allennlp/tests/commands/predict_test.py

+33-23
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from allennlp.data.dataset_readers import DatasetReader, TextClassificationJsonReader
1919
from allennlp.models.archival import load_archive
2020
from allennlp.predictors import Predictor, TextClassifierPredictor
21-
from allennlp.predictors.predictor import DEFAULT_PREDICTORS
2221

2322

2423
class TestPredict(AllenNlpTestCase):
@@ -224,31 +223,42 @@ def test_base_predictor(self):
224223
model_path = str(self.classifier_model_path)
225224
archive = load_archive(model_path)
226225
model_type = archive.config.get("model").get("type")
227-
# Makes sure that we don't have a DEFAULT_PREDICTOR for it. Otherwise the base class
226+
# Makes sure that we don't have a default_predictor for it. Otherwise the base class
228227
# implementation wouldn't be used
229-
del DEFAULT_PREDICTORS["basic_classifier"]
230-
assert model_type not in DEFAULT_PREDICTORS
228+
from allennlp.models import Model
231229

232-
# Doesn't use a --predictor
233-
sys.argv = [
234-
"__main__.py", # executable
235-
"predict", # command
236-
model_path,
237-
str(self.classifier_data_path), # input_file
238-
"--output-file",
239-
str(self.outfile),
240-
"--silent",
241-
"--use-dataset-reader",
242-
]
243-
main()
244-
assert os.path.exists(self.outfile)
245-
with open(self.outfile, "r") as f:
246-
results = [json.loads(line) for line in f]
230+
model_class, _ = Model.resolve_class_name(model_type)
231+
saved_default_predictor = model_class.default_predictor
232+
model_class.default_predictor = None
233+
try:
234+
# Doesn't use a --predictor
235+
sys.argv = [
236+
"__main__.py", # executable
237+
"predict", # command
238+
model_path,
239+
str(self.classifier_data_path), # input_file
240+
"--output-file",
241+
str(self.outfile),
242+
"--silent",
243+
"--use-dataset-reader",
244+
]
245+
main()
246+
assert os.path.exists(self.outfile)
247+
with open(self.outfile, "r") as f:
248+
results = [json.loads(line) for line in f]
247249

248-
assert len(results) == 3
249-
for result in results:
250-
assert set(result.keys()) == {"logits", "probs", "label", "loss", "tokens", "token_ids"}
251-
DEFAULT_PREDICTORS["basic_classifier"] = "text_classifier"
250+
assert len(results) == 3
251+
for result in results:
252+
assert set(result.keys()) == {
253+
"logits",
254+
"probs",
255+
"label",
256+
"loss",
257+
"tokens",
258+
"token_ids",
259+
}
260+
finally:
261+
model_class.default_predictor = saved_default_predictor
252262

253263
def test_batch_prediction_works_with_known_model(self):
254264
with open(self.infile, "w") as f:

0 commit comments

Comments
 (0)