Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit ec96bef

Browse files
committed
create dummy command if checklist not available
1 parent 19e43c4 commit ec96bef

File tree

4 files changed

+240
-221
lines changed

4 files changed

+240
-221
lines changed

allennlp/commands/__init__.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from allennlp import __version__
88
from allennlp.commands.build_vocab import BuildVocab
99
from allennlp.commands.cached_path import CachedPath
10+
from allennlp.commands.checklist import CheckList
1011
from allennlp.commands.diff import Diff
1112
from allennlp.commands.evaluate import Evaluate
1213
from allennlp.commands.find_learning_rate import FindLearningRate
@@ -22,18 +23,6 @@
2223

2324
logger = logging.getLogger(__name__)
2425

25-
try:
26-
"""
27-
The `allennlp checklist` command requires installation of the optional dependency `checklist`.
28-
If you're using conda, it can be installed with `conda install allennlp-checklist`,
29-
otherwise use `pip install allennlp[checklist]`.
30-
"""
31-
with warnings.catch_warnings():
32-
warnings.simplefilter("ignore")
33-
from allennlp.commands.checklist import CheckList
34-
except ImportError:
35-
pass
36-
3726

3827
class ArgumentParserWithDefaults(argparse.ArgumentParser):
3928
"""
+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""
2+
The `checklist` subcommand allows you to conduct behavioural
3+
testing for your model's predictions using a trained model and its
4+
[`Predictor`](../predictors/predictor.md#predictor) wrapper.
5+
"""
6+
7+
from typing import Optional, Dict, Any, List
8+
import argparse
9+
import sys
10+
import json
11+
import logging
12+
13+
14+
from allennlp.commands.subcommand import Subcommand
15+
from allennlp.common.checks import check_for_gpu, ConfigurationError
16+
from allennlp.models.archival import load_archive
17+
from allennlp.predictors.predictor import Predictor
18+
19+
logger = logging.getLogger(__name__)
20+
21+
try:
22+
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
23+
except ImportError:
24+
raise
25+
26+
27+
@Subcommand.register("checklist")
28+
class CheckList(Subcommand):
29+
def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
30+
31+
description = """Run the specified model through a checklist suite."""
32+
subparser = parser.add_parser(
33+
self.name,
34+
description=description,
35+
help="Run a trained model through a checklist suite.",
36+
)
37+
38+
subparser.add_argument(
39+
"archive_file", type=str, help="The archived model to make predictions with"
40+
)
41+
42+
subparser.add_argument("task", type=str, help="The name of the task suite")
43+
44+
subparser.add_argument("--checklist-suite", type=str, help="The checklist suite path")
45+
46+
subparser.add_argument(
47+
"--capabilities",
48+
nargs="+",
49+
default=[],
50+
help=('An optional list of strings of capabilities. Eg. "[Vocabulary, Robustness]"'),
51+
)
52+
53+
subparser.add_argument(
54+
"--max-examples",
55+
type=int,
56+
default=None,
57+
help="Maximum number of examples to check per test.",
58+
)
59+
60+
subparser.add_argument(
61+
"--task-suite-args",
62+
type=str,
63+
default="",
64+
help=(
65+
"An optional JSON structure used to provide additional parameters to the task suite"
66+
),
67+
)
68+
69+
subparser.add_argument(
70+
"--print-summary-args",
71+
type=str,
72+
default="",
73+
help=(
74+
"An optional JSON structure used to provide additional "
75+
"parameters for printing test summary"
76+
),
77+
)
78+
79+
subparser.add_argument("--output-file", type=str, help="Path to output file")
80+
81+
subparser.add_argument(
82+
"--cuda-device", type=int, default=-1, help="ID of GPU to use (if any)"
83+
)
84+
85+
subparser.add_argument(
86+
"--predictor", type=str, help="Optionally specify a specific predictor to use"
87+
)
88+
89+
subparser.add_argument(
90+
"--predictor-args",
91+
type=str,
92+
default="",
93+
help=(
94+
"An optional JSON structure used to provide additional parameters to the predictor"
95+
),
96+
)
97+
98+
subparser.set_defaults(func=_run_suite)
99+
100+
return subparser
101+
102+
103+
def _get_predictor(args: argparse.Namespace) -> Predictor:
104+
check_for_gpu(args.cuda_device)
105+
archive = load_archive(
106+
args.archive_file,
107+
cuda_device=args.cuda_device,
108+
)
109+
110+
predictor_args = args.predictor_args.strip()
111+
if len(predictor_args) <= 0:
112+
predictor_args = {}
113+
else:
114+
predictor_args = json.loads(predictor_args)
115+
116+
return Predictor.from_archive(
117+
archive,
118+
args.predictor,
119+
extra_args=predictor_args,
120+
)
121+
122+
123+
def _get_task_suite(args: argparse.Namespace) -> TaskSuite:
124+
available_tasks = TaskSuite.list_available()
125+
if args.task in available_tasks:
126+
suite_name = args.task
127+
else:
128+
raise ConfigurationError(
129+
f"'{args.task}' is not a recognized task suite. "
130+
f"Available tasks are: {available_tasks}."
131+
)
132+
133+
file_path = args.checklist_suite
134+
135+
task_suite_args = args.task_suite_args.strip()
136+
if len(task_suite_args) <= 0:
137+
task_suite_args = {}
138+
else:
139+
task_suite_args = json.loads(task_suite_args)
140+
141+
return TaskSuite.constructor(
142+
name=suite_name,
143+
suite_file=file_path,
144+
extra_args=task_suite_args,
145+
)
146+
147+
148+
class _CheckListManager:
149+
def __init__(
150+
self,
151+
task_suite: TaskSuite,
152+
predictor: Predictor,
153+
capabilities: Optional[List[str]] = None,
154+
max_examples: Optional[int] = None,
155+
output_file: Optional[str] = None,
156+
print_summary_args: Optional[Dict[str, Any]] = None,
157+
) -> None:
158+
self._task_suite = task_suite
159+
self._predictor = predictor
160+
self._capabilities = capabilities
161+
self._max_examples = max_examples
162+
self._output_file = None if output_file is None else open(output_file, "w")
163+
self._print_summary_args = print_summary_args or {}
164+
165+
if capabilities:
166+
self._print_summary_args["capabilities"] = capabilities
167+
168+
def run(self) -> None:
169+
self._task_suite.run(
170+
self._predictor, capabilities=self._capabilities, max_examples=self._max_examples
171+
)
172+
173+
# We pass in an IO object.
174+
output_file = self._output_file or sys.stdout
175+
self._task_suite.summary(file=output_file, **self._print_summary_args)
176+
177+
# If `_output_file` was None, there would be nothing to close.
178+
if self._output_file is not None:
179+
self._output_file.close()
180+
181+
182+
def _run_suite(args: argparse.Namespace) -> None:
183+
184+
task_suite = _get_task_suite(args)
185+
predictor = _get_predictor(args)
186+
187+
print_summary_args = args.print_summary_args.strip()
188+
if len(print_summary_args) <= 0:
189+
print_summary_args = {}
190+
else:
191+
print_summary_args = json.loads(print_summary_args)
192+
193+
capabilities = args.capabilities
194+
max_examples = args.max_examples
195+
196+
manager = _CheckListManager(
197+
task_suite,
198+
predictor,
199+
capabilities,
200+
max_examples,
201+
args.output_file,
202+
print_summary_args,
203+
)
204+
manager.run()

0 commit comments

Comments
 (0)