Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit f8b10a9

Browse files
authored
Add a no-op trainer. (#2610)
- Simply loads a model, creates the vocab and serializes without any training. - Intended to be used principally for untrained baselines like majority class.
1 parent 9e72ee0 commit f8b10a9

File tree

7 files changed

+122
-1
lines changed

7 files changed

+122
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Dict
2+
3+
import torch
4+
5+
from allennlp.commands.train import train_model
6+
from allennlp.common import Params
7+
from allennlp.common.testing import AllenNlpTestCase
8+
from allennlp.models import load_archive, Model
9+
10+
SEQUENCE_TAGGING_DATA_PATH = str(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
11+
12+
13+
@Model.register('constant')
14+
class ConstantModel(Model):
15+
def forward(self, *inputs) -> Dict[str, torch.Tensor]:
16+
return {"class": torch.tensor(98)} # pylint: disable=not-callable
17+
18+
class TestTrain(AllenNlpTestCase):
19+
20+
def test_train_model(self):
21+
params = lambda: Params({
22+
"model": {
23+
"type": "constant"
24+
},
25+
"dataset_reader": {"type": "sequence_tagging"},
26+
"train_data_path": SEQUENCE_TAGGING_DATA_PATH,
27+
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
28+
"iterator": {"type": "basic", "batch_size": 2},
29+
"trainer": {
30+
"type": "no_op"
31+
}
32+
})
33+
34+
serialization_dir = self.TEST_DIR / 'serialization_directory'
35+
train_model(params(), serialization_dir=serialization_dir)
36+
archive = load_archive(str(serialization_dir / "model.tar.gz"))
37+
model = archive.model
38+
assert model.forward(torch.tensor([1, 2, 3]))["class"] == torch.tensor(98) # pylint: disable=not-callable
39+
assert model.vocab.get_vocab_size() == 9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
from typing import Dict
3+
4+
import torch
5+
6+
from allennlp.common.testing import AllenNlpTestCase
7+
from allennlp.data import Vocabulary
8+
from allennlp.data.dataset_readers import SequenceTaggingDatasetReader
9+
from allennlp.models.model import Model
10+
from allennlp.training import NoOpTrainer
11+
12+
13+
class ConstantModel(Model):
14+
def forward(self, *inputs) -> Dict[str, torch.Tensor]:
15+
return {"class": torch.tensor(98)} # pylint: disable=not-callable
16+
17+
class TestNoOpTrainer(AllenNlpTestCase):
18+
def setUp(self):
19+
super().setUp()
20+
self.instances = SequenceTaggingDatasetReader().read(self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
21+
vocab = Vocabulary.from_instances(self.instances)
22+
self.vocab = vocab
23+
self.model = ConstantModel(vocab)
24+
25+
def test_trainer_serializes(self):
26+
serialization_dir = self.TEST_DIR / "serialization_dir"
27+
trainer = NoOpTrainer(serialization_dir=serialization_dir, model=self.model)
28+
metrics = trainer.train()
29+
assert metrics == {}
30+
assert os.path.exists(serialization_dir / "best.th")
31+
assert os.path.exists(serialization_dir / "vocabulary")

allennlp/training/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
from allennlp.training.no_op_trainer import NoOpTrainer
12
from allennlp.training.trainer import Trainer
23
from allennlp.training.trainer_base import TrainerBase

allennlp/training/no_op_trainer.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
from typing import Dict, Any
3+
4+
from allennlp.common import Params
5+
from allennlp.models import Model
6+
from allennlp.training.checkpointer import Checkpointer
7+
from allennlp.training.trainer import TrainerPieces
8+
from allennlp.training.trainer_base import TrainerBase
9+
10+
@TrainerBase.register("no_op")
11+
class NoOpTrainer(TrainerBase):
12+
def __init__(self, serialization_dir: str, model: Model) -> None:
13+
"""
14+
A trivial trainer to assist in making model archives for models that do not actually
15+
require training. For instance, a majority class baseline.
16+
"""
17+
18+
super().__init__(serialization_dir, cuda_device=-1)
19+
self.model = model
20+
21+
@classmethod
22+
def from_params(cls, # type: ignore
23+
params: Params,
24+
serialization_dir: str,
25+
recover: bool = False):
26+
# pylint: disable=arguments-differ
27+
pieces = TrainerPieces.from_params(params, serialization_dir, recover) # pylint: disable=no-member
28+
return NoOpTrainer(serialization_dir, pieces.model)
29+
30+
def train(self) -> Dict[str, Any]:
31+
self.model.vocab.save_to_files(os.path.join(self._serialization_dir, "vocabulary"))
32+
33+
checkpointer = Checkpointer(self._serialization_dir)
34+
checkpointer.save_checkpoint(epoch=0,
35+
model_state=self.model.state_dict(),
36+
training_states={},
37+
is_best_so_far=True)
38+
return {}

allennlp/training/trainer_base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,8 @@ def from_params(cls, # type: ignore
7979
params=pieces.params,
8080
validation_iterator=pieces.validation_iterator)
8181
else:
82-
return TrainerBase.by_name(typ3).from_params(params, serialization_dir, recover)
82+
klass = TrainerBase.by_name(typ3)
83+
# Explicit check to prevent recursion.
84+
is_overriden = klass.from_params.__func__ != TrainerBase.from_params.__func__ # type: ignore
85+
assert is_overriden, f"Class {klass.__name__} must override `from_params`."
86+
return klass.from_params(params, serialization_dir, recover)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
allennlp.training.no_op_trainer
2+
======================================
3+
4+
.. automodule:: allennlp.training.no_op_trainer
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:

doc/api/allennlp.training.rst

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ for training AllenNLP models.
1313
allennlp.training.metric_tracker
1414
allennlp.training.metrics
1515
allennlp.training.moving_average
16+
allennlp.training.no_op_trainer
1617
allennlp.training.optimizers
1718
allennlp.training.tensorboard_writer
1819
allennlp.training.trainer

0 commit comments

Comments
 (0)