Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Rename sanity_checks to confidence_checks #5201

Merged
merged 4 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Use `dist_reduce_sum` in distributed metrics.
- Allow Google Cloud Storage paths in `cached_path` ("gs://...").
- Print the first batch to the console by default.
- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0).

### Added

- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module.
- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module.
- Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files.
- Added `allennlp.nn.util.load_state_dict` helper function.
- Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers
Expand Down
6 changes: 3 additions & 3 deletions allennlp/commands/checklist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
The `checklist` subcommand allows you to sanity check your
model's predictions using a trained model and its
The `checklist` subcommand allows you to conduct behavioural
testing for your model's predictions using a trained model and its
[`Predictor`](../predictors/predictor.md#predictor) wrapper.
"""

Expand All @@ -15,7 +15,7 @@
from allennlp.common.checks import check_for_gpu, ConfigurationError
from allennlp.models.archival import load_archive
from allennlp.predictors.predictor import Predictor
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite


@Subcommand.register("checklist")
Expand Down
2 changes: 1 addition & 1 deletion allennlp/common/testing/checklist_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional
from checklist.test_suite import TestSuite
from checklist.test_types import MFT as MinimumFunctionalityTest
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite


@TaskSuite.register("fake-task-suite")
Expand Down
2 changes: 1 addition & 1 deletion allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from allennlp.data.batch import Batch
from allennlp.models import load_archive, Model
from allennlp.training import GradientDescentTrainer
from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification
from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification


class ModelTestCase(AllenNlpTestCase):
Expand Down
2 changes: 2 additions & 0 deletions allennlp/confidence_checks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from allennlp.confidence_checks.verification_base import VerificationBase
from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch import nn as nn
from typing import Tuple, List, Callable
from allennlp.sanity_checks.verification_base import VerificationBase
from allennlp.confidence_checks.verification_base import VerificationBase
import logging

logger = logging.getLogger(__name__)
Expand Down
10 changes: 10 additions & 0 deletions allennlp/confidence_checks/task_checklists/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists.sentiment_analysis_suite import (
SentimentAnalysisSuite,
)
from allennlp.confidence_checks.task_checklists.question_answering_suite import (
QuestionAnsweringSuite,
)
from allennlp.confidence_checks.task_checklists.textual_entailment_suite import (
TextualEntailmentSuite,
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from checklist.test_suite import TestSuite
from checklist.test_types import MFT
from checklist.perturb import Perturb
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.sanity_checks.task_checklists import utils
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists import utils


def _crossproduct(template: CheckListTemplate):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from checklist.test_types import MFT, INV, DIR, Expect
from checklist.editor import Editor
from checklist.perturb import Perturb
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.sanity_checks.task_checklists import utils
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists import utils
from allennlp.data.instance import Instance


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from allennlp.common.registrable import Registrable
from allennlp.common.file_utils import cached_path
from allennlp.predictors.predictor import Predictor
from allennlp.sanity_checks.task_checklists import utils
from allennlp.confidence_checks.task_checklists import utils

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from checklist.test_suite import TestSuite
from checklist.test_types import MFT
from checklist.perturb import Perturb
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.sanity_checks.task_checklists import utils
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists import utils


def _wrap_apply_to_each(perturb_fn: Callable, both: bool = False, *args, **kwargs):
Expand Down
11 changes: 9 additions & 2 deletions allennlp/sanity_checks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
from allennlp.sanity_checks.verification_base import VerificationBase
from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification
from allennlp.confidence_checks.verification_base import VerificationBase
from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification

import warnings

warnings.warn(
"Module 'sanity_checks' is deprecated, please use 'confidence_checks' instead.",
DeprecationWarning,
)
8 changes: 4 additions & 4 deletions allennlp/sanity_checks/task_checklists/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.sanity_checks.task_checklists.sentiment_analysis_suite import (
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists.sentiment_analysis_suite import (
SentimentAnalysisSuite,
)
from allennlp.sanity_checks.task_checklists.question_answering_suite import (
from allennlp.confidence_checks.task_checklists.question_answering_suite import (
QuestionAnsweringSuite,
)
from allennlp.sanity_checks.task_checklists.textual_entailment_suite import (
from allennlp.confidence_checks.task_checklists.textual_entailment_suite import (
TextualEntailmentSuite,
)
2 changes: 1 addition & 1 deletion allennlp/training/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from allennlp.training.callbacks.callback import TrainerCallback
from allennlp.training.callbacks.console_logger import ConsoleLoggerCallback
from allennlp.training.callbacks.sanity_checks import SanityChecksCallback
from allennlp.training.callbacks.confidence_checks import ConfidenceChecksCallback
from allennlp.training.callbacks.tensorboard import TensorBoardCallback
from allennlp.training.callbacks.track_epoch import TrackEpochCallback
from allennlp.training.callbacks.wandb import WandBCallback
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@

from allennlp.training.callbacks.callback import TrainerCallback
from allennlp.data import TensorDict
from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification
from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification


if TYPE_CHECKING:
from allennlp.training.trainer import GradientDescentTrainer


# `sanity_checks` is deprecated and will be removed.
@TrainerCallback.register("sanity_checks")
class SanityChecksCallback(TrainerCallback):
@TrainerCallback.register("confidence_checks")
class ConfidenceChecksCallback(TrainerCallback):
"""
Performs model sanity checks.
Performs model confidence checks.

Checks performed:

* `NormalizationBiasVerification` for detecting invalid combinations of
bias and normalization layers.
See `allennlp.sanity_checks.normalization_bias_verification` for more details.
See `allennlp.confidence_checks.normalization_bias_verification` for more details.

Note: Any new sanity checks should also be added to this callback.
Note: Any new confidence checks should also be added to this callback.
"""

def on_start(
Expand Down Expand Up @@ -54,18 +56,18 @@ def on_batch(
self._verification.destroy_hooks()
detected_pairs = self._verification.collect_detections()
if len(detected_pairs) > 0:
raise SanityCheckError(
raise ConfidenceCheckError(
"The NormalizationBiasVerification check failed. See logs for more details."
)


class SanityCheckError(Exception):
class ConfidenceCheckError(Exception):
"""
The error type raised when a sanity check fails.
The error type raised when a confidence check fails.
"""

def __init__(self, message) -> None:
super().__init__(
message
+ "\nYou can disable these checks by setting the trainer parameter `run_sanity_checks` to `False`."
+ "\nYou can disable these checks by setting the trainer parameter `run_confidence_checks` to `False`."
)
37 changes: 28 additions & 9 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import time
import traceback
import warnings
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type

Expand All @@ -23,7 +24,11 @@
from allennlp.data import DataLoader, TensorDict
from allennlp.models.model import Model
from allennlp.training import util as training_util
from allennlp.training.callbacks import TrainerCallback, SanityChecksCallback, ConsoleLoggerCallback
from allennlp.training.callbacks import (
TrainerCallback,
ConfidenceChecksCallback,
ConsoleLoggerCallback,
)
from allennlp.training.checkpointer import Checkpointer
from allennlp.training.learning_rate_schedulers import LearningRateScheduler
from allennlp.training.metric_tracker import MetricTracker
Expand Down Expand Up @@ -263,10 +268,13 @@ class GradientDescentTrainer(Trainer):
addition to any other callbacks listed in the `callbacks` parameter.
When set to `False`, `DEFAULT_CALLBACKS` are not used.

run_confidence_checks : `bool`, optional (default = `True`)
Determines whether model confidence checks, such as
[`NormalizationBiasVerification`](../../confidence_checks/normalization_bias_verification/),
are run.

run_sanity_checks : `bool`, optional (default = `True`)
Determines whether model sanity checks, such as
[`NormalizationBiasVerification`](../../sanity_checks/normalization_bias_verification/),
are ran.
This parameter is deprecated. Please use `run_confidence_checks` instead.

"""

Expand Down Expand Up @@ -294,7 +302,8 @@ def __init__(
num_gradient_accumulation_steps: int = 1,
use_amp: bool = False,
enable_default_callbacks: bool = True,
run_sanity_checks: bool = True,
run_confidence_checks: bool = True,
**kwargs,
) -> None:
super().__init__(
serialization_dir=serialization_dir,
Expand All @@ -304,6 +313,13 @@ def __init__(
world_size=world_size,
)

if "run_sanity_checks" in kwargs:
warnings.warn(
"'run_sanity_checks' is deprecated, please use 'run_confidence_checks' instead.",
DeprecationWarning,
)
run_confidence_checks = kwargs["run_sanity_checks"]

# I am not calling move_to_gpu here, because if the model is
# not already on the GPU then the optimizer is going to be wrong.
self.model = model
Expand Down Expand Up @@ -345,8 +361,9 @@ def __init__(

self._callbacks = callbacks or []
default_callbacks = list(DEFAULT_CALLBACKS) if enable_default_callbacks else []
if run_sanity_checks:
default_callbacks.append(SanityChecksCallback)

if run_confidence_checks:
default_callbacks.append(ConfidenceChecksCallback)
for callback_cls in default_callbacks:
for callback in self._callbacks:
if callback.__class__ == callback_cls:
Expand Down Expand Up @@ -1014,7 +1031,8 @@ def from_partial_objects(
checkpointer: Lazy[Checkpointer] = Lazy(Checkpointer),
callbacks: List[Lazy[TrainerCallback]] = None,
enable_default_callbacks: bool = True,
run_sanity_checks: bool = True,
run_confidence_checks: bool = True,
**kwargs,
) -> "Trainer":
"""
This method exists so that we can have a documented method to construct this class using
Expand Down Expand Up @@ -1106,7 +1124,8 @@ def from_partial_objects(
num_gradient_accumulation_steps=num_gradient_accumulation_steps,
use_amp=use_amp,
enable_default_callbacks=enable_default_callbacks,
run_sanity_checks=run_sanity_checks,
run_confidence_checks=run_confidence_checks,
**kwargs,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch

from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.testing.sanity_check_test import (
from allennlp.common.testing.confidence_check_test import (
FakeModelForTestingNormalizationBiasVerification,
)
from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification
from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification


class TestNormalizationBiasVerification(AllenNlpTestCase):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from allennlp.sanity_checks.task_checklists.sentiment_analysis_suite import SentimentAnalysisSuite
from allennlp.confidence_checks.task_checklists.sentiment_analysis_suite import (
SentimentAnalysisSuite,
)
from allennlp.common.testing import AllenNlpTestCase, requires_gpu
from allennlp.models.archival import load_archive
from allennlp.predictors import Predictor
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.checks import ConfigurationError
from allennlp.models.archival import load_archive
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from allennlp.sanity_checks.task_checklists import utils
from allennlp.confidence_checks.task_checklists import utils
from allennlp.common.testing import AllenNlpTestCase


Expand Down
18 changes: 9 additions & 9 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
TrainerCallback,
TrackEpochCallback,
TensorBoardCallback,
SanityChecksCallback,
ConfidenceChecksCallback,
ConsoleLoggerCallback,
)
from allennlp.training.callbacks.sanity_checks import SanityCheckError
from allennlp.training.callbacks.confidence_checks import ConfidenceCheckError
from allennlp.training.learning_rate_schedulers import CosineWithRestarts
from allennlp.training.learning_rate_schedulers import ExponentialLearningRateScheduler
from allennlp.training.momentum_schedulers import MomentumScheduler
Expand All @@ -49,7 +49,7 @@
TensorField,
)
from allennlp.training.optimizers import Optimizer
from allennlp.common.testing.sanity_check_test import (
from allennlp.common.testing.confidence_check_test import (
FakeModelForTestingNormalizationBiasVerification,
)

Expand Down Expand Up @@ -814,7 +814,7 @@ def test_trainer_can_log_learning_rates_tensorboard(self):

trainer.train()

def test_sanity_check_callback(self):
def test_confidence_check_callback(self):
model_with_bias = FakeModelForTestingNormalizationBiasVerification(use_bias=True)
inst = Instance({"x": TensorField(torch.rand(3, 1, 4))})
data_loader = SimpleDataLoader([inst, inst], 2)
Expand All @@ -824,12 +824,12 @@ def test_sanity_check_callback(self):
data_loader,
num_epochs=1,
serialization_dir=self.TEST_DIR,
callbacks=[SanityChecksCallback(serialization_dir=self.TEST_DIR)],
callbacks=[ConfidenceChecksCallback(serialization_dir=self.TEST_DIR)],
)
with pytest.raises(SanityCheckError):
with pytest.raises(ConfidenceCheckError):
trainer.train()

def test_sanity_check_default(self):
def test_confidence_check_default(self):
model_with_bias = FakeModelForTestingNormalizationBiasVerification(use_bias=True)
inst = Instance({"x": TensorField(torch.rand(3, 1, 4))})
data_loader = SimpleDataLoader([inst, inst], 2)
Expand All @@ -839,15 +839,15 @@ def test_sanity_check_default(self):
data_loader=data_loader,
num_epochs=1,
)
with pytest.raises(SanityCheckError):
with pytest.raises(ConfidenceCheckError):
trainer.train()

trainer = GradientDescentTrainer.from_partial_objects(
model_with_bias,
serialization_dir=self.TEST_DIR,
data_loader=data_loader,
num_epochs=1,
run_sanity_checks=False,
run_confidence_checks=False,
)

# Check is not run, so no failure.
Expand Down