Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Generalizing self attention #4756

Merged
merged 6 commits into from
Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 103 additions & 19 deletions allennlp/modules/transformer/self_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Dict
import torch

from allennlp.common import FromParams
Expand All @@ -13,14 +13,27 @@ class SelfAttention(TransformerModule, FromParams):
Details in the paper:
[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, 2019]
(https://api.semanticscholar.org/CorpusID:52967399)

# Parameters

hidden_size: `int`
num_attention_heads: `int`
dropout: `float` (default = `0.0`)
scoring_func: `str` (default = `scaled_dot_product`)
The name of the attention-calculating function to be used.
Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`.
"""

_relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"]
_huggingface_mapping = {"layer": "layers"}

def __init__(
self,
hidden_size: int,
num_attention_heads: int,
dropout: float = 0.0,
scoring_func: str = "scaled_dot_product",
output_linear: bool = False,
):
super().__init__()
if hidden_size % num_attention_heads != 0:
Expand All @@ -45,6 +58,10 @@ def __init__(
else:
self.attn = Attention.by_name(self.scoring_func)()

# out linear layer for distilbert.
if output_linear:
self.output = torch.nn.Linear(hidden_size, self.all_head_size)

self.dropout = torch.nn.Dropout(dropout)

def _transpose_for_scores(self, x):
Expand All @@ -57,32 +74,50 @@ def _transpose_for_scores(self, x):

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
query_states: torch.Tensor,
key_states: Optional[torch.Tensor] = None,
value_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
mixed_query_layer = self.query(hidden_states)

# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None:
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
"""
query_states : `torch.Tensor`
Shape `batch_size x seq_len x hidden_dim`
key_states : `torch.Tensor`, optional
Shape `batch_size x seq_len x hidden_dim`
value_states : `torch.Tensor`, optional
Shape `batch_size x seq_len x hidden_dim`
attention_mask : `torch.BoolTensor`, optional
Shape `batch_size x seq_len`
head_mask : `torch.BoolTensor`, optional
output_attentions : `bool`
Whether to also return the attention probabilities, default = `False`
"""
if key_states is None:
key_states = query_states
if value_states is None:
value_states = query_states

batch_size = query_states.size(0)
k_length = key_states.size(1)

mixed_query_layer = self.query(query_states)
mixed_key_layer = self.key(key_states)
mixed_value_layer = self.value(value_states)

query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)

attention_scores = self.attn(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores + attention_mask

if attention_mask is not None:
mask_reshp = (batch_size, 1, 1, k_length)
attention_mask = (attention_mask == 0).view(mask_reshp).expand_as(
attention_scores
) * -10e5
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
Expand All @@ -91,14 +126,63 @@ def forward(
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs, value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)

if hasattr(self, "output"):
context_layer = self.output(context_layer)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs

@classmethod
def _get_mapping(
cls, pretrained_module=None, source="huggingface", mapping: Optional[Dict[str, str]] = None
):
combined_mapping = {}
if "huggingface" in source:
combined_mapping.update(cls._huggingface_mapping)
if mapping is not None:
combined_mapping.update(mapping)
if pretrained_module is not None:
for name, _ in pretrained_module.named_modules():
if "q_lin" in name:
combined_mapping["q_lin"] = "query"
combined_mapping["k_lin"] = "key"
combined_mapping["v_lin"] = "value"
combined_mapping["out_lin"] = "output"
combined_mapping["transformer"] = "encoder"
break
return combined_mapping

@classmethod
def _get_input_arguments(
cls,
pretrained_module: torch.nn.Module,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
**kwargs,
):
submodules = cls._get_mapped_submodules(pretrained_module, source, mapping)
final_kwargs = {}

final_kwargs["hidden_size"] = submodules["query"].in_features
if hasattr(submodules[""], "num_attention_heads"):
final_kwargs["num_attention_heads"] = submodules[""].num_attention_heads
elif hasattr(submodules[""], "n_heads"):
final_kwargs["num_attention_heads"] = submodules[""].n_heads
final_kwargs["output_linear"] = True # Since this is the distilbert case.
else:
raise AttributeError("Cannot find a relevant attribute for number of heads.")

final_kwargs["dropout"] = submodules["dropout"].p

final_kwargs.update(**kwargs)

return final_kwargs
2 changes: 1 addition & 1 deletion allennlp/modules/transformer/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(

layer_outputs = layer_module(
hidden_states,
0.0,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
Expand Down
29 changes: 27 additions & 2 deletions allennlp/modules/transformer/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@


class AttentionLayer(TransformerModule, FromParams):
_relevant_module = "encoder.layers.0.attention"

def __init__(
self,
hidden_size: int,
Expand All @@ -32,18 +34,41 @@ def forward(
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
if encoder_attention_mask is not None:
attention_mask = encoder_attention_mask
self_output = self.self(
input_tensor,
encoder_hidden_states,
encoder_hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
attention_output = self.output(self_output[0], input_tensor)
outputs = (attention_output,) + self_output[1:] # add attentions if we output them
return outputs

@classmethod
def _get_input_arguments(
cls,
pretrained_module: torch.nn.Module,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
**kwargs,
):
submodules = cls._get_mapped_submodules(pretrained_module, source, mapping)

final_kwargs = {}

final_kwargs["hidden_size"] = submodules["self.query"].in_features
final_kwargs["num_attention_heads"] = submodules["self"].num_attention_heads
final_kwargs["attention_dropout"] = submodules["self.dropout"].p
final_kwargs["hidden_dropout"] = submodules["output.dropout"].p

final_kwargs.update(**kwargs)

return final_kwargs


class TransformerLayer(TransformerModule, FromParams):
_relevant_module = "encoder.layers.0"
Expand Down
69 changes: 49 additions & 20 deletions allennlp/modules/transformer/transformer_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,76 @@ class TransformerModule(torch.nn.Module):
any differences in the module names between the class modules and the huggingface model's
modules.

`_relevant_module` is an optional str which is the expected name of the module in
the huggingface pretrained model.
`_relevant_module` is an optional str or list of str which contains the expected name of the module
in the huggingface pretrained model. It can be a list to account for different names in different
models. The search is carried out in the order of the list.
"""

_huggingface_mapping: Dict[str, str] = {}
_relevant_module: Optional[str] = None
_relevant_module: Optional[Union[str, List[str]]] = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def _get_mapped_submodules(
cls, pretrained_module, source="huggingface", mapping: Optional[Dict[str, str]] = None
def _get_mapping(
cls,
pretrained_module: Optional[torch.nn.Module] = None,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
):
"""
Subclasses overload this method, and provide appropriate name mapping based on the source.
Returns the mapping to be used, based on the optional `pretrained_module`.
If `pretrained_module` is not given, the default module-level mapping is returned.
"""
submodules = dict(pretrained_module.named_modules())
combined_mapping = {}
if "huggingface" in source:
combined_mapping.update(cls._huggingface_mapping)
if mapping is not None:
combined_mapping.update(mapping)
return combined_mapping

@classmethod
def _get_mapped_submodules(
cls, pretrained_module, source="huggingface", mapping: Optional[Dict[str, str]] = None
):
"""
Subclasses overload this method, and provide appropriate name mapping based on the source.
"""
submodules = dict(pretrained_module.named_modules())
combined_mapping = cls._get_mapping(pretrained_module, source, mapping)
for name, module in pretrained_module.named_modules():
newname = name
for key, val in combined_mapping.items():
newname = newname.replace(key, val)
submodules[newname] = submodules.pop(name)
return submodules

def _construct_default_mapping(self, source):
def _construct_default_mapping(
self,
pretrained_module,
source: str = "huggingface",
mapping: Optional[Dict[str, str]] = None,
):
"""
Recursively constructs the default mapping of parameter names for loading pretrained module weights.
Keys are parameter names from this module, and values are corresponding parameter names in the
expected pretrained module, as per `source`.
"""
mapping = {}
if "huggingface" in source:
mapping = self._huggingface_mapping
combined_mapping = self._get_mapping(pretrained_module, source, mapping)
for name, module in self.named_modules():
if name != "":
if hasattr(module, "_construct_default_mapping"):
# We handle collisions by giving priority to the outer module's mapping.
mapping = dict(
list(module._construct_default_mapping(source).items())
+ list(mapping.items())
combined_mapping = dict(
list(
module._construct_default_mapping(
pretrained_module, source, combined_mapping
).items()
)
+ list(combined_mapping.items())
)
return mapping
return combined_mapping

def _load_from_pretrained_module(
self,
Expand All @@ -79,7 +101,7 @@ def _load_from_pretrained_module(
between `pretrained_module` and the instance.
"""
ignore_absent_parameters = ignore_absent_parameters or []
combined_mapping = self._construct_default_mapping(source)
combined_mapping = self._construct_default_mapping(pretrained_module, source, mapping)
if mapping is not None:
combined_mapping.update(mapping)

Expand Down Expand Up @@ -119,7 +141,7 @@ def _get_input_arguments(
def get_relevant_module(
cls,
pretrained_module: Union[str, torch.nn.Module],
relevant_module: Optional[str] = None,
relevant_module: Optional[Union[str, List[str]]] = None,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
):
Expand All @@ -145,9 +167,16 @@ def get_relevant_module(
submodules = cls._get_mapped_submodules(pretrained_module, source, mapping)
# If the relevant_module is not found, we assume that the pretrained_module
# is already the relevant module.
if relevant_module in submodules:
pretrained_module = submodules[relevant_module]
else:
if isinstance(relevant_module, str):
relevant_module = [relevant_module]
found = False
for module in relevant_module:
if module in submodules:
pretrained_module = submodules[module]
found = True
break

if not found:
logger.warning(
"{} was not found! The submodules are: {}".format(
relevant_module, submodules.keys()
Expand Down
5 changes: 4 additions & 1 deletion tests/modules/transformer/bimodal_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def test_loading_from_pretrained_weights(self):
kwargs = {key: self.params_dict[key] for key in required_kwargs}
module = BiModalEncoder.from_pretrained_module(pretrained_module, **kwargs)
mapping = {
val: key for key, val in module._construct_default_mapping("huggingface").items()
val: key
for key, val in module._construct_default_mapping(
pretrained_module, "huggingface", {}
).items()
}
assert_equal_parameters(
pretrained_module,
Expand Down
Loading