Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

TransformerTextField #5280

Merged
merged 8 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- The activation layer in the transformer toolkit now can be queried for its output dimension.
- `TransformerEmbeddings` now takes, but ignores, a parameter for the attention mask. This is needed for compatibility with some other modules that get called the same way and use the mask.
- `TransformerPooler` can now be instantiated from a pretrained transformer module, just like the other modules in the transformer toolkit.
- `TransformerTextField`, for cases where you don't care about AllenNLP's advanced text handling capabilities.

### Fixed

Expand Down
1 change: 1 addition & 0 deletions allennlp/data/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from allennlp.data.fields.span_field import SpanField
from allennlp.data.fields.text_field import TextField
from allennlp.data.fields.array_field import ArrayField
from allennlp.data.fields.transformer_text_field import TransformerTextField
97 changes: 97 additions & 0 deletions allennlp/data/fields/transformer_text_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Dict, Optional, List, Any

from overrides import overrides
import torch
import torch.nn.functional

from allennlp.data.fields.field import Field
from allennlp.nn import util


class TransformerTextField(Field[torch.Tensor]):
"""
A `TransformerTextField` is a collection of several tensors that are are a representation of text,
tokenized and ready to become input to a transformer.

The naming pattern of the tensors follows the pattern that's produced by the huggingface tokenizers,
and expected by the huggingface transformers.
"""

__slots__ = [
"input_ids",
"token_type_ids",
"attention_mask",
"special_tokens_mask",
"offsets_mapping",
"padding_token_id",
]

def __init__(
self,
input_ids: torch.Tensor,
# I wish input_ids were called `token_ids` for clarity, but we want to be compatible with huggingface.
token_type_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
special_tokens_mask: Optional[torch.Tensor] = None,
offsets_mapping: Optional[torch.Tensor] = None,
padding_token_id: int = 0,
) -> None:
self.input_ids = input_ids
self.token_type_ids = token_type_ids
self.attention_mask = attention_mask
self.special_tokens_mask = special_tokens_mask
self.offsets_mapping = offsets_mapping
self.padding_token_id = padding_token_id

@overrides
def get_padding_lengths(self) -> Dict[str, int]:
return {
name: len(getattr(self, name))
for name in self.__slots__
if isinstance(getattr(self, name), torch.Tensor)
}

@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> Dict[str, torch.Tensor]:
result = {}
for name, padding_length in padding_lengths.items():
tensor = getattr(self, name)
result[name] = torch.nn.functional.pad(
tensor,
(0, padding_length - len(tensor)),
value=self.padding_token_id if name == "token_ids" else 0,
)
if "attention_mask" not in result:
result["attention_mask"] = torch.tensor(
[True] * len(self.input_ids)
+ [False] * (padding_lengths["input_ids"] - len(self.input_ids)),
dtype=torch.bool,
)
return result

@overrides
def empty_field(self):
return TransformerTextField(torch.LongTensor(), padding_token_id=self.padding_token_id)

@overrides
def batch_tensors(self, tensor_list: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
result: Dict[str, torch.Tensor] = util.batch_tensor_dicts(tensor_list)
# Transformer models need LongTensors for indices, just in case we have more than 2 billion
# different tokens. To save space, we make the switch as late as possible, i.e., here.
result = {
name: t.to(torch.int64) if t.dtype == torch.int32 else t for name, t in result.items()
}
return result

def human_readable_repr(self) -> Dict[str, Any]:
def readable_tensor(t: torch.Tensor) -> str:
if len(t) <= 16:
return "[" + ", ".join(map(str, t)) + "]"
else:
return "[" + ", ".join(map(str, t[:8])) + ", ".join(map(str, t[-8:])) + "]"

return {
name: readable_tensor(getattr(self, name))
for name in self.__slots__
if isinstance(getattr(self, name), torch.Tensor)
}
40 changes: 40 additions & 0 deletions tests/data/fields/transformer_text_field_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch

from allennlp.data import Batch, Instance
from allennlp.data.fields import TransformerTextField


def test_transformer_text_field_init():
field = TransformerTextField(torch.IntTensor([1, 2, 3]))
field_as_tensor = field.as_tensor(field.get_padding_lengths())
assert "input_ids" in field_as_tensor
assert "attention_mask" in field_as_tensor
assert torch.all(field_as_tensor["attention_mask"] == torch.BoolTensor([True, True, True]))
assert torch.all(field_as_tensor["input_ids"] == torch.IntTensor([1, 2, 3]))


def test_empty_transformer_text_field():
field = TransformerTextField(torch.IntTensor([]), padding_token_id=9)
field = field.empty_field()
assert isinstance(field, TransformerTextField) and field.padding_token_id == 9
field_as_tensor = field.as_tensor(field.get_padding_lengths())
assert "input_ids" in field_as_tensor
assert "attention_mask" in field_as_tensor
assert torch.all(field_as_tensor["attention_mask"] == torch.BoolTensor([]))
assert torch.all(field_as_tensor["input_ids"] == torch.IntTensor([]))


def test_transformer_text_field_batching():
batch = Batch(
[
Instance({"text": TransformerTextField(torch.IntTensor([1, 2, 3]))}),
Instance({"text": TransformerTextField(torch.IntTensor([2, 3, 4, 5]))}),
Instance({"text": TransformerTextField(torch.IntTensor())}),
]
)
tensors = batch.as_tensor_dict(batch.get_padding_lengths())
assert tensors["text"]["input_ids"].shape == (3, 4)
assert tensors["text"]["input_ids"][0, -1] == 0
assert tensors["text"]["attention_mask"][0, -1] == torch.Tensor([False])
assert torch.all(tensors["text"]["input_ids"][-1] == 0)
assert torch.all(tensors["text"]["attention_mask"][-1] == torch.tensor([False]))