|
| 1 | +import logging |
| 2 | +from typing import List, Tuple |
| 3 | + |
| 4 | +from overrides import overrides |
| 5 | +from pytorch_transformers.tokenization_auto import AutoTokenizer |
| 6 | + |
| 7 | +from allennlp.data.tokenizers.token import Token |
| 8 | +from allennlp.data.tokenizers.tokenizer import Tokenizer |
| 9 | + |
| 10 | +logger = logging.getLogger(__name__) |
| 11 | + |
| 12 | + |
| 13 | +@Tokenizer.register("pretrained_transformer") |
| 14 | +class PretrainedTransformerTokenizer(Tokenizer): |
| 15 | + """ |
| 16 | + A ``PretrainedTransformerTokenizer`` uses a model from HuggingFace's |
| 17 | + ``pytorch_transformers`` library to tokenize some input text. This often means wordpieces |
| 18 | + (where ``'AllenNLP is awesome'`` might get split into ``['Allen', '##NL', '##P', 'is', |
| 19 | + 'awesome']``), but it could also use byte-pair encoding, or some other tokenization, depending |
| 20 | + on the pretrained model that you're using. |
| 21 | +
|
| 22 | + We take a model name as an input parameter, which we will pass to |
| 23 | + ``AutoTokenizer.from_pretrained``. |
| 24 | +
|
| 25 | + Parameters |
| 26 | + ---------- |
| 27 | + model_name : ``str`` |
| 28 | + The name of the pretrained wordpiece tokenizer to use. |
| 29 | + start_tokens : ``List[str]``, optional |
| 30 | + If given, these tokens will be added to the beginning of every string we tokenize. We try |
| 31 | + to be a little bit smart about defaults here - e.g., if your model name contains ``bert``, |
| 32 | + we by default add ``[CLS]`` at the beginning and ``[SEP]`` at the end. |
| 33 | + end_tokens : ``List[str]``, optional |
| 34 | + If given, these tokens will be added to the end of every string we tokenize. |
| 35 | + """ |
| 36 | + def __init__(self, |
| 37 | + model_name: str, |
| 38 | + do_lowercase: bool, |
| 39 | + start_tokens: List[str] = None, |
| 40 | + end_tokens: List[str] = None) -> None: |
| 41 | + if model_name.endswith("-cased") and do_lowercase: |
| 42 | + logger.warning("Your pretrained model appears to be cased, " |
| 43 | + "but your tokenizer is lowercasing tokens.") |
| 44 | + elif model_name.endswith("-uncased") and not do_lowercase: |
| 45 | + logger.warning("Your pretrained model appears to be uncased, " |
| 46 | + "but your tokenizer is not lowercasing tokens.") |
| 47 | + self._tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=do_lowercase) |
| 48 | + default_start_tokens, default_end_tokens = _guess_start_and_end_token_defaults(model_name) |
| 49 | + self._start_tokens = start_tokens if start_tokens is not None else default_start_tokens |
| 50 | + self._end_tokens = end_tokens if end_tokens is not None else default_end_tokens |
| 51 | + |
| 52 | + @overrides |
| 53 | + def tokenize(self, text: str) -> List[Token]: |
| 54 | + # TODO(mattg): track character offsets. Might be too challenging to do it here, given that |
| 55 | + # pytorch-transformers is dealing with the whitespace... |
| 56 | + token_strings = self._start_tokens + self._tokenizer.tokenize(text) + self._end_tokens |
| 57 | + return [Token(t) for t in token_strings] |
| 58 | + |
| 59 | + |
| 60 | +def _guess_start_and_end_token_defaults(model_name: str) -> Tuple[List[str], List[str]]: |
| 61 | + if 'bert' in model_name: |
| 62 | + return (['[CLS]'], ['[SEP]']) |
| 63 | + else: |
| 64 | + return ([], []) |
0 commit comments