Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Evaluator #5445

Merged
merged 19 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- Added an `Evaluator` class to make comparing source, target, and predictions easier.

## [v2.8.0](https://github.com/allenai/allennlp/releases/tag/v2.8.0) - 2021-11-01

### Added
Expand Down
215 changes: 181 additions & 34 deletions allennlp/commands/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import argparse
import json
import logging
from typing import Any, Dict

from pathlib import Path
from os import PathLike
from typing import Union, Dict, Any, Optional
from copy import deepcopy

from overrides import overrides
Expand All @@ -18,7 +19,7 @@
from allennlp.common.util import prepare_environment
from allennlp.data import DataLoader
from allennlp.models.archival import load_archive
from allennlp.training.util import evaluate
from allennlp.evaluation import Evaluator

logger = logging.getLogger(__name__)

Expand All @@ -39,7 +40,7 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
type=str,
help=(
"path to the file containing the evaluation data"
' (for mutiple files, put ":" between filenames e.g., input1.txt:input2.txt)'
' (for multiple files, put ":" between filenames e.g., input1.txt:input2.txt)'
),
)

Expand All @@ -48,7 +49,7 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
type=str,
help=(
"optional path to write the metrics to as JSON"
' (for mutiple files, put ":" between filenames e.g., output1.txt:output2.txt)'
' (for multiple files, put ":" between filenames e.g., output1.txt:output2.txt)'
),
)

Expand All @@ -57,7 +58,7 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
type=str,
help=(
"optional path to write the predictions to as JSON lines"
' (for mutiple files, put ":" between filenames e.g., output1.jsonl:output2.jsonl)'
' (for multiple files, put ":" between filenames e.g., output1.jsonl:output2.jsonl)'
),
)

Expand Down Expand Up @@ -118,13 +119,113 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
help="outputs tqdm status on separate lines and slows tqdm refresh rate",
)

subparser.add_argument(
"--auto-names",
default="NONE",
help="Automatically create output names for each evaluation file.",
choices=["NONE", "METRICS", "PREDS", "ALL"],
)

subparser.set_defaults(func=evaluate_from_args)

return subparser


def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]:
common_logging.FILE_FRIENDLY_LOGGING = args.file_friendly_logging
return evaluate_from_archive(
archive_file=args.archive_file,
input_file=args.input_file,
output_file=args.output_file,
predictions_output_file=args.predictions_output_file,
batch_size=args.batch_size,
cmd_overrides=args.overrides,
cuda_device=args.cuda_device,
embedding_sources_mapping=args.embedding_sources_mapping,
extend_vocab=args.extend_vocab,
weights_file=args.weights_file,
file_friendly_logging=args.file_friendly_logging,
batch_weight_key=args.batch_weight_key,
auto_names=args.auto_names,
)


def evaluate_from_archive(
archive_file: Union[str, PathLike],
input_file: str,
output_file: Optional[str] = None,
predictions_output_file: Optional[str] = None,
batch_size: Optional[int] = None,
cmd_overrides: Union[str, Dict[str, Any]] = "",
cuda_device: int = -1,
embedding_sources_mapping: str = None,
extend_vocab: bool = False,
weights_file: str = None,
file_friendly_logging: bool = False,
batch_weight_key: str = None,
auto_names: str = "NONE",
) -> Dict[str, Any]:
"""

# Parameters

archive_file: `Union[str, PathLike]`
Path to an archived trained model.

input_file: `str`
path to the file containing the evaluation data (for multiple files,
put ":" between filenames e.g., input1.txt:input2.txt)

output_file: `str`, optional (default=`None`)
optional path to write the metrics to as JSON (for multiple files, put
":" between filenames e.g., output1.txt:output2.txt)

predictions_output_file: `str`, optional (default=`None`)
"optional path to write the predictions to (for multiple files, put ":"
between filenames e.g., output1.jsonl:output2.jsonl)

batch_size: `int`, optional (default=`None`)
If non-empty, the batch size to use during evaluation.

cmd_overrides: `str`, optional (default=`""`)
a json(net) structure used to override the experiment configuration,
e.g., '{\"iterator.batch_size\": 16}'. Nested parameters can be
specified either with nested dictionaries or with dot syntax.

cuda_device: `int`, optional (default=`-1`)
id of GPU to use (if any)

embedding_sources_mapping: `str`, optional (default=`None`)
a JSON dict defining mapping from embedding module path to embedding
pretrained-file used during training. If not passed, and embedding
needs to be extended, we will try to use the original file paths used
during training. If they are not available we will use random vectors
for embedding extension.

extend_vocab: `bool`, optional (default=`False`)
if specified, we will use the instances in your new dataset to extend
your vocabulary. If pretrained-file was used to initialize embedding
layers, you may also need to pass --embedding-sources-mapping.

weights_file:`str`, optional (default=`None`)
A path that overrides which weights file to use

file_friendly_logging : `bool`, optional (default=`False`)
If `True`, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.

batch_weight_key: `str`, optional (default=`None`)
If non-empty, name of metric used to weight the loss on a per-batch basis.

auto_names:`str`, optional (default=`"NONE"`)
Automatically create output names for each evaluation file.

# Returns

all_metrics: `Dict[str, Any]`
The metrics from every evaluation file passed.

"""
common_logging.FILE_FRIENDLY_LOGGING = file_friendly_logging

# Disable some of the more verbose logging statements
logging.getLogger("allennlp.common.params").disabled = True
Expand All @@ -133,77 +234,123 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]:

# Load from archive
archive = load_archive(
args.archive_file,
weights_file=args.weights_file,
cuda_device=args.cuda_device,
overrides=args.overrides,
archive_file,
weights_file=weights_file,
cuda_device=cuda_device,
overrides=cmd_overrides,
)
config = deepcopy(archive.config)
prepare_environment(config)
model = archive.model
model.eval()

# Load the evaluator from the config key `Evaluator`
evaluator_params = config.pop("evaluation", {})
evaluator_params["cuda_device"] = cuda_device
evaluator = Evaluator.from_params(evaluator_params)

# Load the evaluation data
dataset_reader = archive.validation_dataset_reader

# split files
evaluation_data_path_list = args.input_file.split(":")
if args.output_file is not None:
output_file_list = args.output_file.split(":")
assert len(output_file_list) == len(
evaluation_data_path_list
), "The number of `output_file` paths must be equal to the number of datasets being evaluated."
if args.predictions_output_file is not None:
predictions_output_file_list = args.predictions_output_file.split(":")
assert len(predictions_output_file_list) == len(evaluation_data_path_list), (
"The number of `predictions_output_file` paths must be equal"
+ "to the number of datasets being evaluated. "
)
evaluation_data_path_list = input_file.split(":")

# TODO(gabeorlanski): Is it safe to always default to .outputs and .preds?
# TODO(gabeorlanski): Add in way to save to specific output directory
if output_file is not None:
if auto_names == "METRICS" or auto_names == "ALL":
logger.warning(
f"Passed output_files will be ignored, auto_names is" f" set to {auto_names}"
)

# Keep the path of the parent otherwise it will write to the CWD
output_file_list = [
p.parent.joinpath(f"{p.stem}.outputs") for p in map(Path, evaluation_data_path_list)
]
else:
output_file_list = output_file.split(":") # type: ignore
assert len(output_file_list) == len(
evaluation_data_path_list
), "The number of `output_file` paths must be equal to the number of datasets being evaluated."
if predictions_output_file is not None:
if auto_names == "PREDS" or auto_names == "ALL":
logger.warning(
f"Passed predictions files will be ignored, auto_names is" f" set to {auto_names}"
)

# Keep the path of the parent otherwise it will write to the CWD
predictions_output_file_list = [
p.parent.joinpath(f"{p.stem}.preds") for p in map(Path, evaluation_data_path_list)
]
else:
predictions_output_file_list = predictions_output_file.split(":") # type: ignore
assert len(predictions_output_file_list) == len(evaluation_data_path_list), (
"The number of `predictions_output_file` paths must be equal"
+ "to the number of datasets being evaluated. "
)

# output file
output_file_path = None
predictions_output_file_path = None

# embedding sources
if args.extend_vocab:
if extend_vocab:
logger.info("Vocabulary is being extended with embedding sources.")
embedding_sources = (
json.loads(args.embedding_sources_mapping) if args.embedding_sources_mapping else {}
json.loads(embedding_sources_mapping) if embedding_sources_mapping else {}
)

all_metrics = {}
for index in range(len(evaluation_data_path_list)):
config = deepcopy(archive.config)
evaluation_data_path = evaluation_data_path_list[index]
if args.output_file is not None:

# Get the eval file name so we can save each metric by file name in the
# output dictionary.
eval_file_name = Path(evaluation_data_path).stem

if output_file is not None:
# noinspection PyUnboundLocalVariable
output_file_path = output_file_list[index]
if args.predictions_output_file is not None:

if predictions_output_file is not None:
# noinspection PyUnboundLocalVariable
predictions_output_file_path = predictions_output_file_list[index]

logger.info("Reading evaluation data from %s", evaluation_data_path)
data_loader_params = config.get("validation_data_loader", None)
if data_loader_params is None:
data_loader_params = config.get("data_loader")
if args.batch_size:
data_loader_params["batch_size"] = args.batch_size
if batch_size:
data_loader_params["batch_size"] = batch_size
data_loader = DataLoader.from_params(
params=data_loader_params, reader=dataset_reader, data_path=evaluation_data_path
)

if args.extend_vocab:
if extend_vocab:
logger.info("Vocabulary is being extended with test instances.")
model.vocab.extend_from_instances(instances=data_loader.iter_instances())
# noinspection PyUnboundLocalVariable
model.extend_embedder_vocab(embedding_sources)

data_loader.index_with(model.vocab)

metrics = evaluate(
metrics = evaluator(
model,
data_loader,
args.cuda_device,
args.batch_weight_key,
batch_weight_key,
output_file=output_file_path,
predictions_output_file=predictions_output_file_path,
)

# Add the metric prefixed by the file it came from.
for name, value in metrics.items():
if len(evaluation_data_path_list) > 1:
key = f"{eval_file_name}_"
else:
key = ""
all_metrics[f"{key}{name}"] = value

logger.info("Finished evaluating.")

return metrics
return all_metrics
5 changes: 5 additions & 0 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ def train_model(
training_util.create_serialization_dir(params, serialization_dir, recover, force)
params.to_file(os.path.join(serialization_dir, CONFIG_NAME))

# Change Author: Gabe Orlanski
# Placeholder for the time being to make sure no errors are raised b/c of
# the evaluator.
params.pop("evaluation", None)

meta = Meta.new()
meta.to_file(os.path.join(serialization_dir, META_NAME))

Expand Down
1 change: 0 additions & 1 deletion allennlp/common/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def setup_method(self):
logging.getLogger("allennlp.modules.token_embedders.embedding").setLevel(logging.INFO)
logging.getLogger("urllib3.connectionpool").disabled = True
log_pytorch_version_info()

self.TEST_DIR = pathlib.Path(TEST_DIR)

os.makedirs(self.TEST_DIR, exist_ok=True)
Expand Down
2 changes: 2 additions & 0 deletions allennlp/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from allennlp.evaluation.evaluator import Evaluator, SimpleEvaluator
from allennlp.evaluation.postprocessors.postprocessor import Postprocessor
Loading