Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit d7c06fe

Browse files
mulhodepwalsh
andauthored
Expose from_pretrained keyword arguments (#4651)
* Add ability to pass through transformers cache-related kwargs such as cache_dir and local_files_only * Add a couple tests for cached_transformers * Update CHANGELOG * Fix formatting * Apply suggestions * Add/fix tokenizers_kwargs/transformer_kwargs in a few places; add documentation wherever it occurs * Update CHANGELOG.md Co-authored-by: Evan Pete Walsh <[email protected]> * Update bert_pooler.py transformer_kwargs documentation * Apply suggestions from code review Co-authored-by: Evan Pete Walsh <[email protected]> * Remove test_from_pretrained_kwargs_local_files_only_missing_from_cache test * Use AllenNlpTestCase in cached_transformers_test.py Co-authored-by: Evan Pete Walsh <[email protected]>
1 parent 175c76b commit d7c06fe

File tree

9 files changed

+112
-23
lines changed

9 files changed

+112
-23
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717
by adding a class-level variable called `authorized_missing_keys` to any PyTorch module that a `Model` uses.
1818
If defined, `authorized_missing_keys` should be a list of regex string patterns.
1919
- Added `FBetaMultiLabelMeasure`, a multi-label Fbeta metric. This is a subclass of the existing `FBetaMeasure`.
20+
- Added ability to pass additional key word arguments to `cached_transformers.get()`, which will be passed on to `AutoModel.from_pretrained()`.
21+
- Added an `overrides` argument to `Predictor.from_path()`.
2022

2123
### Changed
2224

allennlp/common/cached_transformers.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def get(
2121
make_copy: bool,
2222
override_weights_file: Optional[str] = None,
2323
override_weights_strip_prefix: Optional[str] = None,
24+
**kwargs,
2425
) -> transformers.PreTrainedModel:
2526
"""
2627
Returns a transformer model from the cache.
@@ -74,9 +75,16 @@ def strip_prefix(s):
7475
)
7576
override_weights = {strip_prefix(k): override_weights[k] for k in valid_keys}
7677

77-
transformer = AutoModel.from_pretrained(model_name, state_dict=override_weights)
78+
transformer = AutoModel.from_pretrained(
79+
model_name,
80+
state_dict=override_weights,
81+
**kwargs,
82+
)
7883
else:
79-
transformer = AutoModel.from_pretrained(model_name)
84+
transformer = AutoModel.from_pretrained(
85+
model_name,
86+
**kwargs,
87+
)
8088
_model_cache[spec] = transformer
8189
if make_copy:
8290
import copy
@@ -95,6 +103,9 @@ def get_tokenizer(model_name: str, **kwargs) -> transformers.PreTrainedTokenizer
95103
global _tokenizer_cache
96104
tokenizer = _tokenizer_cache.get(cache_key, None)
97105
if tokenizer is None:
98-
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, **kwargs)
106+
tokenizer = transformers.AutoTokenizer.from_pretrained(
107+
model_name,
108+
**kwargs,
109+
)
99110
_tokenizer_cache[cache_key] = tokenizer
100111
return tokenizer

allennlp/data/token_indexers/pretrained_transformer_indexer.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Optional, Tuple
1+
from typing import Dict, List, Optional, Tuple, Any
22
import logging
33
import torch
44
from allennlp.common.util import pad_sequence_to_length
@@ -38,14 +38,25 @@ class PretrainedTransformerIndexer(TokenIndexer):
3838
before feeding into the embedder. The embedder embeds these segments independently and
3939
concatenate the results to get the original document representation. Should be set to
4040
the same value as the `max_length` option on the `PretrainedTransformerEmbedder`.
41-
"""
41+
tokenizer_kwargs : `Dict[str, Any]`, optional (default = `None`)
42+
Dictionary with
43+
[additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691)
44+
for `AutoTokenizer.from_pretrained`.
45+
""" # noqa: E501
4246

4347
def __init__(
44-
self, model_name: str, namespace: str = "tags", max_length: int = None, **kwargs
48+
self,
49+
model_name: str,
50+
namespace: str = "tags",
51+
max_length: int = None,
52+
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
53+
**kwargs,
4554
) -> None:
4655
super().__init__(**kwargs)
4756
self._namespace = namespace
48-
self._allennlp_tokenizer = PretrainedTransformerTokenizer(model_name)
57+
self._allennlp_tokenizer = PretrainedTransformerTokenizer(
58+
model_name, tokenizer_kwargs=tokenizer_kwargs
59+
)
4960
self._tokenizer = self._allennlp_tokenizer.tokenizer
5061
self._added_to_vocabulary = False
5162

allennlp/data/tokenizers/pretrained_transformer_tokenizer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ class PretrainedTransformerTokenizer(Tokenizer):
5151
- 'only_first': Only truncate the first sequence
5252
- 'only_second': Only truncate the second sequence
5353
- 'do_not_truncate': Do not truncate (raise an error if the input sequence is longer than max_length)
54-
tokenizer_kwargs: `Dict[str, Any]`
54+
tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`)
5555
Dictionary with
5656
[additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691)
5757
for `AutoTokenizer.from_pretrained`.
58-
5958
""" # noqa: E501
6059

6160
def __init__(

allennlp/models/archival.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tempfile
99
import tarfile
1010
import shutil
11+
from pathlib import Path
1112

1213
from torch.nn import Module
1314

@@ -129,7 +130,7 @@ def archive_model(
129130

130131

131132
def load_archive(
132-
archive_file: str,
133+
archive_file: Union[str, Path],
133134
cuda_device: int = -1,
134135
overrides: str = "",
135136
weights_file: str = None,
@@ -139,7 +140,7 @@ def load_archive(
139140
140141
# Parameters
141142
142-
archive_file : `str`
143+
archive_file : `Union[str, Path]`
143144
The archive file to load the model from.
144145
cuda_device : `int`, optional (default = `-1`)
145146
If `cuda_device` is >= 0, the model will be loaded onto the

allennlp/modules/seq2vec_encoders/bert_pooler.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Dict, Any
22

33
from overrides import overrides
44

@@ -31,7 +31,11 @@ class BertPooler(Seq2VecEncoder):
3131
Otherwise they will not.
3232
dropout : `float`, optional, (default = `0.0`)
3333
Amount of dropout to apply after pooling
34-
"""
34+
transformer_kwargs: `Dict[str, Any]`, optional (default = `None`)
35+
Dictionary with
36+
[additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/modeling_utils.py#L253)
37+
for `AutoModel.from_pretrained`.
38+
""" # noqa: E501
3539

3640
def __init__(
3741
self,
@@ -40,14 +44,19 @@ def __init__(
4044
override_weights_file: Optional[str] = None,
4145
override_weights_strip_prefix: Optional[str] = None,
4246
requires_grad: bool = True,
43-
dropout: float = 0.0
47+
dropout: float = 0.0,
48+
transformer_kwargs: Optional[Dict[str, Any]] = None,
4449
) -> None:
4550
super().__init__()
4651

4752
from allennlp.common import cached_transformers
4853

4954
model = cached_transformers.get(
50-
pretrained_model, False, override_weights_file, override_weights_strip_prefix
55+
pretrained_model,
56+
False,
57+
override_weights_file,
58+
override_weights_strip_prefix,
59+
**(transformer_kwargs or {}),
5160
)
5261

5362
self._dropout = torch.nn.Dropout(p=dropout)

allennlp/modules/token_embedders/pretrained_transformer_embedder.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Optional, Tuple
2+
from typing import Optional, Tuple, Dict, Any
33

44
from overrides import overrides
55

@@ -42,7 +42,15 @@ class PretrainedTransformerEmbedder(TokenEmbedder):
4242
is used.
4343
gradient_checkpointing: `bool`, optional (default = `None`)
4444
Enable or disable gradient checkpointing.
45-
"""
45+
tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`)
46+
Dictionary with
47+
[additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691)
48+
for `AutoTokenizer.from_pretrained`.
49+
transformer_kwargs: `Dict[str, Any]`, optional (default = `None`)
50+
Dictionary with
51+
[additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/modeling_utils.py#L253)
52+
for `AutoModel.from_pretrained`.
53+
""" # noqa: E501
4654

4755
authorized_missing_keys = [r"position_ids$"]
4856

@@ -57,12 +65,18 @@ def __init__(
5765
override_weights_file: Optional[str] = None,
5866
override_weights_strip_prefix: Optional[str] = None,
5967
gradient_checkpointing: Optional[bool] = None,
68+
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
69+
transformer_kwargs: Optional[Dict[str, Any]] = None,
6070
) -> None:
6171
super().__init__()
6272
from allennlp.common import cached_transformers
6373

6474
self.transformer_model = cached_transformers.get(
65-
model_name, True, override_weights_file, override_weights_strip_prefix
75+
model_name,
76+
True,
77+
override_weights_file=override_weights_file,
78+
override_weights_strip_prefix=override_weights_strip_prefix,
79+
**(transformer_kwargs or {}),
6680
)
6781

6882
if gradient_checkpointing is not None:
@@ -83,7 +97,10 @@ def __init__(
8397
self._scalar_mix = ScalarMix(self.config.num_hidden_layers)
8498
self.config.output_hidden_states = True
8599

86-
tokenizer = PretrainedTransformerTokenizer(model_name)
100+
tokenizer = PretrainedTransformerTokenizer(
101+
model_name,
102+
tokenizer_kwargs=tokenizer_kwargs,
103+
)
87104
self._num_added_start_tokens = len(tokenizer.single_sequence_start_tokens)
88105
self._num_added_end_tokens = len(tokenizer.single_sequence_end_tokens)
89106
self._num_added_tokens = self._num_added_start_tokens + self._num_added_end_tokens

allennlp/predictors/predictor.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import List, Iterator, Dict, Tuple, Any, Type
1+
from typing import List, Iterator, Dict, Tuple, Any, Type, Union
22
import json
33
import re
44
from contextlib import contextmanager
5+
from pathlib import Path
56

67
import numpy
78
from torch.utils.hooks import RemovableHandle
@@ -232,12 +233,13 @@ def _batch_json_to_instances(self, json_dicts: List[JsonDict]) -> List[Instance]
232233
@classmethod
233234
def from_path(
234235
cls,
235-
archive_path: str,
236+
archive_path: Union[str, Path],
236237
predictor_name: str = None,
237238
cuda_device: int = -1,
238239
dataset_reader_to_load: str = "validation",
239240
frozen: bool = True,
240241
import_plugins: bool = True,
242+
overrides: str = "",
241243
) -> "Predictor":
242244
"""
243245
Instantiate a `Predictor` from an archive path.
@@ -247,7 +249,7 @@ def from_path(
247249
248250
# Parameters
249251
250-
archive_path : `str`
252+
archive_path : `Union[str, Path]`
251253
The path to the archive.
252254
predictor_name : `str`, optional (default=`None`)
253255
Name that the predictor is registered as, or None to use the
@@ -265,6 +267,8 @@ def from_path(
265267
This comes with additional overhead, but means you don't need to explicitly
266268
import the modules that your predictor depends on as long as those modules
267269
can be found by `allennlp.common.plugins.import_plugins()`.
270+
overrides : `str`, optional (default = `""`)
271+
JSON overrides to apply to the unarchived `Params` object.
268272
269273
# Returns
270274
@@ -274,7 +278,7 @@ def from_path(
274278
if import_plugins:
275279
plugins.import_plugins()
276280
return Predictor.from_archive(
277-
load_archive(archive_path, cuda_device=cuda_device),
281+
load_archive(archive_path, cuda_device=cuda_device, overrides=overrides),
278282
predictor_name,
279283
dataset_reader_to_load=dataset_reader_to_load,
280284
frozen=frozen,
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pytest
2+
3+
from allennlp.common import cached_transformers
4+
from allennlp.common.testing import AllenNlpTestCase
5+
6+
7+
class TestCachedTransformers(AllenNlpTestCase):
8+
def test_get_missing_from_cache_local_files_only(self):
9+
with pytest.raises(ValueError) as execinfo:
10+
cached_transformers.get(
11+
"bert-base-uncased",
12+
True,
13+
cache_dir=self.TEST_DIR,
14+
local_files_only=True,
15+
)
16+
assert str(execinfo.value) == (
17+
"Cannot find the requested files in the cached path and "
18+
"outgoing traffic has been disabled. To enable model "
19+
"look-ups and downloads online, set 'local_files_only' "
20+
"to False."
21+
)
22+
23+
def test_get_tokenizer_missing_from_cache_local_files_only(self):
24+
with pytest.raises(ValueError) as execinfo:
25+
cached_transformers.get_tokenizer(
26+
"bert-base-uncased",
27+
cache_dir=self.TEST_DIR,
28+
local_files_only=True,
29+
)
30+
assert str(execinfo.value) == (
31+
"Cannot find the requested files in the cached path and "
32+
"outgoing traffic has been disabled. To enable model "
33+
"look-ups and downloads online, set 'local_files_only' "
34+
"to False."
35+
)

0 commit comments

Comments
 (0)