Skip to content

Commit 8e908c8

Browse files
[AutoTokenizer] Allow creation of tokenizers by tokenizer type (#13668)
* up * up
1 parent 2608944 commit 8e908c8

File tree

5 files changed

+81
-1
lines changed

5 files changed

+81
-1
lines changed

src/transformers/models/auto/tokenization_auto.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
402402
facebook/rag-token-base), specify it here.
403403
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
404404
Whether or not to try to load the fast version of the tokenizer.
405+
tokenizer_type (:obj:`str`, `optional`):
406+
Tokenizer type to be loaded.
405407
kwargs (additional keyword arguments, `optional`):
406408
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
407409
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
@@ -425,8 +427,33 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
425427
kwargs["_from_auto"] = True
426428

427429
use_fast = kwargs.pop("use_fast", True)
430+
tokenizer_type = kwargs.pop("tokenizer_type", None)
428431

429-
# First, let's try to use the tokenizer_config file to get the tokenizer class.
432+
# First, let's see whether the tokenizer_type is passed so that we can leverage it
433+
if tokenizer_type is not None:
434+
tokenizer_class = None
435+
tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)
436+
437+
if tokenizer_class_tuple is None:
438+
raise ValueError(
439+
f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
440+
f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}."
441+
)
442+
443+
tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple
444+
445+
if use_fast and tokenizer_fast_class_name is not None:
446+
tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)
447+
448+
if tokenizer_class is None:
449+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)
450+
451+
if tokenizer_class is None:
452+
raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")
453+
454+
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
455+
456+
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
430457
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
431458
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
432459

tests/fixtures/merges.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#version: 0.2
2+
Ġ l
3+
Ġl o
4+
Ġlo w
5+
e r

tests/fixtures/vocab.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"l": 0, "o": 1, "w": 2, "e": 3, "r": 4, "s": 5, "t": 6, "i": 7, "d": 8, "n": 9, "Ġ": 10, "Ġl": 11, "Ġn": 12, "Ġlo": 13, "Ġlow": 14, "er": 15, "Ġlowest": 16, "Ġnewer": 17, "Ġwider": 18, "<unk>": 19, "<|endoftext|>": 20}

tests/fixtures/vocab.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[PAD]
2+
[SEP]
3+
[MASK]
4+
[CLS]
5+
[unused3]
6+
[unused4]
7+
[unused5]
8+
[unused6]
9+
[unused7]
10+
[unused8]

tests/test_tokenization_auto.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import os
17+
import shutil
1618
import tempfile
1719
import unittest
1820

21+
import pytest
22+
1923
from transformers import (
2024
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
2125
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
@@ -78,6 +82,39 @@ def test_tokenizer_from_tokenizer_class(self):
7882
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
7983
self.assertEqual(tokenizer.vocab_size, 12)
8084

85+
def test_tokenizer_from_type(self):
86+
with tempfile.TemporaryDirectory() as tmp_dir:
87+
shutil.copy("./tests/fixtures/vocab.txt", os.path.join(tmp_dir, "vocab.txt"))
88+
89+
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="bert", use_fast=False)
90+
self.assertIsInstance(tokenizer, BertTokenizer)
91+
92+
with tempfile.TemporaryDirectory() as tmp_dir:
93+
shutil.copy("./tests/fixtures/vocab.json", os.path.join(tmp_dir, "vocab.json"))
94+
shutil.copy("./tests/fixtures/merges.txt", os.path.join(tmp_dir, "merges.txt"))
95+
96+
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="gpt2", use_fast=False)
97+
self.assertIsInstance(tokenizer, GPT2Tokenizer)
98+
99+
@require_tokenizers
100+
def test_tokenizer_from_type_fast(self):
101+
with tempfile.TemporaryDirectory() as tmp_dir:
102+
shutil.copy("./tests/fixtures/vocab.txt", os.path.join(tmp_dir, "vocab.txt"))
103+
104+
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="bert")
105+
self.assertIsInstance(tokenizer, BertTokenizerFast)
106+
107+
with tempfile.TemporaryDirectory() as tmp_dir:
108+
shutil.copy("./tests/fixtures/vocab.json", os.path.join(tmp_dir, "vocab.json"))
109+
shutil.copy("./tests/fixtures/merges.txt", os.path.join(tmp_dir, "merges.txt"))
110+
111+
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="gpt2")
112+
self.assertIsInstance(tokenizer, GPT2TokenizerFast)
113+
114+
def test_tokenizer_from_type_incorrect_name(self):
115+
with pytest.raises(ValueError):
116+
AutoTokenizer.from_pretrained("./", tokenizer_type="xxx")
117+
81118
@require_tokenizers
82119
def test_tokenizer_identifier_with_correct_config(self):
83120
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:

0 commit comments

Comments
 (0)