Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Configure validation frequency #5534

Merged
merged 13 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support to push models directly to the [Hugging Face Hub](https://huggingface.co/) with the command `allennlp push-to-hf`.
- More default tests for the `TextualEntailmentSuite`.
- Added attribute `_should_validate_this_epoch` to `GradientDescentTrainer` that controls whether validation is run at the end of each epoch.
- Added `ShouldValidateCallback` that can be used to configure the frequency of validation during training.

### Changed

Expand Down
1 change: 1 addition & 0 deletions allennlp/training/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from allennlp.training.callbacks.track_epoch import TrackEpochCallback
from allennlp.training.callbacks.wandb import WandBCallback
from allennlp.training.callbacks.backward import MixedPrecisionBackwardCallback, OnBackwardException
from allennlp.training.callbacks.should_validate import ShouldValidateCallback
41 changes: 41 additions & 0 deletions allennlp/training/callbacks/should_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Dict, Any, TYPE_CHECKING, Optional

from allennlp.training.callbacks.callback import TrainerCallback

if TYPE_CHECKING:
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


@TrainerCallback.register("should_validate_callback")
class ShouldValidateCallback(TrainerCallback):
"""
A callback that you can pass to the `GradientDescentTrainer` to change the frequency of
validation during training. If `validation_start` is not `None`, validation will not occur until
`validation_start` epochs have elapsed. If `validation_interval` is not `None`, validation will
run every `validation_interval` number of epochs epochs.
"""

def __init__(
self,
serialization_dir: str,
validation_start: Optional[int] = None,
validation_interval: Optional[int] = None,
) -> None:
super().__init__(serialization_dir)
self._validation_start = validation_start
self._validation_interval = validation_interval

def on_epoch(
self,
trainer: "GradientDescentTrainer",
metrics: Dict[str, Any],
epoch: int,
is_primary: bool = True,
**kwargs,
) -> None:
if self._validation_start is not None and epoch < self._validation_start:
trainer._should_validate_this_epoch = False
elif self._validation_interval is not None and epoch % self._validation_interval != 0:
trainer._should_validate_this_epoch = False
else:
trainer._should_validate_this_epoch = True
13 changes: 9 additions & 4 deletions allennlp/training/gradient_descent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def __init__(
self._batches_in_epoch_completed: int = 0
self._start_after_batches_in_epoch_completed: int = 0
self._best_model_filename: Optional[str] = None
self._should_validate_this_epoch: bool = True

# This is a kind of training state, but it is not serialized with the trainer state, because we can
# re-create it with `epochs_completed` and `batches_in_epoch_completed`.
Expand Down Expand Up @@ -811,8 +812,8 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]:
elif key.startswith("worker_") and key.endswith("_memory_MB"):
metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)

this_epoch_val_metric: float = 0.0
if self._validation_data_loader is not None:
this_epoch_val_metric: Optional[float] = None
if self._should_validate_this_epoch and self._validation_data_loader is not None:
with torch.no_grad():
# We have a validation set, so compute all the metrics on it.
val_loss, val_reg_loss, num_batches = self._validation_loss(epoch)
Expand Down Expand Up @@ -851,7 +852,7 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]:
for key, value in val_metrics.items():
metrics["validation_" + key] = value

if self._metric_tracker.is_best_so_far():
if self._should_validate_this_epoch and self._metric_tracker.is_best_so_far():
# Update all the best_ metrics.
# (Otherwise they just stay the same as they were.)
metrics["best_epoch"] = epoch
Expand Down Expand Up @@ -891,7 +892,11 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]:
if self._distributed:
dist.barrier()

if self._serialization_dir and self._metric_tracker.is_best_so_far():
if (
self._should_validate_this_epoch
and self._serialization_dir
and self._metric_tracker.is_best_so_far()
):
should_save_model_state: bool
if self._ddp_wrapped_model is not None and self._ddp_wrapped_model.is_sharded:
# Each worker saves its own shard for now (we combine the shards later).
Expand Down
42 changes: 42 additions & 0 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ConfidenceChecksCallback,
ConsoleLoggerCallback,
OnBackwardException,
ShouldValidateCallback,
)
from allennlp.training.callbacks.confidence_checks import ConfidenceCheckError
from allennlp.training.learning_rate_schedulers import CosineWithRestarts
Expand Down Expand Up @@ -1331,6 +1332,47 @@ def test_console_log_callback(self):
)
trainer.train()

def test_should_validate_callback(self):
total_instances = 1000
batch_size = 25

reader = FakeDatasetReader(total_instances, batch_size)
data_loader = SimpleDataLoader.from_dataset_reader(
reader, "fake_path", batch_size=batch_size
)
instances = list(data_loader.iter_instances())
vocab = Vocabulary.from_instances(instances)
data_loader.index_with(vocab)
model = FakeModel(vocab)
optimizer = torch.optim.SGD(model.parameters(), 0.01, momentum=0.9)
callback = ShouldValidateCallback.from_params(
Params({"validation_start": 4, "validation_interval": 2}),
serialization_dir=self.TEST_DIR,
)

# Check that training works with the callback
trainer = GradientDescentTrainer(
model,
optimizer,
data_loader,
num_epochs=6,
serialization_dir=self.TEST_DIR,
callbacks=[callback],
)
trainer.train()

# Doesn't satisfy 'validation_start' or 'validation_interval'
callback.on_epoch(trainer, metrics={}, epoch=1)
assert not trainer._should_validate_this_epoch

# Satisfies 'validation_start' but not 'validation_interval'
callback.on_epoch(trainer, metrics={}, epoch=2)
assert not trainer._should_validate_this_epoch

# Satisfies both 'validation_start' and 'validation_interval'
callback.on_epoch(trainer, metrics={}, epoch=4)
assert trainer._should_validate_this_epoch


@requires_gpu
class TestAmpTrainer(TrainerTestBase):
Expand Down