Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 217022f

Browse files
authored
Adding a PretrainedTransformerTokenizer (#3145)
* Adding a PretrainedTransformerTokenizer * pylint * doc
1 parent f9e2029 commit 217022f

8 files changed

+91
-6
lines changed

allennlp/data/tokenizers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55

66
from allennlp.data.tokenizers.tokenizer import Token, Tokenizer
77
from allennlp.data.tokenizers.word_tokenizer import WordTokenizer
8+
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
89
from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer
910
from allennlp.data.tokenizers.sentence_splitter import SentenceSplitter

allennlp/data/tokenizers/character_tokenizer.py

-4
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ def __init__(self,
4646
self._start_tokens.reverse()
4747
self._end_tokens = end_tokens or []
4848

49-
@overrides
50-
def batch_tokenize(self, texts: List[str]) -> List[List[Token]]:
51-
return [self.tokenize(text) for text in texts]
52-
5349
@overrides
5450
def tokenize(self, text: str) -> List[Token]:
5551
if self._lowercase_characters:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import logging
2+
from typing import List, Tuple
3+
4+
from overrides import overrides
5+
from pytorch_transformers.tokenization_auto import AutoTokenizer
6+
7+
from allennlp.data.tokenizers.token import Token
8+
from allennlp.data.tokenizers.tokenizer import Tokenizer
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
@Tokenizer.register("pretrained_transformer")
14+
class PretrainedTransformerTokenizer(Tokenizer):
15+
"""
16+
A ``PretrainedTransformerTokenizer`` uses a model from HuggingFace's
17+
``pytorch_transformers`` library to tokenize some input text. This often means wordpieces
18+
(where ``'AllenNLP is awesome'`` might get split into ``['Allen', '##NL', '##P', 'is',
19+
'awesome']``), but it could also use byte-pair encoding, or some other tokenization, depending
20+
on the pretrained model that you're using.
21+
22+
We take a model name as an input parameter, which we will pass to
23+
``AutoTokenizer.from_pretrained``.
24+
25+
Parameters
26+
----------
27+
model_name : ``str``
28+
The name of the pretrained wordpiece tokenizer to use.
29+
start_tokens : ``List[str]``, optional
30+
If given, these tokens will be added to the beginning of every string we tokenize. We try
31+
to be a little bit smart about defaults here - e.g., if your model name contains ``bert``,
32+
we by default add ``[CLS]`` at the beginning and ``[SEP]`` at the end.
33+
end_tokens : ``List[str]``, optional
34+
If given, these tokens will be added to the end of every string we tokenize.
35+
"""
36+
def __init__(self,
37+
model_name: str,
38+
do_lowercase: bool,
39+
start_tokens: List[str] = None,
40+
end_tokens: List[str] = None) -> None:
41+
if model_name.endswith("-cased") and do_lowercase:
42+
logger.warning("Your pretrained model appears to be cased, "
43+
"but your tokenizer is lowercasing tokens.")
44+
elif model_name.endswith("-uncased") and not do_lowercase:
45+
logger.warning("Your pretrained model appears to be uncased, "
46+
"but your tokenizer is not lowercasing tokens.")
47+
self._tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=do_lowercase)
48+
default_start_tokens, default_end_tokens = _guess_start_and_end_token_defaults(model_name)
49+
self._start_tokens = start_tokens if start_tokens is not None else default_start_tokens
50+
self._end_tokens = end_tokens if end_tokens is not None else default_end_tokens
51+
52+
@overrides
53+
def tokenize(self, text: str) -> List[Token]:
54+
# TODO(mattg): track character offsets. Might be too challenging to do it here, given that
55+
# pytorch-transformers is dealing with the whitespace...
56+
token_strings = self._start_tokens + self._tokenizer.tokenize(text) + self._end_tokens
57+
return [Token(t) for t in token_strings]
58+
59+
60+
def _guess_start_and_end_token_defaults(model_name: str) -> Tuple[List[str], List[str]]:
61+
if 'bert' in model_name:
62+
return (['[CLS]'], ['[SEP]'])
63+
else:
64+
return ([], [])

allennlp/data/tokenizers/tokenizer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ def batch_tokenize(self, texts: List[str]) -> List[List[Token]]:
2727
"""
2828
Batches together tokenization of several texts, in case that is faster for particular
2929
tokenizers.
30+
31+
By default we just do this without batching. Override this in your tokenizer if you have a
32+
good way of doing batched computation.
3033
"""
31-
raise NotImplementedError
34+
return [self.tokenize(text) for text in texts]
3235

3336
def tokenize(self, text: str) -> List[Token]:
3437
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# pylint: disable=no-self-use,invalid-name
2+
3+
from allennlp.common.testing import AllenNlpTestCase
4+
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
5+
6+
class TestPretrainedTransformerTokenizer(AllenNlpTestCase):
7+
def test_splits_into_wordpieces(self):
8+
tokenizer = PretrainedTransformerTokenizer('bert-base-cased', do_lowercase=False)
9+
sentence = "A, [MASK] AllenNLP sentence."
10+
tokens = [t.text for t in tokenizer.tokenize(sentence)]
11+
expected_tokens = ["[CLS]", "A", ",", "[MASK]", "Allen", "##NL", "##P", "sentence", ".", "[SEP]"]
12+
assert tokens == expected_tokens

doc/api/allennlp.data.tokenizers.rst

+8-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ allennlp.data.tokenizers
1414
* :ref:`Tokenizer<tokenizer>`
1515
* :ref:`WordTokenizer<word-tokenizer>`
1616
* :ref:`CharacterTokenizer<character-tokenizer>`
17+
* :ref:`PretrainedTransformerTokenizer<pretrained-transformer-tokenizer>`
1718
* :ref:`WordFilter<word-filter>`
1819
* :ref:`WordSplitter<word-splitter>`
1920
* :ref:`WordStemmer<word-stemmer>`
@@ -36,6 +37,12 @@ allennlp.data.tokenizers
3637
:undoc-members:
3738
:show-inheritance:
3839

40+
.. _pretrained-transformer-tokenizer:
41+
.. automodule:: allennlp.data.tokenizers.pretrained_transformer_tokenizer
42+
:members:
43+
:undoc-members:
44+
:show-inheritance:
45+
3946
.. _word-filter:
4047
.. automodule:: allennlp.data.tokenizers.word_filter
4148
:members:
@@ -58,4 +65,4 @@ allennlp.data.tokenizers
5865
.. automodule:: allennlp.data.tokenizers.sentence_splitter
5966
:members:
6067
:undoc-members:
61-
:show-inheritance:
68+
:show-inheritance:

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ word2number>=1.1
6161

6262
# To use the BERT model
6363
pytorch-pretrained-bert>=0.6.0
64+
git+git://github.com/huggingface/pytorch-transformers.git@a7b4cfe9194bf93c7044a42c9f1281260ce6279e
6465

6566
# For caching processed data
6667
jsonpickle

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
'sqlparse>=0.2.4',
132132
'word2number>=1.1',
133133
'pytorch-pretrained-bert>=0.6.0',
134+
'pytorch-transformers @ https://api.github.com/repos/huggingface/pytorch-transformers/tarball/a7b4cfe9194bf93c7044a42c9f1281260ce6279e',
134135
'jsonpickle',
135136
],
136137
entry_points={

0 commit comments

Comments
 (0)