Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 9dabf3f

Browse files
mulhodepwalsh
andauthored
Add missing tokenizer/transformer kwargs (#4682)
* Add tokenizer_kwargs in PretrainedTransformerMismatchedIndexer and tokenizer_kwargs/transformer_kwargs in PretrainedTransformerMismatchedEmbedder * Update allennlp/data/token_indexers/pretrained_transformer_mismatched_indexer.py Co-authored-by: Evan Pete Walsh <[email protected]> * Update CHANGELOG.md Co-authored-by: Evan Pete Walsh <[email protected]> Co-authored-by: Evan Pete Walsh <[email protected]>
1 parent 9ac6c76 commit 9dabf3f

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12+
- Added `tokenizer_kwargs` argument to `PretrainedTransformerMismatchedIndexer`.
13+
- Added `tokenizer_kwargs` and `transformer_kwargs` arguments to `PretrainedTransformerMismatchedEmbedder`.
1214
- Added official support for Python 3.8.
1315
- Added a script: `scripts/release_notes.py`, which automatically prepares markdown release notes from the
1416
CHANGELOG and commit history.

allennlp/data/token_indexers/pretrained_transformer_mismatched_indexer.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List
1+
from typing import Dict, List, Any, Optional
22
import logging
33

44
from overrides import overrides
@@ -39,15 +39,28 @@ class PretrainedTransformerMismatchedIndexer(TokenIndexer):
3939
before feeding into the embedder. The embedder embeds these segments independently and
4040
concatenate the results to get the original document representation. Should be set to
4141
the same value as the `max_length` option on the `PretrainedTransformerMismatchedEmbedder`.
42-
"""
42+
tokenizer_kwargs : `Dict[str, Any]`, optional (default = `None`)
43+
Dictionary with
44+
[additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691)
45+
for `AutoTokenizer.from_pretrained`.
46+
""" # noqa: E501
4347

4448
def __init__(
45-
self, model_name: str, namespace: str = "tags", max_length: int = None, **kwargs
49+
self,
50+
model_name: str,
51+
namespace: str = "tags",
52+
max_length: int = None,
53+
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
54+
**kwargs,
4655
) -> None:
4756
super().__init__(**kwargs)
4857
# The matched version v.s. mismatched
4958
self._matched_indexer = PretrainedTransformerIndexer(
50-
model_name, namespace, max_length, **kwargs
59+
model_name,
60+
namespace=namespace,
61+
max_length=max_length,
62+
tokenizer_kwargs=tokenizer_kwargs,
63+
**kwargs,
5164
)
5265
self._allennlp_tokenizer = self._matched_indexer._allennlp_tokenizer
5366
self._tokenizer = self._matched_indexer._tokenizer

allennlp/modules/token_embedders/pretrained_transformer_mismatched_embedder.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Dict, Any
22

33
from overrides import overrides
44
import torch
@@ -33,7 +33,15 @@ class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
3333
is used.
3434
gradient_checkpointing: `bool`, optional (default = `None`)
3535
Enable or disable gradient checkpointing.
36-
"""
36+
tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`)
37+
Dictionary with
38+
[additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691)
39+
for `AutoTokenizer.from_pretrained`.
40+
transformer_kwargs: `Dict[str, Any]`, optional (default = `None`)
41+
Dictionary with
42+
[additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/modeling_utils.py#L253)
43+
for `AutoModel.from_pretrained`.
44+
""" # noqa: E501
3745

3846
def __init__(
3947
self,
@@ -42,6 +50,8 @@ def __init__(
4250
train_parameters: bool = True,
4351
last_layer_only: bool = True,
4452
gradient_checkpointing: Optional[bool] = None,
53+
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
54+
transformer_kwargs: Optional[Dict[str, Any]] = None,
4555
) -> None:
4656
super().__init__()
4757
# The matched version v.s. mismatched
@@ -51,6 +61,8 @@ def __init__(
5161
train_parameters=train_parameters,
5262
last_layer_only=last_layer_only,
5363
gradient_checkpointing=gradient_checkpointing,
64+
tokenizer_kwargs=tokenizer_kwargs,
65+
transformer_kwargs=transformer_kwargs,
5466
)
5567

5668
@overrides

0 commit comments

Comments
 (0)