diff --git a/CHANGELOG.md b/CHANGELOG.md index e8ff0d1e8f6..a50c08703b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - The activation layer in the transformer toolkit now can be queried for its output dimension. - `TransformerEmbeddings` now takes, but ignores, a parameter for the attention mask. This is needed for compatibility with some other modules that get called the same way and use the mask. - `TransformerPooler` can now be instantiated from a pretrained transformer module, just like the other modules in the transformer toolkit. +- `TransformerTextField`, for cases where you don't care about AllenNLP's advanced text handling capabilities. - Added `TransformerModule._post_load_pretrained_state_dict_hook()` method. Can be used to modify `missing_keys` and `unexpected_keys` after loading a pretrained state dictionary. This is useful when tying weights, for example. diff --git a/allennlp/data/fields/__init__.py b/allennlp/data/fields/__init__.py index fa01eac7367..134865afce1 100644 --- a/allennlp/data/fields/__init__.py +++ b/allennlp/data/fields/__init__.py @@ -18,3 +18,4 @@ from allennlp.data.fields.span_field import SpanField from allennlp.data.fields.text_field import TextField from allennlp.data.fields.array_field import ArrayField +from allennlp.data.fields.transformer_text_field import TransformerTextField diff --git a/allennlp/data/fields/transformer_text_field.py b/allennlp/data/fields/transformer_text_field.py new file mode 100644 index 00000000000..3a64a9425f1 --- /dev/null +++ b/allennlp/data/fields/transformer_text_field.py @@ -0,0 +1,106 @@ +from typing import Dict, Optional, List, Any + +from overrides import overrides +import torch +import torch.nn.functional + +from allennlp.data.fields.field import Field +from allennlp.nn import util + + +class TransformerTextField(Field[torch.Tensor]): + """ + A `TransformerTextField` is a collection of several tensors that are are a representation of text, + tokenized and ready to become input to a transformer. + + The naming pattern of the tensors follows the pattern that's produced by the huggingface tokenizers, + and expected by the huggingface transformers. + """ + + __slots__ = [ + "input_ids", + "token_type_ids", + "attention_mask", + "special_tokens_mask", + "offsets_mapping", + "padding_token_id", + ] + + def __init__( + self, + input_ids: torch.Tensor, + # I wish input_ids were called `token_ids` for clarity, but we want to be compatible with huggingface. + token_type_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + special_tokens_mask: Optional[torch.Tensor] = None, + offsets_mapping: Optional[torch.Tensor] = None, + padding_token_id: int = 0, + ) -> None: + self.input_ids = input_ids + self.token_type_ids = token_type_ids + self.attention_mask = attention_mask + self.special_tokens_mask = special_tokens_mask + self.offsets_mapping = offsets_mapping + self.padding_token_id = padding_token_id + + @overrides + def get_padding_lengths(self) -> Dict[str, int]: + return { + name: len(getattr(self, name)) + for name in self.__slots__ + if isinstance(getattr(self, name), torch.Tensor) + } + + @overrides + def as_tensor(self, padding_lengths: Dict[str, int]) -> Dict[str, torch.Tensor]: + result = {} + for name, padding_length in padding_lengths.items(): + tensor = getattr(self, name) + result[name] = torch.nn.functional.pad( + tensor, + (0, padding_length - len(tensor)), + value=self.padding_token_id if name == "token_ids" else 0, + ) + if "attention_mask" not in result: + result["attention_mask"] = torch.tensor( + [True] * len(self.input_ids) + + [False] * (padding_lengths["input_ids"] - len(self.input_ids)), + dtype=torch.bool, + ) + return result + + @overrides + def empty_field(self): + return TransformerTextField(torch.LongTensor(), padding_token_id=self.padding_token_id) + + @overrides + def batch_tensors(self, tensor_list: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + result: Dict[str, torch.Tensor] = util.batch_tensor_dicts(tensor_list) + # Transformer models need LongTensors for indices, just in case we have more than 2 billion + # different tokens. To save space, we make the switch as late as possible, i.e., here. + result = { + name: t.to(torch.int64) if t.dtype == torch.int32 else t for name, t in result.items() + } + return result + + def human_readable_repr(self) -> Dict[str, Any]: + def format_item(x) -> str: + return str(x.item()) + + def readable_tensor(t: torch.Tensor) -> str: + if len(t) <= 16: + return "[" + ", ".join(map(format_item, t)) + "]" + else: + return ( + "[" + + ", ".join(map(format_item, t[:8])) + + ", ..., " + + ", ".join(map(format_item, t[-8:])) + + "]" + ) + + return { + name: readable_tensor(getattr(self, name)) + for name in self.__slots__ + if isinstance(getattr(self, name), torch.Tensor) + } diff --git a/tests/data/fields/transformer_text_field_test.py b/tests/data/fields/transformer_text_field_test.py new file mode 100644 index 00000000000..cecb60e454a --- /dev/null +++ b/tests/data/fields/transformer_text_field_test.py @@ -0,0 +1,40 @@ +import torch + +from allennlp.data import Batch, Instance +from allennlp.data.fields import TransformerTextField + + +def test_transformer_text_field_init(): + field = TransformerTextField(torch.IntTensor([1, 2, 3])) + field_as_tensor = field.as_tensor(field.get_padding_lengths()) + assert "input_ids" in field_as_tensor + assert "attention_mask" in field_as_tensor + assert torch.all(field_as_tensor["attention_mask"] == torch.BoolTensor([True, True, True])) + assert torch.all(field_as_tensor["input_ids"] == torch.IntTensor([1, 2, 3])) + + +def test_empty_transformer_text_field(): + field = TransformerTextField(torch.IntTensor([]), padding_token_id=9) + field = field.empty_field() + assert isinstance(field, TransformerTextField) and field.padding_token_id == 9 + field_as_tensor = field.as_tensor(field.get_padding_lengths()) + assert "input_ids" in field_as_tensor + assert "attention_mask" in field_as_tensor + assert torch.all(field_as_tensor["attention_mask"] == torch.BoolTensor([])) + assert torch.all(field_as_tensor["input_ids"] == torch.IntTensor([])) + + +def test_transformer_text_field_batching(): + batch = Batch( + [ + Instance({"text": TransformerTextField(torch.IntTensor([1, 2, 3]))}), + Instance({"text": TransformerTextField(torch.IntTensor([2, 3, 4, 5]))}), + Instance({"text": TransformerTextField(torch.IntTensor())}), + ] + ) + tensors = batch.as_tensor_dict(batch.get_padding_lengths()) + assert tensors["text"]["input_ids"].shape == (3, 4) + assert tensors["text"]["input_ids"][0, -1] == 0 + assert tensors["text"]["attention_mask"][0, -1] == torch.Tensor([False]) + assert torch.all(tensors["text"]["input_ids"][-1] == 0) + assert torch.all(tensors["text"]["attention_mask"][-1] == torch.tensor([False]))