Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Registrable _to_params default functionality #5403

Merged
merged 11 commits into from
Oct 8, 2021
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added in a default behavior to the `_to_params` method of `Registrable` so that in the case it is not implemented by the child class, it will still produce _a parameter dictionary_.
- Added in `_to_params` implementations to all tokenizers.
- Added support to evaluate mutiple datasets and produce corresponding output files in the `evaluate` command.
- Added more documentation to the learning rate schedulers to include a sample config object for how to use it.
- Moved the pytorch learning rate schedulers wrappers to their own file called `pytorch_lr_schedulers.py` so that they will have their own documentation page.
Expand Down
87 changes: 86 additions & 1 deletion allennlp/common/registrable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,21 @@
"""
import importlib
import logging
import inspect
from collections import defaultdict
from typing import Callable, ClassVar, DefaultDict, Dict, List, Optional, Tuple, Type, TypeVar, cast
from typing import (
Callable,
ClassVar,
DefaultDict,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
cast,
Any,
)

from allennlp.common.checks import ConfigurationError
from allennlp.common.from_params import FromParams
Expand Down Expand Up @@ -218,6 +231,78 @@ def list_available(cls) -> List[str]:
else:
return [default] + [k for k in keys if k != default]

def _to_params(self) -> Dict[str, Any]:
"""
Default behavior to get a params dictionary from a registrable class
that does NOT have a _to_params implementation. It is NOT recommended to
use this method. Rather this method is a minial implementation that
exists so that calling `_to_params` does not break.

# Returns

parameter_dict: `Dict[str, Any]`
A minimal parameter dictionary for a given registrable class. Will
get the registered name and return that as well as any positional
arguments it can find the value of.

"""
logger.warning(
f"'{self.__class__.__name__}' does not implement '_to_params`. Will"
f" use Registrable's `_to_params`."
)

# Get the list of parent classes in the MRO in order to check where to
# look for the registered name. Skip the first because that is the
# current class.
mro = inspect.getmro(self.__class__)[1:]

registered_name = None
for parent in mro:
# Check if Parent has any registered classes
try:
registered_classes = self._registry[parent]
except KeyError:
continue

# Found a dict of (name,(class,constructor)) pairs. Check if the
# current class is in it.
for name, registered_value in registered_classes.items():
registered_class, _ = registered_value
if registered_class == self.__class__:
registered_name = name
break

# Extra break to end the top loop.
if registered_name is not None:
break

if registered_name is None:
raise KeyError(f"'{self.__class__.__name__}' is not registered")

parameter_dict = {"type": registered_name}

# Get the parameters from the init function.
for parameter in inspect.signature(self.__class__).parameters.values():
# Skip non-positional arguments. For simplicity, these are arguments
# without a default value as those will be required for the
# `from_params` method.
if parameter.default != inspect.Parameter.empty:
logger.debug(f"Skipping parameter {parameter.name}")
continue

# Try to get the value of the parameter from the class. Will only
# try 'name' and '_name'. If it is not there, the parameter is not
# added to the returned dict.
if hasattr(self, parameter.name):
parameter_dict[parameter.name] = getattr(self, parameter.name)
elif hasattr(self, f"_{parameter.name}"):
parameter_dict[parameter.name] = getattr(self, f"_{parameter.name}")
else:
logger.warning(f"Could not find a value for positional argument {parameter.name}")
continue

return parameter_dict


def _get_suggestion(name: str, available: List[str]) -> Optional[str]:
# First check for simple mistakes like using '-' instead of '_', or vice-versa.
Expand Down
11 changes: 10 additions & 1 deletion allennlp/data/tokenizers/character_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import List, Union, Dict, Any

from overrides import overrides

Expand Down Expand Up @@ -83,3 +83,12 @@ def __eq__(self, other) -> bool:
if isinstance(self, other.__class__):
return self.__dict__ == other.__dict__
return NotImplemented

def _to_params(self) -> Dict[str, Any]:
return {
"type": "character",
"byte_encoding": self._byte_encoding,
"lowercase_characters": self._lowercase_characters,
"start_tokens": self._start_tokens,
"end_tokens": self._end_tokens,
}
16 changes: 14 additions & 2 deletions allennlp/data/tokenizers/pretrained_transformer_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,16 @@ def __init__(
tokenizer_kwargs = {}
else:
tokenizer_kwargs = tokenizer_kwargs.copy()
tokenizer_kwargs.setdefault("use_fast", True)
# Note: Just because we request a fast tokenizer doesn't mean we get one.
tokenizer_kwargs.setdefault("use_fast", True)

self._tokenizer_kwargs = tokenizer_kwargs
self._model_name = model_name

from allennlp.common import cached_transformers

self.tokenizer = cached_transformers.get_tokenizer(
model_name, add_special_tokens=False, **tokenizer_kwargs
self._model_name, add_special_tokens=False, **self._tokenizer_kwargs
)

self._add_special_tokens = add_special_tokens
Expand Down Expand Up @@ -452,3 +455,12 @@ def num_special_tokens_for_pair(self) -> int:
+ len(self.sequence_pair_mid_tokens)
+ len(self.sequence_pair_end_tokens)
)

def _to_params(self) -> Dict[str, Any]:
return {
"type": "pretrained_transformer",
"model_name": self._model_name,
"add_special_tokens": self._add_special_tokens,
"max_length": self._max_length,
"tokenizer_kwargs": self._tokenizer_kwargs,
}
10 changes: 8 additions & 2 deletions allennlp/data/tokenizers/sentence_splitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Dict, Any
from overrides import overrides

import spacy
Expand Down Expand Up @@ -44,8 +44,11 @@ class SpacySentenceSplitter(SentenceSplitter):
"""

def __init__(self, language: str = "en_core_web_sm", rule_based: bool = False) -> None:
self._language = language
self._rule_based = rule_based

# we need spacy's dependency parser if we're not using rule-based sentence boundary detection.
self.spacy = get_spacy_model(language, parse=not rule_based, ner=False)
self.spacy = get_spacy_model(self._language, parse=not self._rule_based, ner=False)
self._is_version_3 = spacy.__version__ >= "3.0"
if rule_based:
# we use `sentencizer`, a built-in spacy module for rule-based sentence boundary detection.
Expand Down Expand Up @@ -77,3 +80,6 @@ def batch_split_sentences(self, texts: List[str]) -> List[List[str]]:
return [
[sentence.string.strip() for sentence in doc.sents] for doc in self.spacy.pipe(texts)
]

def _to_params(self) -> Dict[str, Any]:
return {"type": "spacy", "language": self._language, "rule_based": self._rule_based}
25 changes: 23 additions & 2 deletions allennlp/data/tokenizers/spacy_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,16 @@ def __init__(
start_tokens: Optional[List[str]] = None,
end_tokens: Optional[List[str]] = None,
) -> None:
self.spacy = get_spacy_model(language, pos_tags, parse, ner)
if split_on_spaces:
# Save these for use later in the _to_params method
self._language = language
self._pos_tags = pos_tags
self._parse = parse
self._ner = ner
self._split_on_spaces = split_on_spaces

self.spacy = get_spacy_model(self._language, self._pos_tags, self._parse, self._ner)

if self._split_on_spaces:
self.spacy.tokenizer = _WhitespaceSpacyTokenizer(self.spacy.vocab)

self._keep_spacy_tokens = keep_spacy_tokens
Expand Down Expand Up @@ -115,6 +123,19 @@ def tokenize(self, text: str) -> List[Token]:
# This works because our Token class matches spacy's.
return self._sanitize(_remove_spaces(self.spacy(text)))

def _to_params(self):
return {
"type": "spacy",
"language": self._language,
"pos_tags": self._pos_tags,
"parse": self._parse,
"ner": self._ner,
"keep_spacy_tokens": self._keep_spacy_tokens,
"split_on_spaces": self._split_on_spaces,
"start_tokens": self._start_tokens,
"end_tokens": self._end_tokens,
}


class _WhitespaceSpacyTokenizer:
"""
Expand Down
5 changes: 4 additions & 1 deletion allennlp/data/tokenizers/whitespace_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Dict, Any

from overrides import overrides

Expand All @@ -23,3 +23,6 @@ class WhitespaceTokenizer(Tokenizer):
@overrides
def tokenize(self, text: str) -> List[Token]:
return [Token(t) for t in text.split()]

def _to_params(self) -> Dict[str, Any]:
return {"type": "whitespace"}
69 changes: 66 additions & 3 deletions tests/common/registrable_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import os
from typing import List

import pytest

Expand All @@ -15,6 +16,14 @@
from allennlp.nn.regularizers.regularizer import Regularizer


@pytest.fixture()
def empty_registrable():
class EmptyRegistrable(Registrable):
pass

yield EmptyRegistrable


class TestRegistrable(AllenNlpTestCase):
def test_registrable_functionality_works(self):
# This function tests the basic `Registrable` functionality:
Expand All @@ -33,7 +42,6 @@ def test_registrable_functionality_works(self):

@base_class.register("fake")
class Fake(base_class):

pass

assert base_class.by_name("fake") == Fake
Expand All @@ -55,14 +63,12 @@ class Fake(base_class):

@base_class.register("fake")
class FakeAlternate(base_class):

pass

# Registering under a name that already exists should overwrite
# if exist_ok=True.
@base_class.register("fake", exist_ok=True) # noqa
class FakeAlternate2(base_class):

pass

assert base_class.by_name("fake") == FakeAlternate2
Expand Down Expand Up @@ -131,6 +137,63 @@ def test_implicit_include_package(self):
)
assert duplicate_reader.__name__ == "TextClassificationJsonReader"

def test_to_params_no_arguments(self, empty_registrable):
# Test how registrable disambiguates the class based on if there is no
# init function nor arguments.
@empty_registrable.register("no-args")
class NoArguments(empty_registrable):
pass

obj = NoArguments()
assert obj.to_params().params == {"type": "no-args"}

def test_to_params_no_pos_arguments(self, empty_registrable):
# Test how registrable disambiguates the _to_params when there is an
# init function but no positional arguments.
@empty_registrable.register("no-pos-args")
class NoPosArguments(empty_registrable):
def __init__(self, A: bool = None):
self.A = A

obj = NoPosArguments()
assert obj.to_params().params == {"type": "no-pos-args"}

def test_to_params_pos_arguments(self, empty_registrable):
# Test how registrable disambiguates the _to_params when there is an
# init function and positional arguments.
@empty_registrable.register("pos-args")
class PosArguments(empty_registrable):
def __init__(self, A: bool, B: int, C: List):
self.A = A
self._B = B
self._msg = C

obj = PosArguments(False, 5, [])
assert obj.to_params().params == {"type": "pos-args", "A": False, "B": 5}

def test_to_params_not_registered(self, empty_registrable):
# Test that Registrable raises an exception when the class called is
# not registered.
class NotRegistered(empty_registrable):
pass

obj = NotRegistered()
with pytest.raises(KeyError):
obj.to_params()

def test_to_params_nested(self, empty_registrable):
# Test how registrable disambiguates the _to_params when there is nested
# registrables.
class NestedBase(empty_registrable):
pass

@NestedBase.register("nested")
class NestedClass(NestedBase):
pass

obj = NestedClass()
assert obj.to_params().params == {"type": "nested"}


@pytest.mark.parametrize(
"name",
Expand Down
13 changes: 13 additions & 0 deletions tests/data/tokenizers/character_tokenizer_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from allennlp.common import Params
from allennlp.common.testing import AllenNlpTestCase
from allennlp.data.tokenizers import CharacterTokenizer

Expand Down Expand Up @@ -55,3 +56,15 @@ def test_handles_byte_encoding(self):
# Note that we've added one to the utf-8 encoded bytes, to account for masking.
expected_tokens = [259, 196, 166, 196, 185, 196, 163, 196, 162, 98, 99, 102, 260]
assert tokens == expected_tokens

def test_to_params(self):
tokenizer = CharacterTokenizer(byte_encoding="utf-8", start_tokens=[259], end_tokens=[260])
params = tokenizer.to_params()
assert isinstance(params, Params)
assert params.params == {
"type": "character",
"byte_encoding": "utf-8",
"end_tokens": [260],
"start_tokens": [259],
"lowercase_characters": False,
}
14 changes: 14 additions & 0 deletions tests/data/tokenizers/pretrained_transformer_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,17 @@ def test_from_params_kwargs(self):
PretrainedTransformerTokenizer.from_params(
Params({"model_name": "bert-base-uncased", "tokenizer_kwargs": {"max_len": 10}})
)

def test_to_params(self):
tokenizer = PretrainedTransformerTokenizer.from_params(
Params({"model_name": "bert-base-uncased", "tokenizer_kwargs": {"max_len": 10}})
)
params = tokenizer.to_params()
assert isinstance(params, Params)
assert params.params == {
"type": "pretrained_transformer",
"model_name": "bert-base-uncased",
"add_special_tokens": True,
"max_length": None,
"tokenizer_kwargs": {"max_len": 10, "use_fast": True},
}
Loading