Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

TransformerTextField #5280

Merged
merged 8 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- The activation layer in the transformer toolkit now can be queried for its output dimension.
- `TransformerEmbeddings` now takes, but ignores, a parameter for the attention mask. This is needed for compatibility with some other modules that get called the same way and use the mask.
- `TransformerPooler` can now be instantiated from a pretrained transformer module, just like the other modules in the transformer toolkit.
- `TransformerTextField`, for cases where you don't care about AllenNLP's advanced text handling capabilities.
- Added `TransformerModule._post_load_pretrained_state_dict_hook()` method. Can be used to modify `missing_keys` and `unexpected_keys` after
loading a pretrained state dictionary. This is useful when tying weights, for example.

Expand Down
1 change: 1 addition & 0 deletions allennlp/data/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from allennlp.data.fields.span_field import SpanField
from allennlp.data.fields.text_field import TextField
from allennlp.data.fields.array_field import ArrayField
from allennlp.data.fields.transformer_text_field import TransformerTextField
106 changes: 106 additions & 0 deletions allennlp/data/fields/transformer_text_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Dict, Optional, List, Any

from overrides import overrides
import torch
import torch.nn.functional

from allennlp.data.fields.field import Field
from allennlp.nn import util


class TransformerTextField(Field[torch.Tensor]):
"""
A `TransformerTextField` is a collection of several tensors that are are a representation of text,
tokenized and ready to become input to a transformer.

The naming pattern of the tensors follows the pattern that's produced by the huggingface tokenizers,
and expected by the huggingface transformers.
"""

__slots__ = [
"input_ids",
"token_type_ids",
"attention_mask",
"special_tokens_mask",
"offsets_mapping",
"padding_token_id",
]

def __init__(
self,
input_ids: torch.Tensor,
# I wish input_ids were called `token_ids` for clarity, but we want to be compatible with huggingface.
token_type_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
special_tokens_mask: Optional[torch.Tensor] = None,
offsets_mapping: Optional[torch.Tensor] = None,
padding_token_id: int = 0,
) -> None:
self.input_ids = input_ids
self.token_type_ids = token_type_ids
self.attention_mask = attention_mask
self.special_tokens_mask = special_tokens_mask
self.offsets_mapping = offsets_mapping
self.padding_token_id = padding_token_id

@overrides
def get_padding_lengths(self) -> Dict[str, int]:
return {
name: len(getattr(self, name))
for name in self.__slots__
if isinstance(getattr(self, name), torch.Tensor)
}

@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> Dict[str, torch.Tensor]:
result = {}
for name, padding_length in padding_lengths.items():
tensor = getattr(self, name)
result[name] = torch.nn.functional.pad(
tensor,
(0, padding_length - len(tensor)),
value=self.padding_token_id if name == "token_ids" else 0,
)
if "attention_mask" not in result:
result["attention_mask"] = torch.tensor(
[True] * len(self.input_ids)
+ [False] * (padding_lengths["input_ids"] - len(self.input_ids)),
dtype=torch.bool,
)
return result

@overrides
def empty_field(self):
return TransformerTextField(torch.LongTensor(), padding_token_id=self.padding_token_id)

@overrides
def batch_tensors(self, tensor_list: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
result: Dict[str, torch.Tensor] = util.batch_tensor_dicts(tensor_list)
# Transformer models need LongTensors for indices, just in case we have more than 2 billion
# different tokens. To save space, we make the switch as late as possible, i.e., here.
result = {
name: t.to(torch.int64) if t.dtype == torch.int32 else t for name, t in result.items()
}
return result

def human_readable_repr(self) -> Dict[str, Any]:
def format_item(x) -> str:
return str(x.item())

def readable_tensor(t: torch.Tensor) -> str:
if len(t) <= 16:
return "[" + ", ".join(map(format_item, t)) + "]"
else:
return (
"["
+ ", ".join(map(format_item, t[:8]))
+ ", ..., "
+ ", ".join(map(format_item, t[-8:]))
+ "]"
)

return {
name: readable_tensor(getattr(self, name))
for name in self.__slots__
if isinstance(getattr(self, name), torch.Tensor)
}
40 changes: 40 additions & 0 deletions tests/data/fields/transformer_text_field_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch

from allennlp.data import Batch, Instance
from allennlp.data.fields import TransformerTextField


def test_transformer_text_field_init():
field = TransformerTextField(torch.IntTensor([1, 2, 3]))
field_as_tensor = field.as_tensor(field.get_padding_lengths())
assert "input_ids" in field_as_tensor
assert "attention_mask" in field_as_tensor
assert torch.all(field_as_tensor["attention_mask"] == torch.BoolTensor([True, True, True]))
assert torch.all(field_as_tensor["input_ids"] == torch.IntTensor([1, 2, 3]))


def test_empty_transformer_text_field():
field = TransformerTextField(torch.IntTensor([]), padding_token_id=9)
field = field.empty_field()
assert isinstance(field, TransformerTextField) and field.padding_token_id == 9
field_as_tensor = field.as_tensor(field.get_padding_lengths())
assert "input_ids" in field_as_tensor
assert "attention_mask" in field_as_tensor
assert torch.all(field_as_tensor["attention_mask"] == torch.BoolTensor([]))
assert torch.all(field_as_tensor["input_ids"] == torch.IntTensor([]))


def test_transformer_text_field_batching():
batch = Batch(
[
Instance({"text": TransformerTextField(torch.IntTensor([1, 2, 3]))}),
Instance({"text": TransformerTextField(torch.IntTensor([2, 3, 4, 5]))}),
Instance({"text": TransformerTextField(torch.IntTensor())}),
]
)
tensors = batch.as_tensor_dict(batch.get_padding_lengths())
assert tensors["text"]["input_ids"].shape == (3, 4)
assert tensors["text"]["input_ids"][0, -1] == 0
assert tensors["text"]["attention_mask"][0, -1] == torch.Tensor([False])
assert torch.all(tensors["text"]["input_ids"][-1] == 0)
assert torch.all(tensors["text"]["attention_mask"][-1] == torch.tensor([False]))