Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Transformer Toolkit fixes #5303

Merged
merged 9 commits into from
Jul 8, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `TransformerTextField`, for cases where you don't care about AllenNLP's advanced text handling capabilities.
- Added `TransformerModule._post_load_pretrained_state_dict_hook()` method. Can be used to modify `missing_keys` and `unexpected_keys` after
loading a pretrained state dictionary. This is useful when tying weights, for example.
- Added an end-to-end test for the Transformer Toolkit.

### Fixed

Expand All @@ -32,10 +33,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Implemented slightly faster label smoothing.
- Fixed the docs for `PytorchTransformerWrapper`
- Fixed recovering training jobs with models that expect `get_metrics()` to not be called until they have seen at least one batch.
- Made the Transformer Toolkit compatible with transformers that don't start their positional embeddings at 0.

### Changed

- Changed behavior of `MultiOptimizer` so that while a default optimizer is still required, an error is not thrown if the default optimizer receives no parameters.
- Made the epsilon parameter for the layer normalization in token embeddings configurable.

### Removed

Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/fields/transformer_text_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def as_tensor(self, padding_lengths: Dict[str, int]) -> Dict[str, torch.Tensor]:
result[name] = torch.nn.functional.pad(
tensor,
(0, padding_length - len(tensor)),
value=self.padding_token_id if name == "token_ids" else 0,
value=self.padding_token_id if name == "input_ids" else 0,
)
if "attention_mask" not in result:
result["attention_mask"] = torch.tensor(
Expand Down
47 changes: 36 additions & 11 deletions allennlp/modules/transformer/transformer_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ class Embeddings(TransformerModule, FromParams):
The probability of an element to be zeroed.
"""

def __init__(self, embeddings: torch.nn.ModuleDict, embedding_size: int, dropout: float):
def __init__(
self,
embeddings: torch.nn.ModuleDict,
embedding_size: int,
dropout: float,
layer_norm_eps: float = 1e-12, # different from Huggingface!
):
super().__init__()
for name, embedding_layer in embeddings.named_children():
if isinstance(embedding_layer, torch.nn.Embedding):
Expand All @@ -41,7 +47,7 @@ def __init__(self, embeddings: torch.nn.ModuleDict, embedding_size: int, dropout
)
)
self.embeddings = embeddings
self.layer_norm = LayerNorm(embedding_size, eps=1e-12)
self.layer_norm = LayerNorm(embedding_size, eps=layer_norm_eps)
self.dropout = torch.nn.Dropout(dropout)

def forward(self, *inputs) -> torch.Tensor:
Expand Down Expand Up @@ -131,8 +137,10 @@ def __init__(
embedding_size: int,
pad_token_id: int = 0,
max_position_embeddings: int = 512,
position_pad_token_id: Optional[int] = None,
type_vocab_size: int = 2,
dropout: float = 0.1,
layer_norm_eps: float = 1e-12, # different from Huggingface!
output_size: Optional[int] = None,
):
embedding_dict = {}
Expand All @@ -141,7 +149,9 @@ def __init__(
embedding_dict["word_embeddings"] = word_embeddings

if max_position_embeddings > 0:
position_embeddings = torch.nn.Embedding(max_position_embeddings, embedding_size)
position_embeddings = torch.nn.Embedding(
max_position_embeddings, embedding_size, padding_idx=position_pad_token_id
)
embedding_dict["position_embeddings"] = position_embeddings

if type_vocab_size > 0:
Expand All @@ -150,7 +160,7 @@ def __init__(

embeddings = torch.nn.ModuleDict(embedding_dict)

super().__init__(embeddings, embedding_size, dropout)
super().__init__(embeddings, embedding_size, dropout, layer_norm_eps=layer_norm_eps)

# For Albert, the embedding size is different than the hidden size used
# in the model, so a linear transform is applied.
Expand Down Expand Up @@ -183,10 +193,21 @@ def forward( # type: ignore

embedding_inputs = [input_ids]

if attention_mask is None:
attention_mask = input_ids != self.embeddings["word_embeddings"].padding_idx

if "position_embeddings" in self.embeddings:
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
padding_idx = self.embeddings["position_embeddings"].padding_idx
if padding_idx is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
else:
# In the RoBERTa case, position indices start with padding_idx + 1. Also, RoBERTa likes
# to respect padding in its position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + 1
position_ids = position_ids.unsqueeze(0).expand(input_shape) * attention_mask
position_ids += padding_idx
embedding_inputs.append(position_ids)

if "token_type_embeddings" in self.embeddings:
Expand All @@ -203,17 +224,21 @@ def forward( # type: ignore

@classmethod
def _from_config(cls, config: "PretrainedConfig", **kwargs):
final_kwargs = {}
final_kwargs["vocab_size"] = config.vocab_size
final_kwargs = {
"vocab_size": config.vocab_size,
"pad_token_id": config.pad_token_id,
"max_position_embeddings": config.max_position_embeddings,
"type_vocab_size": config.type_vocab_size,
"layer_norm_eps": config.layer_norm_eps,
}
# For Albert, the embedding size is different than the hidden size used
# in the model, so a linear transform is applied.
if hasattr(config, "embedding_size"):
final_kwargs["embedding_size"] = config.embedding_size
final_kwargs["output_size"] = config.hidden_size
else:
final_kwargs["embedding_size"] = config.hidden_size
final_kwargs["pad_token_id"] = config.pad_token_id
final_kwargs["max_position_embeddings"] = config.max_position_embeddings
final_kwargs["type_vocab_size"] = config.type_vocab_size
if config.model_type == "roberta":
final_kwargs["position_pad_token_id"] = config.pad_token_id
final_kwargs.update(**kwargs)
return cls(**final_kwargs)
33 changes: 32 additions & 1 deletion tests/modules/transformer/toolkit_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch
from torch.testing import assert_allclose
from overrides import overrides
Expand All @@ -7,7 +8,7 @@
from allennlp.common import cached_transformers
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.token_embedders import Embedding, TokenEmbedder
from allennlp.modules.transformer import TransformerStack, TransformerEmbeddings
from allennlp.modules.transformer import TransformerStack, TransformerEmbeddings, TransformerPooler
from allennlp.common.testing import AllenNlpTestCase


Expand Down Expand Up @@ -168,3 +169,33 @@ def forward(self, token_ids: torch.LongTensor, mask: torch.BoolTensor):
almost = AlmostRegularTransformer()
assert len(almost.transformer.layers) == 12
assert isinstance(almost.embeddings, AlbertEmbeddings)

@pytest.mark.parametrize("model_name", ["bert-base-cased", "roberta-base"])
def test_end_to_end(self, model_name: str):
data = [
("I'm against picketing", "but I don't know how to show it."),
("I saw a human pyramid once.", "It was very unnecessary."),
]
tokenizer = cached_transformers.get_tokenizer(model_name)
batch = tokenizer.batch_encode_plus(data, padding=True, return_tensors="pt")

with torch.no_grad():
huggingface_model = cached_transformers.get(model_name, make_copy=False).eval()
huggingface_output = huggingface_model(**batch)

embeddings = TransformerEmbeddings.from_pretrained_module(model_name).eval()
transformer_stack = TransformerStack.from_pretrained_module(model_name).eval()
pooler = TransformerPooler.from_pretrained_module(model_name).eval()
batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
output = embeddings(**batch)
output = transformer_stack(output, batch["attention_mask"])

assert_allclose(
output.final_hidden_states,
huggingface_output.last_hidden_state,
rtol=0.0001,
atol=1e-4,
)

output = pooler(output.final_hidden_states)
assert_allclose(output, huggingface_output.pooler_output, rtol=0.0001, atol=1e-4)
2 changes: 1 addition & 1 deletion tests/modules/transformer/transformer_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def forward(
torch.manual_seed(23)
text = TextEmbeddings(10, 5, 2, 3, 7, 0.0)
torch.manual_seed(23)
transformer = TransformerEmbeddings(10, 5, 2, 3, 7, 0.0)
transformer = TransformerEmbeddings(10, 5, 2, 3, None, 7, 0.0)

input_ids = torch.tensor([[1, 2]])
token_type_ids = torch.tensor([[1, 0]], dtype=torch.long)
Expand Down