Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Configure validation frequency #5534

Merged
merged 13 commits into from
Jan 4, 2022
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support to push models directly to the [Hugging Face Hub](https://huggingface.co/) with the command `allennlp push-to-hf`.
- More default tests for the `TextualEntailmentSuite`.
- Added attribute `_should_validate_this_epoch` to `GradientDescentTrainer` that controls whether validation is run at the end of each epoch.
- Added `ShouldValidateCallback` that can be used to configure the frequency of validation during training.

### Changed

Expand Down
1 change: 1 addition & 0 deletions allennlp/training/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from allennlp.training.callbacks.track_epoch import TrackEpochCallback
from allennlp.training.callbacks.wandb import WandBCallback
from allennlp.training.callbacks.backward import MixedPrecisionBackwardCallback, OnBackwardException
from allennlp.training.callbacks.should_validate import ShouldValidateCallback
41 changes: 41 additions & 0 deletions allennlp/training/callbacks/should_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Dict, Any, TYPE_CHECKING, Optional

from allennlp.training.callbacks.callback import TrainerCallback

if TYPE_CHECKING:
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


@TrainerCallback.register("should_validate_callback")
class ShouldValidateCallback(TrainerCallback):
"""
A callback that you can pass to the `GradientDescentTrainer` to change the frequency of
validation during training. If `validation_start` is not `None`, validation will not occur until
`validation_start` epochs have elapsed. If `validation_interval` is not `None`, validation will
run every `validation_interval` number of epochs epochs.
"""

def __init__(
self,
serialization_dir: str,
validation_start: Optional[int] = None,
validation_interval: Optional[int] = None,
) -> None:
super().__init__(serialization_dir)
self._validation_start = validation_start
self._validation_interval = validation_interval

def on_epoch(
self,
trainer: "GradientDescentTrainer",
metrics: Dict[str, Any],
epoch: int,
is_primary: bool = True,
**kwargs,
) -> None:
if self._validation_start is not None and epoch < self._validation_start:
trainer._should_validate_this_epoch = False
elif self._validation_interval is not None and epoch % self._validation_interval != 0:
trainer._should_validate_this_epoch = False
else:
trainer._should_validate_this_epoch = True
3 changes: 2 additions & 1 deletion allennlp/training/gradient_descent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def __init__(
self._batches_in_epoch_completed: int = 0
self._start_after_batches_in_epoch_completed: int = 0
self._best_model_filename: Optional[str] = None
self._should_validate_this_epoch: bool = True

# This is a kind of training state, but it is not serialized with the trainer state, because we can
# re-create it with `epochs_completed` and `batches_in_epoch_completed`.
Expand Down Expand Up @@ -811,7 +812,7 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]:
metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)

this_epoch_val_metric: float = 0.0
if self._validation_data_loader is not None:
if self._validation_data_loader is not None and self._should_validate_this_epoch:
with torch.no_grad():
# We have a validation set, so compute all the metrics on it.
val_loss, val_reg_loss, num_batches = self._validation_loss(epoch)
Expand Down
42 changes: 42 additions & 0 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ConfidenceChecksCallback,
ConsoleLoggerCallback,
OnBackwardException,
ShouldValidateCallback,
)
from allennlp.training.callbacks.confidence_checks import ConfidenceCheckError
from allennlp.training.learning_rate_schedulers import CosineWithRestarts
Expand Down Expand Up @@ -1331,6 +1332,47 @@ def test_console_log_callback(self):
)
trainer.train()

def test_should_validate_callback(self):
total_instances = 1000
batch_size = 25

reader = FakeDatasetReader(total_instances, batch_size)
data_loader = SimpleDataLoader.from_dataset_reader(
reader, "fake_path", batch_size=batch_size
)
instances = list(data_loader.iter_instances())
vocab = Vocabulary.from_instances(instances)
data_loader.index_with(vocab)
model = FakeModel(vocab)
optimizer = torch.optim.SGD(model.parameters(), 0.01, momentum=0.9)
callback = ShouldValidateCallback.from_params(
Params({"validation_start": 4, "validation_interval": 2}),
serialization_dir=self.TEST_DIR,
)

# Check that training works with the callback
trainer = GradientDescentTrainer(
model,
optimizer,
data_loader,
num_epochs=6,
serialization_dir=self.TEST_DIR,
callbacks=[callback],
)
trainer.train()

# Doesn't satisfy 'validation_start' or 'validation_interval'
callback.on_epoch(trainer, metrics={}, epoch=1)
assert not trainer._should_validate_this_epoch

# Satisfies 'validation_start' but not 'validation_interval'
callback.on_epoch(trainer, metrics={}, epoch=2)
assert not trainer._should_validate_this_epoch

# Satisfies both 'validation_start' and 'validation_interval'
callback.on_epoch(trainer, metrics={}, epoch=4)
assert trainer._should_validate_this_epoch


@requires_gpu
class TestAmpTrainer(TrainerTestBase):
Expand Down