From 70f8f92b9b84cec9c73bd6a7a525541954f35eea Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Thu, 29 Apr 2021 01:09:19 -0700 Subject: [PATCH 01/13] initial commit --- .../transformer/general_self_attention.py | 587 ++++++++++++++++++ allennlp/modules/transformer/t5.py | 19 +- .../transformer/self_attention_test.py | 33 +- .../transformer/t5_self_attention_test.py | 92 +++ 4 files changed, 707 insertions(+), 24 deletions(-) create mode 100644 allennlp/modules/transformer/general_self_attention.py create mode 100644 tests/modules/transformer/t5_self_attention_test.py diff --git a/allennlp/modules/transformer/general_self_attention.py b/allennlp/modules/transformer/general_self_attention.py new file mode 100644 index 00000000000..c70f0b41df7 --- /dev/null +++ b/allennlp/modules/transformer/general_self_attention.py @@ -0,0 +1,587 @@ +import math +from typing import Optional, Dict, Union, Tuple +from dataclasses import dataclass +import torch +import torch.nn.functional as F + +from allennlp.common import FromParams +from allennlp.modules.attention import Attention +from allennlp.modules.transformer.transformer_module import TransformerModule +from allennlp.modules.transformer.util import apply_mask + + +class GeneralSelfAttention(TransformerModule, FromParams): + """ + TODO + """ + + def __init__( + self, + hidden_size: int = 512, + attention_head_size: int = 64, + num_attention_heads: int = 8, + # has_relative_attention_bias: bool = False, # t5 + # relative_attention_num_buckets: int = 32, # t5 + # is_decoder: bool = False, # t5 + scoring_func: str = "scaled_dot_product", + output_linear: bool = False, + dropout: float = 0.0, + bias: bool = True, + normalize_weights: bool = False, + ): + + super().__init__() + + if hidden_size % num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (hidden_size, num_attention_heads) + ) + + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias) + self.key = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias) + self.value = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias) + + if output_linear: + self.output = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias) + + self.scoring_func = scoring_func + if self.scoring_func in ["additive", "linear", "bilinear"]: + self.attn = Attention.by_name(self.scoring_func)(hidden_size, hidden_size) + elif self.scoring_func == "scaled_dot_product": + self.attn = Attention.by_name(self.scoring_func)(self.attention_head_size, False) + else: + self.attn = Attention.by_name(self.scoring_func)() + + # self.is_decoder = is_decoder + # self.has_relative_attention_bias = has_relative_attention_bias + # self.relative_attention_num_buckets = relative_attention_num_buckets + + # if self.has_relative_attention_bias: + # self.relative_attention_bias = torch.nn.Embedding( + # self.relative_attention_num_buckets, self.num_attention_heads + # ) + + self.dropout = dropout + + if normalize_weights: + self._normalize() + + def _normalize(self): + self.query.weight.data.normal_( + mean=0.0, std=(self.hidden_size * self.attention_head_size) ** -0.5 + ) + self.key.weight.data.normal_(mean=0.0, std=self.hidden_size ** -0.5) + self.value.weight.data.normal_(mean=0.0, std=self.hidden_size ** -0.5) + + if hasattr(self, "output"): + self.output.weight.data.normal_( + mean=0.0, std=(self.num_attention_heads * self.attention_head_size) ** -0.5 + ) + + # if self.has_relative_attention_bias: + # self.relative_attention_bias.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) + + def _transpose_for_scores(self, x: torch.Tensor): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def _query_layer(self, query_states: torch.Tensor): + mixed_query_layer = self.query(query_states) + query_layer = self._transpose_for_scores(mixed_query_layer) + return query_layer + + def _key_layer(self, key_states: torch.Tensor, past_key_states: Optional[torch.Tensor] = None): + mixed_key_layer = self.key(key_states) + key_layer = self._transpose_for_scores(mixed_key_layer) + return key_layer + + def _value_layer( + self, value_states: torch.Tensor, past_value_states: Optional[torch.Tensor] = None + ): + mixed_value_layer = self.value(value_states) + value_layer = self._transpose_for_scores(mixed_value_layer) + return value_layer + + def _get_attention_probs( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + attention_mask: torch.Tensor, + head_mask: torch.Tensor, + position_bias: Optional[torch.Tensor] = None, + **kwargs, + ): + attention_scores = self.attn(query_layer, key_layer.transpose(-1, -2)) + + # if position_bias is None: + # if self.has_relative_attention_bias: + # position_bias = self.compute_bias(real_seq_length, key_length) + # else: + # position_bias = torch.zeros( + # (1, self.num_attention_heads, real_seq_length, key_length), + # device=scores.device, + # dtype=scores.dtype, + # ) + + # # if key and values are already calculated + # # we want only the last query position bias + # if past_key_value is not None: + # position_bias = position_bias[:, :, -seq_length:, :] + + # if mask is not None: + # # Shape: (batch_size, num_heads, seq_length, key_length) + # position_bias = apply_mask(position_bias, mask) + + # scores += position_bias + + if attention_mask is not None: + attention_scores = apply_mask(attention_scores, attention_mask) + + attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + attention_probs = F.dropout(attention_probs, p=self.dropout, training=self.training) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + return attention_probs + + def _output_layer(self, attention_probs: torch.Tensor, value_layer: torch.Tensor): + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + if hasattr(self, "output"): + context_layer = self.output(context_layer) + + return context_layer + + def _get_key_value_states( + self, + query_states: torch.Tensor, + key_states: Optional[torch.Tensor] = None, + value_states: Optional[torch.Tensor] = None, + ): + if key_states is None: + key_states = query_states + if value_states is None: + value_states = query_states + return key_states, value_states + + def forward( + self, + query_states: torch.Tensor, + key_states: Optional[torch.Tensor] = None, + value_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + """ + query_states : `torch.Tensor` + Shape `batch_size x seq_len x hidden_dim` + key_states : `torch.Tensor`, optional + Shape `batch_size x seq_len x hidden_dim` + value_states : `torch.Tensor`, optional + Shape `batch_size x seq_len x hidden_dim` + attention_mask : `torch.BoolTensor`, optional + Shape `batch_size x seq_len` + head_mask : `torch.BoolTensor`, optional + output_attentions : `bool` + Whether to also return the attention probabilities, default = `False` + """ + # if key_states is None: + # key_states = query_states + # if value_states is None: + # value_states = query_states + + key_states, value_states = self._get_key_value_states( + query_states, key_states, value_states + ) + + query_layer = self._query_layer(query_states) + key_layer = self._key_layer(key_states) + value_layer = self._value_layer(value_states) + + attention_probs = self._get_attention_probs( + query_layer, key_layer, attention_mask, head_mask + ) + + context_layer = self._output_layer(attention_probs, value_layer) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + +# Unfortunately mypy is insane, so we have to wrap these in unions. +FloatT = Union[torch.FloatTensor] +IntT = Union[torch.IntTensor] +BoolT = Union[torch.BoolTensor] + + +@dataclass +class T5AttentionOutput: + hidden_states: FloatT + key_value_state: Optional[Tuple[FloatT, FloatT]] + position_bias: FloatT + attn_weights: Optional[FloatT] = None + + +class T5Attention(GeneralSelfAttention): + def __init__( + self, + is_decoder: bool = False, + hidden_size: int = 512, + key_value_proj_dim: int = 64, + num_heads: int = 8, + has_relative_attention_bias: bool = False, + relative_attention_num_buckets: int = 32, + dropout: float = 0.1, + normalize: bool = True, + ): + + super().__init__( + hidden_size=hidden_size, + attention_head_size=key_value_proj_dim, + num_attention_heads=num_heads, + output_linear=True, + dropout=dropout, + bias=False, + normalize_weights=normalize, + ) + + self.is_decoder = is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = relative_attention_num_buckets + + if self.has_relative_attention_bias: + self.relative_attention_bias = torch.nn.Embedding( + self.relative_attention_num_buckets, self.num_attention_heads + ) + + @staticmethod + def _relative_position_bucket( + relative_position: IntT, + bidirectional: bool = True, + num_buckets: int = 32, + max_distance: int = 128, + ) -> IntT: + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the + attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller + buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All + relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the + same bucket. This should allow for more graceful generalization to longer sequences than the model has been + trained on. + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range + [0, num_buckets) + """ + relative_buckets = relative_position.new_zeros(relative_position.shape) + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length: int, key_length: int) -> FloatT: + """ Compute binned relative position bias """ + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + ) + relative_position_bucket = relative_position_bucket.to( + self.relative_attention_bias.weight.device + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def _get_attention_probs( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + attention_mask: torch.Tensor, + head_mask: torch.Tensor, + position_bias: torch.Tensor, + real_seq_length: int, + key_length: int, + query_length: int, + ): + # compute scores + scores = torch.matmul( + query_layer, key_layer.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if self.has_relative_attention_bias: + position_bias = self.compute_bias(real_seq_length, key_length) + else: + position_bias = torch.zeros( + (1, self.num_attention_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, + ) + + # if key and values are already calculated + # we want only the last query position bias + # TODO: use past_key_value correctly!! + # if past_key_value is not None: + # position_bias = position_bias[:, :, -seq_length:, :] + + if attention_mask is not None: + # Shape: (batch_size, num_heads, seq_length, key_length) + position_bias = apply_mask(position_bias, attention_mask) + + scores += position_bias + attn_weights = F.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, num_heads, seq_length, key_length) + attn_weights = F.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, num_heads, seq_length, key_length) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + return attn_weights + + def _get_key_value_states( + self, + query_states: torch.Tensor, + key_states: Optional[torch.Tensor] = None, + value_states: Optional[torch.Tensor] = None, + ): + # TODO: simplify + # FIX: past_key_value usage needs to be fixed. + past_key_value = None + if past_key_value is None: + # if key_value_states is None: # unnecessary check? + key_value_states = (query_states, query_states) + else: + if key_value_states is None: + # self-attn + # (batch_size, num_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, query_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + + key_value_states = (hidden_states, hidden_states) + + return key_value_states + + def _get_seq_key_length(self, hidden_states, past_key_value, key_value_states, query_length): + batch_size, seq_length = hidden_states.shape[:2] + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), "past_key_value should have 2 past states: keys and values. Got {} past states".format( + len(past_key_value) + ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + return real_seq_length, key_length + + def forward( + self, + hidden_states: torch.Tensor, + mask: Optional[torch.BoolTensor] = None, + key_value_states: Optional[FloatT] = None, + position_bias: Optional[FloatT] = None, + past_key_value: Optional[Tuple[FloatT, FloatT]] = None, + layer_head_mask: Optional[BoolT] = None, + query_length: Optional[int] = None, + use_cache: bool = False, + output_attentions: bool = False, + ) -> T5AttentionOutput: + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by + key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, num_heads, q_len - 1, dim_per_head) + # batch_size, seq_length = hidden_states.shape[:2] + # real_seq_length = seq_length + + real_seq_length, key_length = self._get_seq_key_length( + hidden_states, past_key_value, key_value_states, query_length + ) + # FIX: use key value states. + key_value_states = self._get_key_value_states(hidden_states, None, None) + + # get query states + query_states = self._query_layer( + hidden_states + ) # (batch_size, num_heads, seq_length, dim_per_head) + + key_states = self._key_layer(key_value_states[0]) + value_states = self._value_layer(key_value_states[1]) + + attn_weights = self._get_attention_probs( + query_states, + key_states, + mask, + layer_head_mask, + position_bias, + real_seq_length, + key_length, + query_length, + ) + + attn_output = self._output_layer(attn_weights, value_states) + + present_key_value_state = ( + (key_states, value_states) if (self.is_decoder and use_cache) else None + ) + outputs = T5AttentionOutput(attn_output, present_key_value_state, position_bias) + if output_attentions: + outputs.attn_weights = attn_weights + return outputs + + +class SelfAttention(GeneralSelfAttention): + """ + This module computes the self-attention, similar to the architecture in BERT. Additionally, the attention + scoring function can be specified. + Details in the paper: + [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, 2019] + (https://api.semanticscholar.org/CorpusID:52967399) + + # Parameters + + hidden_size: `int` + num_attention_heads: `int` + dropout: `float` (default = `0.0`) + scoring_func: `str` (default = `scaled_dot_product`) + The name of the attention-calculating function to be used. + Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`. + """ + + _relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] + _huggingface_mapping = {"layer": "layers"} + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + dropout: float = 0.0, + scoring_func: str = "scaled_dot_product", + output_linear: bool = False, + ): + + attention_head_size = int(hidden_size / num_attention_heads) + + super().__init__( + hidden_size=hidden_size, + attention_head_size=attention_head_size, + num_attention_heads=num_attention_heads, + scoring_func=scoring_func, + output_linear=output_linear, + dropout=dropout, + bias=True, + ) + + @classmethod + def _get_mapping( + cls, pretrained_module=None, source="huggingface", mapping: Optional[Dict[str, str]] = None + ): + combined_mapping = {} + if "huggingface" in source: + combined_mapping.update(cls._huggingface_mapping) + if mapping is not None: + combined_mapping.update(mapping) + if pretrained_module is not None: + for name, _ in pretrained_module.named_modules(): + if "q_lin" in name: + combined_mapping["q_lin"] = "query" + combined_mapping["k_lin"] = "key" + combined_mapping["v_lin"] = "value" + combined_mapping["out_lin"] = "output" + combined_mapping["transformer"] = "encoder" + break + return combined_mapping + + @classmethod + def _get_input_arguments( + cls, + pretrained_module: torch.nn.Module, + source="huggingface", + mapping: Optional[Dict[str, str]] = None, + **kwargs, + ): + submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) + final_kwargs = {} + + final_kwargs["hidden_size"] = submodules["query"].in_features + if hasattr(submodules[""], "num_attention_heads"): + final_kwargs["num_attention_heads"] = submodules[""].num_attention_heads + elif hasattr(submodules[""], "n_heads"): + final_kwargs["num_attention_heads"] = submodules[""].n_heads + final_kwargs["output_linear"] = True # Since this is the distilbert case. + else: + raise AttributeError("Cannot find a relevant attribute for number of heads.") + + final_kwargs["dropout"] = submodules["dropout"].p + + final_kwargs.update(**kwargs) + + return final_kwargs diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 83305487b76..a7a045f74a3 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -15,6 +15,7 @@ from allennlp.common import FromParams, Params, Lazy, Registrable from allennlp.common.checks import ConfigurationError from allennlp.modules.transformer import TransformerModule +from allennlp.modules.transformer.general_self_attention import T5Attention as NewT5Attention from allennlp.modules.transformer.util import ( apply_mask, get_extended_attention_mask, @@ -132,6 +133,7 @@ def __init__( has_relative_attention_bias: bool = False, relative_attention_num_buckets: int = 32, dropout: float = 0.1, + normalize: bool = True, ): super().__init__() self.is_decoder = is_decoder @@ -153,12 +155,13 @@ def __init__( self.relative_attention_num_buckets, self.num_heads ) - self.q.weight.data.normal_(mean=0.0, std=(hidden_size * key_value_proj_dim) ** -0.5) - self.k.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) - self.v.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) - self.o.weight.data.normal_(mean=0.0, std=(num_heads * key_value_proj_dim) ** -0.5) - if self.has_relative_attention_bias: - self.relative_attention_bias.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) + if normalize: + self.q.weight.data.normal_(mean=0.0, std=(hidden_size * key_value_proj_dim) ** -0.5) + self.k.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) + self.v.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) + self.o.weight.data.normal_(mean=0.0, std=(num_heads * key_value_proj_dim) ** -0.5) + if self.has_relative_attention_bias: + self.relative_attention_bias.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) @staticmethod def _relative_position_bucket( @@ -385,7 +388,7 @@ def __init__( dropout: float = 0.1, ): super().__init__() - self.self_attention = self_attention or T5Attention() + self.self_attention = self_attention or NewT5Attention() self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.self_attention.hidden_size) self.dropout = nn.Dropout(dropout) @@ -436,7 +439,7 @@ def __init__( dropout: float = 0.1, ): super().__init__() - self.enc_dec_attention = enc_dec_attention or T5Attention( + self.enc_dec_attention = enc_dec_attention or NewT5Attention( is_decoder=True, has_relative_attention_bias=False ) self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.enc_dec_attention.hidden_size) diff --git a/tests/modules/transformer/self_attention_test.py b/tests/modules/transformer/self_attention_test.py index e29ae44cf9e..4c576844ae8 100644 --- a/tests/modules/transformer/self_attention_test.py +++ b/tests/modules/transformer/self_attention_test.py @@ -5,7 +5,9 @@ from allennlp.common import Params from allennlp.common import cached_transformers from allennlp.common.testing import assert_equal_parameters, AllenNlpTestCase -from allennlp.modules.transformer import SelfAttention + +# from allennlp.modules.transformer import SelfAttention +from allennlp.modules.transformer.general_self_attention import SelfAttention from allennlp.nn.util import min_value_of_dtype from transformers.models.bert.configuration_bert import BertConfig @@ -55,31 +57,30 @@ def get_modules(params_dict): class TestSelfAttention(AllenNlpTestCase): - def setup_method(self): - super().setup_method() + def test_can_construct_from_params(self): - self.params_dict = {key: val for key, val in PARAMS_DICT.items()} + params_dict = {key: val for key, val in PARAMS_DICT.items()} - params = Params(copy.deepcopy(self.params_dict)) + params = Params(copy.deepcopy(params_dict)) - self.self_attention = SelfAttention.from_params(params) + self_attention = SelfAttention.from_params(params) - def test_can_construct_from_params(self): - assert self.self_attention.num_attention_heads == self.params_dict["num_attention_heads"] - assert self.self_attention.attention_head_size == int( - self.params_dict["hidden_size"] / self.params_dict["num_attention_heads"] + assert self_attention.num_attention_heads == params_dict["num_attention_heads"] + assert self_attention.attention_head_size == int( + params_dict["hidden_size"] / params_dict["num_attention_heads"] ) assert ( - self.self_attention.all_head_size - == self.params_dict["num_attention_heads"] * self.self_attention.attention_head_size + self_attention.all_head_size + == params_dict["num_attention_heads"] * self_attention.attention_head_size ) - assert self.self_attention.query.in_features == self.params_dict["hidden_size"] - assert self.self_attention.key.in_features == self.params_dict["hidden_size"] - assert self.self_attention.value.in_features == self.params_dict["hidden_size"] + assert self_attention.query.in_features == params_dict["hidden_size"] + assert self_attention.key.in_features == params_dict["hidden_size"] + assert self_attention.value.in_features == params_dict["hidden_size"] - assert self.self_attention.dropout.p == self.params_dict["dropout"] + # assert self_attention.dropout.p == params_dict["dropout"] + assert self_attention.dropout == params_dict["dropout"] @pytest.mark.skip("Takes up too much memory") @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) diff --git a/tests/modules/transformer/t5_self_attention_test.py b/tests/modules/transformer/t5_self_attention_test.py new file mode 100644 index 00000000000..cc69239a5a2 --- /dev/null +++ b/tests/modules/transformer/t5_self_attention_test.py @@ -0,0 +1,92 @@ +import copy +import torch + +from allennlp.common import Params +from allennlp.common.testing import AllenNlpTestCase + +# from allennlp.modules.transformer.t5 import T5Attention +from allennlp.modules.transformer.general_self_attention import T5Attention + +from transformers.models.t5.configuration_t5 import T5Config +from transformers.models.t5.modeling_t5 import T5Attention as HFT5Attention + +PARAMS_DICT = { + "hidden_size": 6, + "num_heads": 2, + "key_value_proj_dim": 3, + "dropout": 0.0, + "relative_attention_num_buckets": 2, +} + + +class TestT5Attention(AllenNlpTestCase): + def test_can_construct_from_params(self): + + params_dict = {key: val for key, val in PARAMS_DICT.items()} + + params = Params(copy.deepcopy(params_dict)) + + t5_attention = T5Attention.from_params(params) + + # the old one + # assert t5_attention.num_heads == params_dict["num_heads"] + # assert t5_attention.key_value_proj_dim == params_dict["key_value_proj_dim"] + + # assert ( + # t5_attention.inner_dim + # == params_dict["num_heads"] * params_dict["key_value_proj_dim"] + # ) + + # assert t5_attention.q.in_features == params_dict["hidden_size"] + # assert t5_attention.k.in_features == params_dict["hidden_size"] + # assert t5_attention.v.in_features == params_dict["hidden_size"] + # assert t5_attention.o.in_features == params_dict["hidden_size"] + + # assert t5_attention.dropout == params_dict["dropout"] + + # the new one + assert t5_attention.num_attention_heads == params_dict["num_heads"] + assert t5_attention.attention_head_size == params_dict["key_value_proj_dim"] + + assert ( + t5_attention.all_head_size + == params_dict["num_heads"] * params_dict["key_value_proj_dim"] + ) + + assert t5_attention.query.in_features == params_dict["hidden_size"] + assert t5_attention.key.in_features == params_dict["hidden_size"] + assert t5_attention.value.in_features == params_dict["hidden_size"] + assert t5_attention.output.in_features == params_dict["hidden_size"] + + assert t5_attention.dropout == params_dict["dropout"] + + def test_forward_against_huggingface_output(self): + hidden_states = torch.randn(2, 3, 6) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + + hf_kwargs = { + "d_model": PARAMS_DICT["hidden_size"], + "d_kv": PARAMS_DICT["key_value_proj_dim"], + "num_heads": PARAMS_DICT["num_heads"], + "relative_attention_num_buckets": PARAMS_DICT["relative_attention_num_buckets"], + "dropout_rate": PARAMS_DICT["dropout"], + } + + torch.manual_seed(1234) + hf_module = HFT5Attention(T5Config(**hf_kwargs), has_relative_attention_bias=False) + + torch.manual_seed(1234) + + params = copy.deepcopy(PARAMS_DICT) + params["normalize"] = False # only for this test. + t5_attention = T5Attention(**params) + + # setting to eval mode to avoid non-deterministic dropout. + t5_attention = t5_attention.eval() + hf_module = hf_module.eval() + + output = t5_attention.forward(hidden_states, mask=attention_mask) + attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 + hf_output = hf_module.forward(hidden_states, mask=attention_mask_hf) + + assert torch.allclose(output.hidden_states, hf_output[0]) From ea93e9edc108dfa61e7d69c5a3e9881768af14dc Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Wed, 5 May 2021 02:14:07 -0700 Subject: [PATCH 02/13] general self attn --- .../transformer/general_self_attention.py | 471 +++++++++--------- allennlp/modules/transformer/t5.py | 12 +- .../transformer/self_attention_test.py | 4 +- .../transformer/t5_self_attention_test.py | 6 +- 4 files changed, 245 insertions(+), 248 deletions(-) diff --git a/allennlp/modules/transformer/general_self_attention.py b/allennlp/modules/transformer/general_self_attention.py index c70f0b41df7..111bc6d073a 100644 --- a/allennlp/modules/transformer/general_self_attention.py +++ b/allennlp/modules/transformer/general_self_attention.py @@ -9,6 +9,29 @@ from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.util import apply_mask +# Unfortunately mypy is insane, so we have to wrap these in unions. +FloatT = Union[torch.FloatTensor] +IntT = Union[torch.IntTensor] +BoolT = Union[torch.BoolTensor] + + +@dataclass +class KeyValueState: + key_state: FloatT + value_state: FloatT + + +@dataclass +class GeneralSelfAttentionOutput: + """ + Encapsulates the outputs of the `GeneralSelfAttention` module. + """ + + hidden_states: FloatT + key_value_state: Optional[Tuple[FloatT, FloatT]] = None + position_bias: Optional[FloatT] = None + attention_probs: Optional[FloatT] = None + class GeneralSelfAttention(TransformerModule, FromParams): """ @@ -20,14 +43,15 @@ def __init__( hidden_size: int = 512, attention_head_size: int = 64, num_attention_heads: int = 8, - # has_relative_attention_bias: bool = False, # t5 - # relative_attention_num_buckets: int = 32, # t5 - # is_decoder: bool = False, # t5 scoring_func: str = "scaled_dot_product", output_linear: bool = False, dropout: float = 0.0, bias: bool = True, normalize_weights: bool = False, + is_decoder: bool = False, + is_cross_attention: bool = False, + has_relative_attention_bias: bool = False, + relative_attention_num_buckets: int = 32, ): super().__init__() @@ -38,6 +62,10 @@ def __init__( "heads (%d)" % (hidden_size, num_attention_heads) ) + if is_cross_attention: + assert is_decoder, "The attention layer can be a cross-attention layer only " + "if it is within a decoder." + self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.attention_head_size = attention_head_size @@ -58,17 +86,19 @@ def __init__( else: self.attn = Attention.by_name(self.scoring_func)() - # self.is_decoder = is_decoder - # self.has_relative_attention_bias = has_relative_attention_bias - # self.relative_attention_num_buckets = relative_attention_num_buckets + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = relative_attention_num_buckets - # if self.has_relative_attention_bias: - # self.relative_attention_bias = torch.nn.Embedding( - # self.relative_attention_num_buckets, self.num_attention_heads - # ) + if self.has_relative_attention_bias: + self.relative_attention_bias = torch.nn.Embedding( + self.relative_attention_num_buckets, self.num_attention_heads + ) self.dropout = dropout + self.is_decoder = is_decoder + self.is_cross_attention = is_cross_attention + if normalize_weights: self._normalize() @@ -84,8 +114,8 @@ def _normalize(self): mean=0.0, std=(self.num_attention_heads * self.attention_head_size) ** -0.5 ) - # if self.has_relative_attention_bias: - # self.relative_attention_bias.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) + if hasattr(self, "has_relative_attention_bias") and self.has_relative_attention_bias: + self.relative_attention_bias.weight.data.normal_(mean=0.0, std=self.hidden_size ** -0.5) def _transpose_for_scores(self, x: torch.Tensor): new_x_shape = x.size()[:-1] + ( @@ -100,17 +130,52 @@ def _query_layer(self, query_states: torch.Tensor): query_layer = self._transpose_for_scores(mixed_query_layer) return query_layer - def _key_layer(self, key_states: torch.Tensor, past_key_states: Optional[torch.Tensor] = None): - mixed_key_layer = self.key(key_states) - key_layer = self._transpose_for_scores(mixed_key_layer) - return key_layer - - def _value_layer( - self, value_states: torch.Tensor, past_value_states: Optional[torch.Tensor] = None + def _project( + self, + query_states: torch.Tensor, + layer: torch.nn.Linear, + source_states: Optional[torch.Tensor] = None, + past_key_or_value_states: Optional[torch.Tensor] = None, ): - mixed_value_layer = self.value(value_states) - value_layer = self._transpose_for_scores(mixed_value_layer) - return value_layer + if self.is_decoder: + if self.is_cross_attention: + if past_key_or_value_states is None: + assert source_states is not None, "Encoder final state needs to be passed." + query_states = source_states + else: + return past_key_or_value_states + + layer_output = layer(query_states) + layer_output = self._transpose_for_scores(layer_output) + if self.is_decoder: + layer_output = torch.cat([past_key_or_value_states, layer_output], dim=2) + + return layer_output + + def _position_bias( + self, + position_bias, + seq_lengths, + past_key_states, + attention_scores, + ): + seq_length, real_seq_length, key_length = seq_lengths + + if position_bias is None: + if self.has_relative_attention_bias: + position_bias = self.compute_bias(real_seq_length, key_length) + else: + position_bias = torch.zeros( + (1, self.num_attention_heads, real_seq_length, key_length), + device=attention_scores.device, + dtype=attention_scores.dtype, + ) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_states is not None: + position_bias = position_bias[:, :, -seq_length:, :] + return position_bias def _get_attention_probs( self, @@ -119,45 +184,37 @@ def _get_attention_probs( attention_mask: torch.Tensor, head_mask: torch.Tensor, position_bias: Optional[torch.Tensor] = None, + seq_lengths: Optional[Tuple[int, int, int]] = None, + past_key_states: Optional[torch.Tensor] = None, **kwargs, ): attention_scores = self.attn(query_layer, key_layer.transpose(-1, -2)) - # if position_bias is None: - # if self.has_relative_attention_bias: - # position_bias = self.compute_bias(real_seq_length, key_length) - # else: - # position_bias = torch.zeros( - # (1, self.num_attention_heads, real_seq_length, key_length), - # device=scores.device, - # dtype=scores.dtype, - # ) - - # # if key and values are already calculated - # # we want only the last query position bias - # if past_key_value is not None: - # position_bias = position_bias[:, :, -seq_length:, :] + # return attention_scores - # if mask is not None: - # # Shape: (batch_size, num_heads, seq_length, key_length) - # position_bias = apply_mask(position_bias, mask) - - # scores += position_bias + position_bias = self._position_bias( + position_bias, seq_lengths, past_key_states, attention_scores + ) - if attention_mask is not None: - attention_scores = apply_mask(attention_scores, attention_mask) + if position_bias is not None: + if attention_mask is not None: + # Shape: (batch_size, num_heads, seq_length, key_length) + position_bias = apply_mask(position_bias, attention_mask) + attention_scores += position_bias + else: + if attention_mask is not None: + attention_scores = apply_mask(attention_scores, attention_mask) attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = F.dropout(attention_probs, p=self.dropout, training=self.training) if head_mask is not None: attention_probs = attention_probs * head_mask - return attention_probs + return attention_probs, position_bias def _output_layer(self, attention_probs: torch.Tensor, value_layer: torch.Tensor): context_layer = torch.matmul(attention_probs, value_layer) @@ -170,108 +227,101 @@ def _output_layer(self, attention_probs: torch.Tensor, value_layer: torch.Tensor return context_layer - def _get_key_value_states( - self, - query_states: torch.Tensor, - key_states: Optional[torch.Tensor] = None, - value_states: Optional[torch.Tensor] = None, - ): - if key_states is None: - key_states = query_states - if value_states is None: - value_states = query_states - return key_states, value_states + def _get_lengths(self, query_states, past_key_states, source_states): + + seq_length = query_states.shape[1] + effective_seq_len = seq_length + + key_length = seq_length + + if past_key_states is not None: + # TODO: query_length from up the stack: move logic here. + # TODO: clarify the logic here. + effective_seq_len += past_key_states.shape[2] + if self.is_cross_attention: + key_length = source_states.shape[1] + + return (seq_length, effective_seq_len, key_length) def forward( self, query_states: torch.Tensor, - key_states: Optional[torch.Tensor] = None, - value_states: Optional[torch.Tensor] = None, + past_key_states: Optional[torch.Tensor] = None, + past_value_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.BoolTensor] = None, + source_states: Optional[torch.Tensor] = None, + source_attention_mask: Optional[torch.BoolTensor] = None, head_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, output_attentions: bool = False, + use_cache: bool = False, ): """ query_states : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` - key_states : `torch.Tensor`, optional + past_key_states : `torch.Tensor`, optional Shape `batch_size x seq_len x hidden_dim` - value_states : `torch.Tensor`, optional + These are the key_states from the previous step of the decoder. + past_value_states : `torch.Tensor`, optional Shape `batch_size x seq_len x hidden_dim` + These are the value_states from the previous step of the decoder. attention_mask : `torch.BoolTensor`, optional Shape `batch_size x seq_len` + source_states : `torch.Tensor`, optional + Shape `batch_size x source_seq_len x hidden_dim` + This is from the final state of attention over the source (encoder); + it is passed when this module is being used for cross-attention. + source_attention_mask : `torch.BoolTensor`, optional + Shape `batch_size x source_seq_len` head_mask : `torch.BoolTensor`, optional + position_bias : `torch.Tensor`, optional output_attentions : `bool` Whether to also return the attention probabilities, default = `False` - """ - # if key_states is None: - # key_states = query_states - # if value_states is None: - # value_states = query_states - key_states, value_states = self._get_key_value_states( - query_states, key_states, value_states - ) + !!! Note + `source_states` needs to be passed in case of cross-attention. + """ query_layer = self._query_layer(query_states) - key_layer = self._key_layer(key_states) - value_layer = self._value_layer(value_states) - - attention_probs = self._get_attention_probs( - query_layer, key_layer, attention_mask, head_mask + key_layer = self._project( + query_states, + self.key, + source_states, + past_key_states, ) - context_layer = self._output_layer(attention_probs, value_layer) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs - - -# Unfortunately mypy is insane, so we have to wrap these in unions. -FloatT = Union[torch.FloatTensor] -IntT = Union[torch.IntTensor] -BoolT = Union[torch.BoolTensor] + value_layer = self._project( + query_states, + self.value, + source_states, + past_value_states, + ) + if self.is_cross_attention: + attention_mask = source_attention_mask -@dataclass -class T5AttentionOutput: - hidden_states: FloatT - key_value_state: Optional[Tuple[FloatT, FloatT]] - position_bias: FloatT - attn_weights: Optional[FloatT] = None + seq_lengths = self._get_lengths(query_states, past_key_states, source_states) + attention_probs, position_bias = self._get_attention_probs( + query_layer, + key_layer, + attention_mask, + head_mask, + position_bias, + seq_lengths, + past_key_states, + ) -class T5Attention(GeneralSelfAttention): - def __init__( - self, - is_decoder: bool = False, - hidden_size: int = 512, - key_value_proj_dim: int = 64, - num_heads: int = 8, - has_relative_attention_bias: bool = False, - relative_attention_num_buckets: int = 32, - dropout: float = 0.1, - normalize: bool = True, - ): + context_layer = self._output_layer(attention_probs, value_layer) - super().__init__( - hidden_size=hidden_size, - attention_head_size=key_value_proj_dim, - num_attention_heads=num_heads, - output_linear=True, - dropout=dropout, - bias=False, - normalize_weights=normalize, + present_key_value_state = ( + (key_layer, value_layer) if (self.is_decoder and use_cache) else None + ) + outputs = GeneralSelfAttentionOutput( + context_layer, present_key_value_state, position_bias, attention_probs ) - self.is_decoder = is_decoder - self.has_relative_attention_bias = has_relative_attention_bias - self.relative_attention_num_buckets = relative_attention_num_buckets - - if self.has_relative_attention_bias: - self.relative_attention_bias = torch.nn.Embedding( - self.relative_attention_num_buckets, self.num_attention_heads - ) + return outputs @staticmethod def _relative_position_bucket( @@ -349,152 +399,85 @@ def compute_bias(self, query_length: int, key_length: int) -> FloatT: ) # shape (1, num_heads, query_length, key_length) return values - def _get_attention_probs( - self, - query_layer: torch.Tensor, - key_layer: torch.Tensor, - attention_mask: torch.Tensor, - head_mask: torch.Tensor, - position_bias: torch.Tensor, - real_seq_length: int, - key_length: int, - query_length: int, - ): - # compute scores - scores = torch.matmul( - query_layer, key_layer.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - if self.has_relative_attention_bias: - position_bias = self.compute_bias(real_seq_length, key_length) - else: - position_bias = torch.zeros( - (1, self.num_attention_heads, real_seq_length, key_length), - device=scores.device, - dtype=scores.dtype, - ) - - # if key and values are already calculated - # we want only the last query position bias - # TODO: use past_key_value correctly!! - # if past_key_value is not None: - # position_bias = position_bias[:, :, -seq_length:, :] - - if attention_mask is not None: - # Shape: (batch_size, num_heads, seq_length, key_length) - position_bias = apply_mask(position_bias, attention_mask) - scores += position_bias - attn_weights = F.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, num_heads, seq_length, key_length) - attn_weights = F.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, num_heads, seq_length, key_length) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask +class T5Attention(GeneralSelfAttention): - return attn_weights + # _relevant_module = ["encoder.block.0.layer.0.self_attention"] + _huggingface_mapping = { + "q": "query", + "k": "key", + "v": "value", + "o": "output", + "layers": "layer", + } - def _get_key_value_states( + def __init__( self, - query_states: torch.Tensor, - key_states: Optional[torch.Tensor] = None, - value_states: Optional[torch.Tensor] = None, + is_decoder: bool = False, + hidden_size: int = 512, + key_value_proj_dim: int = 64, + num_heads: int = 8, + has_relative_attention_bias: bool = False, + relative_attention_num_buckets: int = 32, + dropout: float = 0.1, + normalize: bool = True, + is_cross_attention: bool = False, ): - # TODO: simplify - # FIX: past_key_value usage needs to be fixed. - past_key_value = None - if past_key_value is None: - # if key_value_states is None: # unnecessary check? - key_value_states = (query_states, query_states) - else: - if key_value_states is None: - # self-attn - # (batch_size, num_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, query_states], dim=2) - else: - # cross-attn - hidden_states = past_key_value - - key_value_states = (hidden_states, hidden_states) - - return key_value_states - - def _get_seq_key_length(self, hidden_states, past_key_value, key_value_states, query_length): - batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length - - if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), "past_key_value should have 2 past states: keys and values. Got {} past states".format( - len(past_key_value) - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + super().__init__( + hidden_size=hidden_size, + attention_head_size=key_value_proj_dim, + num_attention_heads=num_heads, + output_linear=True, + scoring_func="scaled_dot_product", + dropout=dropout, + bias=False, + normalize_weights=normalize, + is_decoder=is_decoder, + is_cross_attention=is_cross_attention, + has_relative_attention_bias=has_relative_attention_bias, + relative_attention_num_buckets=relative_attention_num_buckets, + ) - return real_seq_length, key_length + self.attn = Attention.by_name(self.scoring_func)(1, False) - def forward( + def forward( # type: ignore self, hidden_states: torch.Tensor, mask: Optional[torch.BoolTensor] = None, key_value_states: Optional[FloatT] = None, position_bias: Optional[FloatT] = None, - past_key_value: Optional[Tuple[FloatT, FloatT]] = None, + past_key_value: Optional[ + Tuple[FloatT, FloatT] + ] = None, # this is used when taking decoding steps. layer_head_mask: Optional[BoolT] = None, - query_length: Optional[int] = None, + query_length: Optional[int] = None, # only relevant in cross-attention. use_cache: bool = False, output_attentions: bool = False, - ) -> T5AttentionOutput: + ) -> GeneralSelfAttentionOutput: """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, num_heads, q_len - 1, dim_per_head) - # batch_size, seq_length = hidden_states.shape[:2] - # real_seq_length = seq_length - - real_seq_length, key_length = self._get_seq_key_length( - hidden_states, past_key_value, key_value_states, query_length - ) - # FIX: use key value states. - key_value_states = self._get_key_value_states(hidden_states, None, None) - - # get query states - query_states = self._query_layer( - hidden_states - ) # (batch_size, num_heads, seq_length, dim_per_head) - - key_states = self._key_layer(key_value_states[0]) - value_states = self._value_layer(key_value_states[1]) - - attn_weights = self._get_attention_probs( - query_states, - key_states, - mask, - layer_head_mask, - position_bias, - real_seq_length, - key_length, - query_length, + if past_key_value: + past_key_states = past_key_value[0] + past_value_states = past_key_value[1] + else: + past_key_states = None + past_value_states = None + + outputs = super().forward( + query_states=hidden_states, + past_key_states=past_key_states, + past_value_states=past_value_states, + attention_mask=mask, + source_states=key_value_states, + source_attention_mask=None, # TODO: is this a bug in current T5 code? + head_mask=layer_head_mask, + position_bias=position_bias, + output_attentions=output_attentions, ) - attn_output = self._output_layer(attn_weights, value_states) - - present_key_value_state = ( - (key_states, value_states) if (self.is_decoder and use_cache) else None - ) - outputs = T5AttentionOutput(attn_output, present_key_value_state, position_bias) - if output_attentions: - outputs.attn_weights = attn_weights return outputs @@ -540,6 +523,12 @@ def __init__( bias=True, ) + def forward(self, *args, **kwargs): + outputs = super().forward(*args, **kwargs) + if outputs.attention_probs is not None: + return (outputs.hidden_states, outputs.attention_probs) + return (outputs.hidden_states,) + @classmethod def _get_mapping( cls, pretrained_module=None, source="huggingface", mapping: Optional[Dict[str, str]] = None diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index a7a045f74a3..8ad5e28470c 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -15,7 +15,8 @@ from allennlp.common import FromParams, Params, Lazy, Registrable from allennlp.common.checks import ConfigurationError from allennlp.modules.transformer import TransformerModule -from allennlp.modules.transformer.general_self_attention import T5Attention as NewT5Attention + +# from allennlp.modules.transformer.general_self_attention import T5Attention from allennlp.modules.transformer.util import ( apply_mask, get_extended_attention_mask, @@ -134,6 +135,7 @@ def __init__( relative_attention_num_buckets: int = 32, dropout: float = 0.1, normalize: bool = True, + is_cross_attention: bool = False, ): super().__init__() self.is_decoder = is_decoder @@ -388,7 +390,7 @@ def __init__( dropout: float = 0.1, ): super().__init__() - self.self_attention = self_attention or NewT5Attention() + self.self_attention = self_attention or T5Attention() self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.self_attention.hidden_size) self.dropout = nn.Dropout(dropout) @@ -439,8 +441,10 @@ def __init__( dropout: float = 0.1, ): super().__init__() - self.enc_dec_attention = enc_dec_attention or NewT5Attention( - is_decoder=True, has_relative_attention_bias=False + self.enc_dec_attention = enc_dec_attention or T5Attention( + is_decoder=True, + has_relative_attention_bias=False, + is_cross_attention=True, ) self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.enc_dec_attention.hidden_size) self.dropout = nn.Dropout(dropout) diff --git a/tests/modules/transformer/self_attention_test.py b/tests/modules/transformer/self_attention_test.py index 4c576844ae8..a34d287acd0 100644 --- a/tests/modules/transformer/self_attention_test.py +++ b/tests/modules/transformer/self_attention_test.py @@ -82,7 +82,7 @@ def test_can_construct_from_params(self): # assert self_attention.dropout.p == params_dict["dropout"] assert self_attention.dropout == params_dict["dropout"] - @pytest.mark.skip("Takes up too much memory") + # @pytest.mark.skip("Takes up too much memory") @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) def test_forward_against_huggingface_output(self, module_name, hf_module): hidden_states = torch.randn(2, 3, 6) @@ -103,7 +103,7 @@ def test_forward_against_huggingface_output(self, module_name, hf_module): assert torch.allclose(output[0], hf_output[0]) - @pytest.mark.skip("Takes up too much memory") + # @pytest.mark.skip("Takes up too much memory") @pytest.mark.parametrize( "pretrained_name", [ diff --git a/tests/modules/transformer/t5_self_attention_test.py b/tests/modules/transformer/t5_self_attention_test.py index cc69239a5a2..f0e0cb66b9e 100644 --- a/tests/modules/transformer/t5_self_attention_test.py +++ b/tests/modules/transformer/t5_self_attention_test.py @@ -89,4 +89,8 @@ def test_forward_against_huggingface_output(self): attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 hf_output = hf_module.forward(hidden_states, mask=attention_mask_hf) - assert torch.allclose(output.hidden_states, hf_output[0]) + hs = output.hidden_states + print(hs) + print(hf_output[0]) + + assert torch.allclose(hs, hf_output[0]) From 3b73050c44b3ce7bb790c11ccb436ed8a6ae949f Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Wed, 5 May 2021 18:59:13 -0700 Subject: [PATCH 03/13] fixing bugs, adding tests, adding docs --- ...self_attention.py => general_attention.py} | 186 ++++++++++++++---- allennlp/modules/transformer/t5.py | 39 +++- .../transformer/self_attention_test.py | 6 +- .../transformer/t5_self_attention_test.py | 48 +++-- 4 files changed, 206 insertions(+), 73 deletions(-) rename allennlp/modules/transformer/{general_self_attention.py => general_attention.py} (76%) diff --git a/allennlp/modules/transformer/general_self_attention.py b/allennlp/modules/transformer/general_attention.py similarity index 76% rename from allennlp/modules/transformer/general_self_attention.py rename to allennlp/modules/transformer/general_attention.py index 111bc6d073a..471f8e400fe 100644 --- a/allennlp/modules/transformer/general_self_attention.py +++ b/allennlp/modules/transformer/general_attention.py @@ -16,15 +16,9 @@ @dataclass -class KeyValueState: - key_state: FloatT - value_state: FloatT - - -@dataclass -class GeneralSelfAttentionOutput: +class GeneralAttentionOutput: """ - Encapsulates the outputs of the `GeneralSelfAttention` module. + Encapsulates the outputs of the `GeneralAttention` module. """ hidden_states: FloatT @@ -33,9 +27,47 @@ class GeneralSelfAttentionOutput: attention_probs: Optional[FloatT] = None -class GeneralSelfAttention(TransformerModule, FromParams): +class GeneralAttention(TransformerModule, FromParams): """ - TODO + This module computes self-attention (or cross-attention), similar to the architecture in BERT. + Details in the paper: + [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, 2019] + (https://api.semanticscholar.org/CorpusID:52967399) + + Additionally, it has the following functionality: + + * the attention scoring function can be specified. + * it can be used in encoders as well as decoders. + * `position_bias` can be used, which makes it suitable for + [T5-style attention](https://api.semanticscholar.org/CorpusID:204838007) as well. + + # Parameters + + hidden_size: `int` (default = `512`) + The size of the expected input tensor. + attention_head_size: `int` (default = `64`) + The size of a single attention head. + num_attention_heads: `int` (default = `8`) + The number of attention heads. + scoring_func: `str` (default = `scaled_dot_product`) + The name of the attention-calculating function to be used. + Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`. + output_linear: `bool` (default = `False`) + Whether to add an additional output linear layer at the end. + dropout: `float` (default = `0.0`) + The dropout probability. + bias: `bool` (default = `True`) + Whether to include bias weights in query, key, value (and output) linear layers. + normalize_weights: `bool` (default = `False`) + Whether to normalize the initial weights. + is_decoder: `bool` (default = `False`) + Whether this module is being used in a decoder stack or not. + is_cross_attention: `bool` (default = `False`) + Whether this module is being used for cross-attention in a decoder stack or not. + If `is_cross_attention` is `True`, then `is_decoder` must also be `True`. + has_relative_attention_bias: `bool` (default = `False`) + relative_attention_num_buckets: `int` (default = `32`) + This is ignored if `has_relative_attention_bias` is set to `False`. """ def __init__( @@ -75,6 +107,7 @@ def __init__( self.key = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias) self.value = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias) + # out linear layer for distilbert, T5 etc. if output_linear: self.output = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias) @@ -132,25 +165,31 @@ def _query_layer(self, query_states: torch.Tensor): def _project( self, - query_states: torch.Tensor, + hidden_states: torch.Tensor, layer: torch.nn.Linear, source_states: Optional[torch.Tensor] = None, - past_key_or_value_states: Optional[torch.Tensor] = None, + past_key_or_value: Optional[torch.Tensor] = None, ): - if self.is_decoder: - if self.is_cross_attention: - if past_key_or_value_states is None: - assert source_states is not None, "Encoder final state needs to be passed." - query_states = source_states - else: - return past_key_or_value_states - - layer_output = layer(query_states) - layer_output = self._transpose_for_scores(layer_output) - if self.is_decoder: - layer_output = torch.cat([past_key_or_value_states, layer_output], dim=2) - - return layer_output + # TODO: clarify logic in terms of is_decoder and is_cross_attention + # to make it more readable. + if source_states is None: + # self-attn + # (batch_size, num_heads, seq_length, dim_per_head) + hidden_states = self._transpose_for_scores(layer(hidden_states)) + elif past_key_or_value is None: + # cross-attn + # (batch_size, num_heads, seq_length, dim_per_head) + hidden_states = self._transpose_for_scores(layer(source_states)) + + if past_key_or_value is not None: + if source_states is None: + # self-attn + # (batch_size, num_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_or_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_or_value + return hidden_states def _position_bias( self, @@ -190,8 +229,6 @@ def _get_attention_probs( ): attention_scores = self.attn(query_layer, key_layer.transpose(-1, -2)) - # return attention_scores - position_bias = self._position_bias( position_bias, seq_lengths, past_key_states, attention_scores ) @@ -227,19 +264,17 @@ def _output_layer(self, attention_probs: torch.Tensor, value_layer: torch.Tensor return context_layer - def _get_lengths(self, query_states, past_key_states, source_states): + def _get_lengths(self, query_states, past_key_states, source_states, query_length): seq_length = query_states.shape[1] effective_seq_len = seq_length - key_length = seq_length - if past_key_states is not None: # TODO: query_length from up the stack: move logic here. - # TODO: clarify the logic here. - effective_seq_len += past_key_states.shape[2] - if self.is_cross_attention: - key_length = source_states.shape[1] + # TODO: clarify the logic here in terms of encoder/decoder case. + effective_seq_len += past_key_states.shape[2] if query_length is None else query_length + + key_length = effective_seq_len if source_states is None else source_states.shape[1] return (seq_length, effective_seq_len, key_length) @@ -255,6 +290,7 @@ def forward( position_bias: Optional[torch.Tensor] = None, output_attentions: bool = False, use_cache: bool = False, + query_length: Optional[int] = None, ): """ query_states : `torch.Tensor` @@ -283,6 +319,7 @@ def forward( """ query_layer = self._query_layer(query_states) + key_layer = self._project( query_states, self.key, @@ -300,7 +337,7 @@ def forward( if self.is_cross_attention: attention_mask = source_attention_mask - seq_lengths = self._get_lengths(query_states, past_key_states, source_states) + seq_lengths = self._get_lengths(query_states, past_key_states, source_states, query_length) attention_probs, position_bias = self._get_attention_probs( query_layer, @@ -317,7 +354,7 @@ def forward( present_key_value_state = ( (key_layer, value_layer) if (self.is_decoder and use_cache) else None ) - outputs = GeneralSelfAttentionOutput( + outputs = GeneralAttentionOutput( context_layer, present_key_value_state, position_bias, attention_probs ) @@ -400,15 +437,14 @@ def compute_bias(self, query_length: int, key_length: int) -> FloatT: return values -class T5Attention(GeneralSelfAttention): +class T5Attention(GeneralAttention): - # _relevant_module = ["encoder.block.0.layer.0.self_attention"] + _relevant_module = ["encoder.block.0.layer.0.SelfAttention"] _huggingface_mapping = { "q": "query", "k": "key", "v": "value", "o": "output", - "layers": "layer", } def __init__( @@ -439,7 +475,7 @@ def __init__( relative_attention_num_buckets=relative_attention_num_buckets, ) - self.attn = Attention.by_name(self.scoring_func)(1, False) + self.attn = Attention.by_name(self.scoring_func)(scaling_factor=1, normalize=False) def forward( # type: ignore self, @@ -454,7 +490,7 @@ def forward( # type: ignore query_length: Optional[int] = None, # only relevant in cross-attention. use_cache: bool = False, output_attentions: bool = False, - ) -> GeneralSelfAttentionOutput: + ) -> GeneralAttentionOutput: """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -476,12 +512,76 @@ def forward( # type: ignore head_mask=layer_head_mask, position_bias=position_bias, output_attentions=output_attentions, + use_cache=use_cache, + query_length=query_length, ) return outputs + @classmethod + def _get_input_arguments( + cls, + pretrained_module: torch.nn.Module, + source="huggingface", + mapping: Optional[Dict[str, str]] = None, + **kwargs, + ): + submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) + final_kwargs = {} + + final_kwargs["hidden_size"] = submodules["query"].in_features + + if hasattr(submodules[""], "num_attention_heads"): + final_kwargs["num_heads"] = submodules[""].num_attention_heads + elif hasattr(submodules[""], "n_heads"): + final_kwargs["num_heads"] = submodules[""].n_heads + else: + raise AttributeError("Cannot find a relevant attribute for number of heads.") + + final_kwargs["key_value_proj_dim"] = int( + submodules["query"].out_features / final_kwargs["num_heads"] + ) + + final_kwargs["dropout"] = pretrained_module.dropout + final_kwargs["has_relative_attention_bias"] = pretrained_module.has_relative_attention_bias + final_kwargs[ + "relative_attention_num_buckets" + ] = pretrained_module.relative_attention_num_buckets + final_kwargs["is_decoder"] = pretrained_module.is_decoder + + final_kwargs.update(**kwargs) + + return final_kwargs + + @classmethod + def _get_mapped_submodules( + cls, + pretrained_module: torch.nn.Module, + source: str = "huggingface", + mapping: Optional[Dict[str, str]] = None, + ): + """ + Subclasses overload this method, and provide appropriate name mapping based on the source. + """ + submodules = dict(pretrained_module.named_modules()) + # combined_mapping = cls._get_mapping(pretrained_module, source, mapping) + for name, module in pretrained_module.named_modules(): + newname = name + if name in ["q", "q.weight"]: + newname = newname.replace("q", "query") + elif name in ["k", "k.weight"]: + newname = newname.replace("k", "key") + elif name in ["v", "v.weight"]: + newname = newname.replace("v", "value") + elif name in ["o", "o.weight"]: + newname = newname.replace("o", "output") + else: + pass + submodules[newname] = submodules.pop(name) + return submodules + -class SelfAttention(GeneralSelfAttention): +class SelfAttention(GeneralAttention): """ This module computes the self-attention, similar to the architecture in BERT. Additionally, the attention scoring function can be specified. diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 8ad5e28470c..795275ddfb5 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -16,7 +16,10 @@ from allennlp.common.checks import ConfigurationError from allennlp.modules.transformer import TransformerModule -# from allennlp.modules.transformer.general_self_attention import T5Attention +from allennlp.modules.transformer.general_attention import ( + T5Attention, + GeneralAttentionOutput, +) from allennlp.modules.transformer.util import ( apply_mask, get_extended_attention_mask, @@ -124,7 +127,7 @@ class T5AttentionOutput: attn_weights: Optional[FloatT] = None -class T5Attention(TransformerModule, FromParams): +class T5AttentionOld(TransformerModule, FromParams): def __init__( self, is_decoder: bool = False, @@ -405,7 +408,7 @@ def forward( output_attentions: bool = False, ) -> T5LayerSelfAttentionOutput: normed_hidden_states = self.layer_norm(hidden_states) - attention_output: T5AttentionOutput = self.self_attention( + attention_output: GeneralAttentionOutput = self.self_attention( normed_hidden_states, mask=attention_mask, position_bias=position_bias, @@ -415,11 +418,17 @@ def forward( output_attentions=output_attentions, ) hidden_states = hidden_states + self.dropout(attention_output.hidden_states) + # return T5LayerSelfAttentionOutput( + # hidden_states, + # attention_output.key_value_state, + # attention_output.position_bias, + # attention_output.attn_weights, + # ) return T5LayerSelfAttentionOutput( hidden_states, attention_output.key_value_state, attention_output.position_bias, - attention_output.attn_weights, + attention_output.attention_probs, ) @@ -462,7 +471,7 @@ def forward( output_attentions: bool = False, ) -> T5LayerCrossAttentionOutput: normed_hidden_states = self.layer_norm(hidden_states) - attention_output: T5AttentionOutput = self.enc_dec_attention( + attention_output: GeneralAttentionOutput = self.enc_dec_attention( normed_hidden_states, mask=attention_mask, key_value_states=key_value_states, @@ -474,11 +483,18 @@ def forward( output_attentions=output_attentions, ) layer_output = hidden_states + self.dropout(attention_output.hidden_states) + # return T5LayerCrossAttentionOutput( + # layer_output, + # attention_output.key_value_state, + # attention_output.position_bias, + # attention_output.attn_weights, + # ) + return T5LayerCrossAttentionOutput( layer_output, attention_output.key_value_state, attention_output.position_bias, - attention_output.attn_weights, + attention_output.attention_probs, ) @@ -967,6 +983,17 @@ class T5Output: class T5(TransformerModule, Registrable): _huggingface_mapping = {"shared": "token_embeddings"} + _huggingface_mapping.update( + { + "q": "query", + "k": "key", + "v": "value", + "o": "output", + "block": "blocks", + "SelfAttention": "self_attention", + "EncDecAttention": "enc_dec_attention", + } + ) default_implementation = "default" diff --git a/tests/modules/transformer/self_attention_test.py b/tests/modules/transformer/self_attention_test.py index a34d287acd0..d816470bb75 100644 --- a/tests/modules/transformer/self_attention_test.py +++ b/tests/modules/transformer/self_attention_test.py @@ -7,7 +7,7 @@ from allennlp.common.testing import assert_equal_parameters, AllenNlpTestCase # from allennlp.modules.transformer import SelfAttention -from allennlp.modules.transformer.general_self_attention import SelfAttention +from allennlp.modules.transformer.general_attention import SelfAttention from allennlp.nn.util import min_value_of_dtype from transformers.models.bert.configuration_bert import BertConfig @@ -82,7 +82,7 @@ def test_can_construct_from_params(self): # assert self_attention.dropout.p == params_dict["dropout"] assert self_attention.dropout == params_dict["dropout"] - # @pytest.mark.skip("Takes up too much memory") + @pytest.mark.skip("Takes up too much memory") @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) def test_forward_against_huggingface_output(self, module_name, hf_module): hidden_states = torch.randn(2, 3, 6) @@ -103,7 +103,7 @@ def test_forward_against_huggingface_output(self, module_name, hf_module): assert torch.allclose(output[0], hf_output[0]) - # @pytest.mark.skip("Takes up too much memory") + @pytest.mark.skip("Takes up too much memory") @pytest.mark.parametrize( "pretrained_name", [ diff --git a/tests/modules/transformer/t5_self_attention_test.py b/tests/modules/transformer/t5_self_attention_test.py index f0e0cb66b9e..3d6e1cf7804 100644 --- a/tests/modules/transformer/t5_self_attention_test.py +++ b/tests/modules/transformer/t5_self_attention_test.py @@ -1,11 +1,13 @@ import copy import torch +import pytest from allennlp.common import Params -from allennlp.common.testing import AllenNlpTestCase +from allennlp.common import cached_transformers +from allennlp.common.testing import assert_equal_parameters, AllenNlpTestCase # from allennlp.modules.transformer.t5 import T5Attention -from allennlp.modules.transformer.general_self_attention import T5Attention +from allennlp.modules.transformer.general_attention import T5Attention from transformers.models.t5.configuration_t5 import T5Config from transformers.models.t5.modeling_t5 import T5Attention as HFT5Attention @@ -28,23 +30,6 @@ def test_can_construct_from_params(self): t5_attention = T5Attention.from_params(params) - # the old one - # assert t5_attention.num_heads == params_dict["num_heads"] - # assert t5_attention.key_value_proj_dim == params_dict["key_value_proj_dim"] - - # assert ( - # t5_attention.inner_dim - # == params_dict["num_heads"] * params_dict["key_value_proj_dim"] - # ) - - # assert t5_attention.q.in_features == params_dict["hidden_size"] - # assert t5_attention.k.in_features == params_dict["hidden_size"] - # assert t5_attention.v.in_features == params_dict["hidden_size"] - # assert t5_attention.o.in_features == params_dict["hidden_size"] - - # assert t5_attention.dropout == params_dict["dropout"] - - # the new one assert t5_attention.num_attention_heads == params_dict["num_heads"] assert t5_attention.attention_head_size == params_dict["key_value_proj_dim"] @@ -90,7 +75,28 @@ def test_forward_against_huggingface_output(self): hf_output = hf_module.forward(hidden_states, mask=attention_mask_hf) hs = output.hidden_states - print(hs) - print(hf_output[0]) assert torch.allclose(hs, hf_output[0]) + + @pytest.mark.parametrize( + "pretrained_name", + [ + "t5-small", + ], + ) + def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): + + torch.manual_seed(1234) + pretrained = cached_transformers.get(pretrained_name, False) + + pretrained_module = pretrained.encoder.block[0].layer[0].SelfAttention + + module = T5Attention.from_pretrained_module(pretrained_name) + + mapping = { + val: key + for key, val in module._construct_default_mapping( + pretrained_module, "huggingface", {} + ).items() + } + assert_equal_parameters(pretrained_module, module, mapping=mapping) From 5ce9865648bb23dac94b3557b3dc11530183d3ed Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 24 May 2021 22:22:31 -0700 Subject: [PATCH 04/13] updating other modules --- allennlp/modules/transformer/__init__.py | 2 +- ...neral_attention.py => attention_module.py} | 20 +- allennlp/modules/transformer/t5.py | 266 +----------------- .../modules/transformer/transformer_layer.py | 20 +- .../transformer/self_attention_test.py | 3 +- .../transformer/t5_self_attention_test.py | 3 +- .../transformer/transformer_layer_test.py | 6 +- .../transformer/transformer_stack_test.py | 2 +- 8 files changed, 39 insertions(+), 283 deletions(-) rename allennlp/modules/transformer/{general_attention.py => attention_module.py} (97%) diff --git a/allennlp/modules/transformer/__init__.py b/allennlp/modules/transformer/__init__.py index 9b944130c7c..40e99a67918 100644 --- a/allennlp/modules/transformer/__init__.py +++ b/allennlp/modules/transformer/__init__.py @@ -131,7 +131,7 @@ def forward(self, token_ids: torch.LongTensor, mask: torch.BoolTensor): TransformerEmbeddings, ImageFeatureEmbeddings, ) -from allennlp.modules.transformer.self_attention import SelfAttention +from allennlp.modules.transformer.attention_module import SelfAttention, T5Attention from allennlp.modules.transformer.activation_layer import ActivationLayer from allennlp.modules.transformer.transformer_layer import AttentionLayer, TransformerLayer from allennlp.modules.transformer.transformer_stack import TransformerStack diff --git a/allennlp/modules/transformer/general_attention.py b/allennlp/modules/transformer/attention_module.py similarity index 97% rename from allennlp/modules/transformer/general_attention.py rename to allennlp/modules/transformer/attention_module.py index b18d16f62e2..5052483aa97 100644 --- a/allennlp/modules/transformer/general_attention.py +++ b/allennlp/modules/transformer/attention_module.py @@ -19,9 +19,9 @@ @dataclass -class GeneralAttentionOutput: +class AttentionOutput: """ - Encapsulates the outputs of the `GeneralAttention` module. + Encapsulates the outputs of the `Attention` module. """ hidden_states: FloatT @@ -30,7 +30,7 @@ class GeneralAttentionOutput: attention_probs: Optional[FloatT] = None -class GeneralAttention(TransformerModule, FromParams): +class AttentionModule(TransformerModule, FromParams): """ This module computes self-attention (or cross-attention), similar to the architecture in BERT. Details in the paper: @@ -188,6 +188,8 @@ def _project( if source_states is None: # self-attn # (batch_size, num_heads, key_length, dim_per_head) + # if len(past_key_or_value.shape) == 3: + # past_key_or_value = self._transpose_for_scores(past_key_or_value) hidden_states = torch.cat([past_key_or_value, hidden_states], dim=2) else: # cross-attn @@ -357,7 +359,7 @@ def forward( present_key_value_state = ( (key_layer, value_layer) if (self.is_decoder and use_cache) else None ) - outputs = GeneralAttentionOutput( + outputs = AttentionOutput( context_layer, present_key_value_state, position_bias, attention_probs ) @@ -440,7 +442,7 @@ def compute_bias(self, query_length: int, key_length: int) -> FloatT: return values -class T5Attention(GeneralAttention): +class T5Attention(AttentionModule): _pretrained_relevant_module = ["encoder.block.0.layer.0.SelfAttention"] _pretrained_mapping = { @@ -493,7 +495,7 @@ def forward( # type: ignore query_length: Optional[int] = None, # only relevant in cross-attention. use_cache: bool = False, output_attentions: bool = False, - ) -> GeneralAttentionOutput: + ) -> AttentionOutput: """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -542,7 +544,7 @@ def _from_config(cls, config: "PretrainedConfig", **kwargs): return cls(**final_kwargs) -class SelfAttention(GeneralAttention): +class SelfAttention(AttentionModule): """ This module computes the self-attention, similar to the architecture in BERT. Additionally, the attention scoring function can be specified. @@ -577,6 +579,8 @@ def __init__( dropout: float = 0.0, scoring_func: str = "scaled_dot_product", output_linear: bool = False, + is_decoder: bool = False, + is_cross_attention: bool = False, ): attention_head_size = int(hidden_size / num_attention_heads) @@ -589,6 +593,8 @@ def __init__( output_linear=output_linear, dropout=dropout, bias=True, + is_decoder=is_decoder, + is_cross_attention=is_cross_attention, ) def forward(self, *args, **kwargs): diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 9edf4f73011..530f901a743 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -3,7 +3,6 @@ (https://github.com/huggingface/transformers/blob/4c32f9f26e6a84f0d9843fec8757e6ce640bb44e/src/transformers/models/t5/modeling_t5.py). """ # noqa: E401 -import math from dataclasses import dataclass from typing import Optional, Tuple, List, Union, Dict, TYPE_CHECKING @@ -16,12 +15,11 @@ from allennlp.common.checks import ConfigurationError from allennlp.modules.transformer import TransformerModule -from allennlp.modules.transformer.general_attention import ( +from allennlp.modules.transformer.attention_module import ( T5Attention, - GeneralAttentionOutput, + AttentionOutput, ) from allennlp.modules.transformer.util import ( - apply_mask, get_extended_attention_mask, ) from allennlp.nn.beam_search import BeamSearch @@ -122,262 +120,6 @@ def forward(self, hidden_states) -> FloatT: return hidden_states -@dataclass -class T5AttentionOutput: - hidden_states: FloatT - key_value_state: Optional[Tuple[FloatT, FloatT]] - position_bias: FloatT - attn_weights: Optional[FloatT] = None - - -class T5AttentionOld(TransformerModule, FromParams): - def __init__( - self, - is_decoder: bool = False, - hidden_size: int = 512, - key_value_proj_dim: int = 64, - num_heads: int = 8, - has_relative_attention_bias: bool = False, - relative_attention_num_buckets: int = 32, - dropout: float = 0.1, - normalize: bool = True, - is_cross_attention: bool = False, - ): - super().__init__() - self.is_decoder = is_decoder - self.has_relative_attention_bias = has_relative_attention_bias - - self.relative_attention_num_buckets = relative_attention_num_buckets - self.hidden_size = hidden_size - self.key_value_proj_dim = key_value_proj_dim - self.num_heads = num_heads - self.dropout = dropout - self.inner_dim = self.num_heads * self.key_value_proj_dim - - self.q = nn.Linear(self.hidden_size, self.inner_dim, bias=False) - self.k = nn.Linear(self.hidden_size, self.inner_dim, bias=False) - self.v = nn.Linear(self.hidden_size, self.inner_dim, bias=False) - self.o = nn.Linear(self.inner_dim, self.hidden_size, bias=False) - if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding( - self.relative_attention_num_buckets, self.num_heads - ) - - if normalize: - self.q.weight.data.normal_(mean=0.0, std=(hidden_size * key_value_proj_dim) ** -0.5) - self.k.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) - self.v.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) - self.o.weight.data.normal_(mean=0.0, std=(num_heads * key_value_proj_dim) ** -0.5) - if self.has_relative_attention_bias: - self.relative_attention_bias.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) - - @staticmethod - def _relative_position_bucket( - relative_position: IntT, - bidirectional: bool = True, - num_buckets: int = 32, - max_distance: int = 128, - ) -> IntT: - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the - attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller - buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All - relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the - same bucket. This should allow for more graceful generalization to longer sequences than the model has been - trained on. - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range - [0, num_buckets) - """ - relative_buckets = relative_position.new_zeros(relative_position.shape) - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_postion_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_postion_if_large = torch.min( - relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) - ) - - relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) - return relative_buckets - - def compute_bias(self, query_length: int, key_length: int) -> FloatT: - """ Compute binned relative position bias """ - context_position = torch.arange(query_length, dtype=torch.long)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=(not self.is_decoder), - num_buckets=self.relative_attention_num_buckets, - ) - relative_position_bucket = relative_position_bucket.to( - self.relative_attention_bias.weight.device - ) - values = self.relative_attention_bias( - relative_position_bucket - ) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze( - 0 - ) # shape (1, num_heads, query_length, key_length) - return values - - def forward( - self, - hidden_states: torch.Tensor, - mask: Optional[torch.BoolTensor] = None, - key_value_states: Optional[FloatT] = None, - position_bias: Optional[FloatT] = None, - past_key_value: Optional[Tuple[FloatT, FloatT]] = None, - layer_head_mask: Optional[BoolT] = None, - query_length: Optional[int] = None, - use_cache: bool = False, - output_attentions: bool = False, - ) -> T5AttentionOutput: - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by - key_value_states). - """ - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, num_heads, q_len - 1, dim_per_head) - batch_size, seq_length = hidden_states.shape[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), "past_key_value should have 2 past states: keys and values. Got {} past states".format( - len(past_key_value) - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - return states.view(batch_size, -1, self.num_heads, self.key_value_proj_dim).transpose( - 1, 2 - ) - - def unshape(states): - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value) -> FloatT: - """ projects hidden states correctly to key/query states """ - if key_value_states is None: - # self-attn - # (batch_size, num_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, num_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, num_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape( - self.q(hidden_states) - ) # (batch_size, num_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, - self.k, - key_value_states, - past_key_value[0] if past_key_value is not None else None, - ) - value_states = project( - hidden_states, - self.v, - key_value_states, - past_key_value[1] if past_key_value is not None else None, - ) - - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - if self.has_relative_attention_bias: - position_bias = self.compute_bias(real_seq_length, key_length) - else: - position_bias = torch.zeros( - (1, self.num_heads, real_seq_length, key_length), - device=scores.device, - dtype=scores.dtype, - ) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -seq_length:, :] - - if mask is not None: - # Shape: (batch_size, num_heads, seq_length, key_length) - position_bias = apply_mask(position_bias, mask) - - scores += position_bias - attn_weights = F.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, num_heads, seq_length, key_length) - attn_weights = F.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, num_heads, seq_length, key_length) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask - - attn_output = unshape( - torch.matmul(attn_weights, value_states) - ) # (batch_size, seq_length, dim) - attn_output = self.o(attn_output) - - present_key_value_state = ( - (key_states, value_states) if (self.is_decoder and use_cache) else None - ) - outputs = T5AttentionOutput(attn_output, present_key_value_state, position_bias) - if output_attentions: - outputs.attn_weights = attn_weights - return outputs - - @dataclass class T5LayerSelfAttentionOutput: hidden_states: FloatT @@ -414,7 +156,7 @@ def forward( output_attentions: bool = False, ) -> T5LayerSelfAttentionOutput: normed_hidden_states = self.layer_norm(hidden_states) - attention_output: GeneralAttentionOutput = self.self_attention( + attention_output: AttentionOutput = self.self_attention( normed_hidden_states, mask=attention_mask, position_bias=position_bias, @@ -477,7 +219,7 @@ def forward( output_attentions: bool = False, ) -> T5LayerCrossAttentionOutput: normed_hidden_states = self.layer_norm(hidden_states) - attention_output: GeneralAttentionOutput = self.enc_dec_attention( + attention_output: AttentionOutput = self.enc_dec_attention( normed_hidden_states, mask=attention_mask, key_value_states=key_value_states, diff --git a/allennlp/modules/transformer/transformer_layer.py b/allennlp/modules/transformer/transformer_layer.py index 43a76d33144..914658777a1 100644 --- a/allennlp/modules/transformer/transformer_layer.py +++ b/allennlp/modules/transformer/transformer_layer.py @@ -5,7 +5,7 @@ from allennlp.common import FromParams from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.activation_layer import ActivationLayer -from allennlp.modules.transformer.self_attention import SelfAttention +from allennlp.modules.transformer.attention_module import SelfAttention from allennlp.modules.transformer.output_layer import OutputLayer if TYPE_CHECKING: @@ -38,9 +38,17 @@ def __init__( num_attention_heads: int, attention_dropout: float = 0.0, hidden_dropout: float = 0.0, + is_cross_attention: bool = False, + is_decoder: bool = False, ): super().__init__() - self.self = SelfAttention(hidden_size, num_attention_heads, attention_dropout) + self.self = SelfAttention( + hidden_size, + num_attention_heads, + attention_dropout, + is_cross_attention=is_cross_attention, + is_decoder=is_decoder, + ) self.output = OutputLayer(hidden_size, hidden_size, hidden_dropout) def forward( @@ -71,9 +79,9 @@ def forward( input_tensor, encoder_hidden_states, encoder_hidden_states, - attention_mask, - head_mask, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, ) attention_output = self.output(self_output[0], input_tensor) outputs = (attention_output,) + self_output[1:] # add attentions if we output them @@ -149,6 +157,8 @@ def __init__( num_attention_heads=num_attention_heads, attention_dropout=attention_dropout, hidden_dropout=hidden_dropout, + is_cross_attention=True, + is_decoder=True, ) self.intermediate = ActivationLayer( diff --git a/tests/modules/transformer/self_attention_test.py b/tests/modules/transformer/self_attention_test.py index 12593bbcc4a..597f56b6aa2 100644 --- a/tests/modules/transformer/self_attention_test.py +++ b/tests/modules/transformer/self_attention_test.py @@ -6,8 +6,7 @@ from allennlp.common import Params -# from allennlp.modules.transformer import SelfAttention -from allennlp.modules.transformer.general_attention import SelfAttention +from allennlp.modules.transformer.attention_module import SelfAttention from allennlp.nn.util import min_value_of_dtype diff --git a/tests/modules/transformer/t5_self_attention_test.py b/tests/modules/transformer/t5_self_attention_test.py index a6e3bfc323b..96eb58bd806 100644 --- a/tests/modules/transformer/t5_self_attention_test.py +++ b/tests/modules/transformer/t5_self_attention_test.py @@ -6,8 +6,7 @@ from allennlp.common import Params -# from allennlp.modules.transformer.t5 import T5Attention -from allennlp.modules.transformer.general_attention import T5Attention +from allennlp.modules.transformer.attention_module import T5Attention from transformers.models.t5.configuration_t5 import T5Config from transformers.models.t5.modeling_t5 import T5Attention as HFT5Attention diff --git a/tests/modules/transformer/transformer_layer_test.py b/tests/modules/transformer/transformer_layer_test.py index 4c1e141a5a8..42b9d0c763a 100644 --- a/tests/modules/transformer/transformer_layer_test.py +++ b/tests/modules/transformer/transformer_layer_test.py @@ -44,7 +44,7 @@ def test_attention(attention_params): assert attention_layer.self.query.in_features == attention_params["hidden_size"] assert attention_layer.self.key.in_features == attention_params["hidden_size"] assert attention_layer.self.value.in_features == attention_params["hidden_size"] - assert attention_layer.self.dropout.p == attention_params["attention_dropout"] + assert attention_layer.self.dropout == attention_params["attention_dropout"] assert attention_layer.output.dense.in_features == attention_params["hidden_size"] assert attention_layer.output.dense.out_features == attention_params["hidden_size"] @@ -166,7 +166,7 @@ def test_layer(layer_params): assert transformer_layer.attention.self.query.in_features == layer_params["hidden_size"] assert transformer_layer.attention.self.key.in_features == layer_params["hidden_size"] assert transformer_layer.attention.self.value.in_features == layer_params["hidden_size"] - assert transformer_layer.attention.self.dropout.p == layer_params["attention_dropout"] + assert transformer_layer.attention.self.dropout == layer_params["attention_dropout"] assert transformer_layer.attention.output.dense.in_features == layer_params["hidden_size"] assert transformer_layer.attention.output.dense.out_features == layer_params["hidden_size"] @@ -207,7 +207,7 @@ def test_layer_with_cross_attention(layer_params): transformer_layer( torch.randn(2, 3, 6), attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 3, 6), + encoder_hidden_states=torch.randn(2, 2, 3, 3), ) diff --git a/tests/modules/transformer/transformer_stack_test.py b/tests/modules/transformer/transformer_stack_test.py index cf42f6c0f6d..1ca0522cd59 100644 --- a/tests/modules/transformer/transformer_stack_test.py +++ b/tests/modules/transformer/transformer_stack_test.py @@ -78,7 +78,7 @@ def test_transformer_stack_with_cross_attention(params): transformer_stack.forward( torch.randn(2, 3, 6), attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 3, 6), + encoder_hidden_states=torch.randn(2, 2, 3, 3), ) From 6e7243f97b1267e5bdf41132144f35d8f0c32159 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 25 May 2021 00:57:20 -0700 Subject: [PATCH 05/13] refactor --- .../modules/transformer/attention_module.py | 13 +--------- allennlp/modules/transformer/t5.py | 24 ++++++------------- allennlp/modules/transformer/util.py | 6 ++++- .../transformer/transformer_layer_test.py | 6 ++--- .../transformer/transformer_stack_test.py | 2 +- 5 files changed, 17 insertions(+), 34 deletions(-) diff --git a/allennlp/modules/transformer/attention_module.py b/allennlp/modules/transformer/attention_module.py index 5052483aa97..be9230868ed 100644 --- a/allennlp/modules/transformer/attention_module.py +++ b/allennlp/modules/transformer/attention_module.py @@ -7,16 +7,11 @@ from allennlp.common import FromParams from allennlp.modules.attention import Attention from allennlp.modules.transformer.transformer_module import TransformerModule -from allennlp.modules.transformer.util import apply_mask +from allennlp.modules.transformer.util import apply_mask, FloatT, IntT, BoolT if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig -# Unfortunately mypy is insane, so we have to wrap these in unions. -FloatT = Union[torch.FloatTensor] -IntT = Union[torch.IntTensor] -BoolT = Union[torch.BoolTensor] - @dataclass class AttentionOutput: @@ -597,12 +592,6 @@ def __init__( is_cross_attention=is_cross_attention, ) - def forward(self, *args, **kwargs): - outputs = super().forward(*args, **kwargs) - if outputs.attention_probs is not None: - return (outputs.hidden_states, outputs.attention_probs) - return (outputs.hidden_states,) - @classmethod def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 530f901a743..bc05d192939 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -21,17 +21,15 @@ ) from allennlp.modules.transformer.util import ( get_extended_attention_mask, + FloatT, + IntT, + BoolT, ) from allennlp.nn.beam_search import BeamSearch if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig -# Unfortunately mypy is insane, so I have to wrap these in unions. -FloatT = Union[torch.FloatTensor] -IntT = Union[torch.IntTensor] -BoolT = Union[torch.BoolTensor] - class T5LayerNorm(TransformerModule, FromParams): """T5-style layer norm does not have bias and does not subtract the mean.""" @@ -155,7 +153,9 @@ def forward( use_cache: bool = False, output_attentions: bool = False, ) -> T5LayerSelfAttentionOutput: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output: AttentionOutput = self.self_attention( normed_hidden_states, mask=attention_mask, @@ -165,13 +165,9 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) + hidden_states = hidden_states + self.dropout(attention_output.hidden_states) - # return T5LayerSelfAttentionOutput( - # hidden_states, - # attention_output.key_value_state, - # attention_output.position_bias, - # attention_output.attn_weights, - # ) + return T5LayerSelfAttentionOutput( hidden_states, attention_output.key_value_state, @@ -231,12 +227,6 @@ def forward( output_attentions=output_attentions, ) layer_output = hidden_states + self.dropout(attention_output.hidden_states) - # return T5LayerCrossAttentionOutput( - # layer_output, - # attention_output.key_value_state, - # attention_output.position_bias, - # attention_output.attn_weights, - # ) return T5LayerCrossAttentionOutput( layer_output, diff --git a/allennlp/modules/transformer/util.py b/allennlp/modules/transformer/util.py index e9797bff68c..7d6a43cf198 100644 --- a/allennlp/modules/transformer/util.py +++ b/allennlp/modules/transformer/util.py @@ -1,8 +1,12 @@ from typing import Union, Tuple import torch - from allennlp.nn.util import min_value_of_dtype +# Unfortunately mypy is insane, so we have to wrap these in unions. +FloatT = Union[torch.FloatTensor] +IntT = Union[torch.IntTensor] +BoolT = Union[torch.BoolTensor] + def apply_mask( values: torch.FloatTensor, mask: Union[torch.BoolTensor, torch.IntTensor, torch.FloatTensor] diff --git a/tests/modules/transformer/transformer_layer_test.py b/tests/modules/transformer/transformer_layer_test.py index 42b9d0c763a..641032ee78d 100644 --- a/tests/modules/transformer/transformer_layer_test.py +++ b/tests/modules/transformer/transformer_layer_test.py @@ -87,7 +87,7 @@ def test_attention_matches_huggingface(attention_params, module_name, hf_module) torch.manual_seed(1234) hf_output = hf_module(hidden_states, attention_mask=attention_mask_hf) - assert torch.allclose(output[0], hf_output[0]) + assert torch.allclose(output.hidden_states, hf_output[0]) @pytest.mark.parametrize( @@ -126,7 +126,7 @@ def test_attention_from_pretrained(pretrained_name, relevant_top_level_module): attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 torch.manual_seed(1234) - output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + output = module(hidden_states, attention_mask=attention_mask.squeeze()).hidden_states torch.manual_seed(1234) hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf)[0] @@ -207,7 +207,7 @@ def test_layer_with_cross_attention(layer_params): transformer_layer( torch.randn(2, 3, 6), attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 2, 3, 3), + encoder_hidden_states=torch.randn(2, 3, 6), ) diff --git a/tests/modules/transformer/transformer_stack_test.py b/tests/modules/transformer/transformer_stack_test.py index 1ca0522cd59..cf42f6c0f6d 100644 --- a/tests/modules/transformer/transformer_stack_test.py +++ b/tests/modules/transformer/transformer_stack_test.py @@ -78,7 +78,7 @@ def test_transformer_stack_with_cross_attention(params): transformer_stack.forward( torch.randn(2, 3, 6), attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 2, 3, 3), + encoder_hidden_states=torch.randn(2, 3, 6), ) From d252af0001a06124b4000cb582fe9e6a061bfa61 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 25 May 2021 12:01:26 -0700 Subject: [PATCH 06/13] bug fix --- .../modules/transformer/attention_module.py | 2 +- .../modules/transformer/transformer_layer.py | 30 ++++++++++++------- .../transformer/self_attention_test.py | 2 +- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/allennlp/modules/transformer/attention_module.py b/allennlp/modules/transformer/attention_module.py index be9230868ed..a27b3296485 100644 --- a/allennlp/modules/transformer/attention_module.py +++ b/allennlp/modules/transformer/attention_module.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Union, Tuple, TYPE_CHECKING +from typing import Optional, Tuple, TYPE_CHECKING from dataclasses import dataclass import torch import torch.nn.functional as F diff --git a/allennlp/modules/transformer/transformer_layer.py b/allennlp/modules/transformer/transformer_layer.py index 914658777a1..e4bff40ad40 100644 --- a/allennlp/modules/transformer/transformer_layer.py +++ b/allennlp/modules/transformer/transformer_layer.py @@ -5,7 +5,7 @@ from allennlp.common import FromParams from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.activation_layer import ActivationLayer -from allennlp.modules.transformer.attention_module import SelfAttention +from allennlp.modules.transformer.attention_module import SelfAttention, AttentionOutput from allennlp.modules.transformer.output_layer import OutputLayer if TYPE_CHECKING: @@ -77,14 +77,20 @@ def forward( self_output = self.self( input_tensor, - encoder_hidden_states, - encoder_hidden_states, + source_states=encoder_hidden_states, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, ) - attention_output = self.output(self_output[0], input_tensor) - outputs = (attention_output,) + self_output[1:] # add attentions if we output them + + attention_output = self.output(self_output.hidden_states, input_tensor) + # outputs = (attention_output, self_output.attention_probs) # add attentions if we output them + outputs = AttentionOutput( + attention_output, + self_output.key_value_state, + self_output.position_bias, + self_output.attention_probs, + ) return outputs @classmethod @@ -196,8 +202,10 @@ def forward( head_mask, output_attentions=output_attentions, ) - attention_output = attention_outputs[0] - outputs = attention_outputs[1:] # add self attentions if we output attention weights + attention_output = attention_outputs.hidden_states + outputs = ( + attention_outputs.attention_probs, + ) # add self attentions if we output attention weights if encoder_hidden_states is not None: assert hasattr( @@ -213,14 +221,14 @@ def forward( encoder_attention_mask, output_attentions, ) - attention_output = cross_attention_outputs[0] - outputs = ( - outputs + cross_attention_outputs[1:] + attention_output = cross_attention_outputs.hidden_states + outputs = outputs + ( # type: ignore + cross_attention_outputs.attention_probs, ) # add cross attentions if we output attention weights intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) - outputs = (layer_output,) + outputs + outputs = (layer_output,) + outputs # type: ignore return outputs @classmethod diff --git a/tests/modules/transformer/self_attention_test.py b/tests/modules/transformer/self_attention_test.py index 597f56b6aa2..af1b4a4c43a 100644 --- a/tests/modules/transformer/self_attention_test.py +++ b/tests/modules/transformer/self_attention_test.py @@ -79,7 +79,7 @@ def test_loading_from_pretrained_weights_using_model_name(pretrained_name, relev pretrained_module = pretrained_module.eval() torch.manual_seed(1234) - output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + output = module(hidden_states, attention_mask=attention_mask.squeeze()).hidden_states if "distilbert" in pretrained_name: torch.manual_seed(1234) hf_output = pretrained_module( From 2eecfc4f55f16b6a519795e0469a198c64ac3507 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 25 May 2021 12:03:40 -0700 Subject: [PATCH 07/13] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 739dff1071e..0f768246687 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences. - Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`. - Added `shuffle` argument to `BucketBatchSampler` which allows for disabling shuffling. +- Added `allennlp.modules.transformer.attention_module` which contains a generalized `AttentionModule`. `SelfAttention` and `T5Attention` both inherit from this. ### Fixed From 86af3cbd1da2549e908b9e4c5d793ada66dc8ed8 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 25 May 2021 12:29:07 -0700 Subject: [PATCH 08/13] fix shape --- allennlp/modules/transformer/attention_module.py | 2 +- allennlp/modules/transformer/transformer_layer.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/allennlp/modules/transformer/attention_module.py b/allennlp/modules/transformer/attention_module.py index a27b3296485..dd1223f9995 100644 --- a/allennlp/modules/transformer/attention_module.py +++ b/allennlp/modules/transformer/attention_module.py @@ -107,7 +107,7 @@ def __init__( # out linear layer for distilbert, T5 etc. if output_linear: - self.output = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias) + self.output = torch.nn.Linear(self.all_head_size, hidden_size, bias=bias) self.scoring_func = scoring_func if self.scoring_func in ["additive", "linear", "bilinear"]: diff --git a/allennlp/modules/transformer/transformer_layer.py b/allennlp/modules/transformer/transformer_layer.py index e4bff40ad40..46f61baa5b7 100644 --- a/allennlp/modules/transformer/transformer_layer.py +++ b/allennlp/modules/transformer/transformer_layer.py @@ -84,7 +84,6 @@ def forward( ) attention_output = self.output(self_output.hidden_states, input_tensor) - # outputs = (attention_output, self_output.attention_probs) # add attentions if we output them outputs = AttentionOutput( attention_output, self_output.key_value_state, From e90574a5cf91db000e91f7ee50751b9a6848affe Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Wed, 26 May 2021 11:52:28 -0700 Subject: [PATCH 09/13] fix format --- allennlp/modules/transformer/attention_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/modules/transformer/attention_module.py b/allennlp/modules/transformer/attention_module.py index dd1223f9995..61d7e78396e 100644 --- a/allennlp/modules/transformer/attention_module.py +++ b/allennlp/modules/transformer/attention_module.py @@ -416,7 +416,7 @@ def _relative_position_bucket( return relative_buckets def compute_bias(self, query_length: int, key_length: int) -> FloatT: - """ Compute binned relative position bias """ + """Compute binned relative position bias""" context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) From 885808fc48cd801c951491e5e6d4904346158c03 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 1 Jun 2021 17:01:14 -0700 Subject: [PATCH 10/13] address feedback --- .../modules/transformer/attention_module.py | 81 ++++++++++--------- .../modules/transformer/bimodal_encoder.py | 12 +-- allennlp/modules/transformer/t5.py | 9 ++- .../modules/transformer/transformer_layer.py | 27 +++++-- .../modules/transformer/transformer_stack.py | 26 ++++-- .../transformer/transformer_layer_test.py | 4 +- .../transformer/transformer_stack_test.py | 4 +- 7 files changed, 99 insertions(+), 64 deletions(-) diff --git a/allennlp/modules/transformer/attention_module.py b/allennlp/modules/transformer/attention_module.py index 61d7e78396e..acc52f11502 100644 --- a/allennlp/modules/transformer/attention_module.py +++ b/allennlp/modules/transformer/attention_module.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from allennlp.common import FromParams +from allennlp.common.checks import ConfigurationError from allennlp.modules.attention import Attention from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.util import apply_mask, FloatT, IntT, BoolT @@ -63,9 +64,9 @@ class AttentionModule(TransformerModule, FromParams): is_cross_attention: `bool` (default = `False`) Whether this module is being used for cross-attention in a decoder stack or not. If `is_cross_attention` is `True`, then `is_decoder` must also be `True`. - has_relative_attention_bias: `bool` (default = `False`) - relative_attention_num_buckets: `int` (default = `32`) - This is ignored if `has_relative_attention_bias` is set to `False`. + relative_attention_num_buckets: `int`, optional (default = `None`) + The number of buckets to use in relative attention; if `None`, relative attention + will not be applied. """ def __init__( @@ -80,21 +81,22 @@ def __init__( normalize_weights: bool = False, is_decoder: bool = False, is_cross_attention: bool = False, - has_relative_attention_bias: bool = False, - relative_attention_num_buckets: int = 32, + relative_attention_num_buckets: Optional[int] = None, ): super().__init__() if hidden_size % num_attention_heads != 0: - raise ValueError( + raise ConfigurationError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (hidden_size, num_attention_heads) ) - if is_cross_attention: - assert is_decoder, "The attention layer can be a cross-attention layer only " - "if it is within a decoder." + if is_cross_attention and not is_decoder: + raise ConfigurationError( + "The attention layer can be a cross-attention layer only " + "if it is within a decoder." + ) self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads @@ -117,10 +119,9 @@ def __init__( else: self.attn = Attention.by_name(self.scoring_func)() - self.has_relative_attention_bias = has_relative_attention_bias self.relative_attention_num_buckets = relative_attention_num_buckets - if self.has_relative_attention_bias: + if self.relative_attention_num_buckets is not None: self.relative_attention_bias = torch.nn.Embedding( self.relative_attention_num_buckets, self.num_attention_heads ) @@ -133,7 +134,7 @@ def __init__( if normalize_weights: self._normalize() - def _normalize(self): + def _normalize(self) -> None: self.query.weight.data.normal_( mean=0.0, std=(self.hidden_size * self.attention_head_size) ** -0.5 ) @@ -145,10 +146,10 @@ def _normalize(self): mean=0.0, std=(self.num_attention_heads * self.attention_head_size) ** -0.5 ) - if hasattr(self, "has_relative_attention_bias") and self.has_relative_attention_bias: + if hasattr(self, "relative_attention_bias"): self.relative_attention_bias.weight.data.normal_(mean=0.0, std=self.hidden_size ** -0.5) - def _transpose_for_scores(self, x: torch.Tensor): + def _transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, @@ -156,7 +157,7 @@ def _transpose_for_scores(self, x: torch.Tensor): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def _query_layer(self, query_states: torch.Tensor): + def _query_layer(self, query_states: torch.Tensor) -> torch.Tensor: mixed_query_layer = self.query(query_states) query_layer = self._transpose_for_scores(mixed_query_layer) return query_layer @@ -167,7 +168,7 @@ def _project( layer: torch.nn.Linear, source_states: Optional[torch.Tensor] = None, past_key_or_value: Optional[torch.Tensor] = None, - ): + ) -> torch.Tensor: # TODO: clarify logic in terms of is_decoder and is_cross_attention # to make it more readable. if source_states is None: @@ -193,15 +194,15 @@ def _project( def _position_bias( self, - position_bias, - seq_lengths, - past_key_states, - attention_scores, - ): + position_bias: Optional[torch.Tensor], + seq_lengths: Tuple[int, int, int], + past_key_states: Optional[torch.Tensor], + attention_scores: torch.Tensor, + ) -> torch.Tensor: seq_length, real_seq_length, key_length = seq_lengths if position_bias is None: - if self.has_relative_attention_bias: + if self.relative_attention_num_buckets is not None: position_bias = self.compute_bias(real_seq_length, key_length) else: position_bias = torch.zeros( @@ -222,8 +223,8 @@ def _get_attention_probs( key_layer: torch.Tensor, attention_mask: torch.Tensor, head_mask: torch.Tensor, + seq_lengths: Tuple[int, int, int], position_bias: Optional[torch.Tensor] = None, - seq_lengths: Optional[Tuple[int, int, int]] = None, past_key_states: Optional[torch.Tensor] = None, **kwargs, ): @@ -233,14 +234,10 @@ def _get_attention_probs( position_bias, seq_lengths, past_key_states, attention_scores ) - if position_bias is not None: - if attention_mask is not None: - # Shape: (batch_size, num_heads, seq_length, key_length) - position_bias = apply_mask(position_bias, attention_mask) - attention_scores += position_bias - else: - if attention_mask is not None: - attention_scores = apply_mask(attention_scores, attention_mask) + if attention_mask is not None: + # Shape: (batch_size, num_heads, seq_length, key_length) + position_bias = apply_mask(position_bias, attention_mask) + attention_scores += position_bias attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) @@ -264,7 +261,13 @@ def _output_layer(self, attention_probs: torch.Tensor, value_layer: torch.Tensor return context_layer - def _get_lengths(self, query_states, past_key_states, source_states, query_length): + def _get_lengths( + self, + query_states: torch.Tensor, + past_key_states: Optional[torch.Tensor] = None, + source_states: Optional[torch.Tensor] = None, + query_length: Optional[int] = None, + ) -> Tuple[int, int, int]: seq_length = query_states.shape[1] effective_seq_len = seq_length @@ -344,8 +347,8 @@ def forward( key_layer, attention_mask, head_mask, - position_bias, seq_lengths, + position_bias, past_key_states, ) @@ -354,6 +357,10 @@ def forward( present_key_value_state = ( (key_layer, value_layer) if (self.is_decoder and use_cache) else None ) + + if not output_attentions: + attention_probs = None + outputs = AttentionOutput( context_layer, present_key_value_state, position_bias, attention_probs ) @@ -423,7 +430,7 @@ def compute_bias(self, query_length: int, key_length: int) -> FloatT: relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), - num_buckets=self.relative_attention_num_buckets, + num_buckets=self.relative_attention_num_buckets, # type: ignore ) relative_position_bucket = relative_position_bucket.to( self.relative_attention_bias.weight.device @@ -460,6 +467,9 @@ def __init__( is_cross_attention: bool = False, ): + if not has_relative_attention_bias: + relative_attention_num_buckets = None # type: ignore + super().__init__( hidden_size=hidden_size, attention_head_size=key_value_proj_dim, @@ -471,7 +481,6 @@ def __init__( normalize_weights=normalize, is_decoder=is_decoder, is_cross_attention=is_cross_attention, - has_relative_attention_bias=has_relative_attention_bias, relative_attention_num_buckets=relative_attention_num_buckets, ) @@ -599,7 +608,7 @@ def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs["num_attention_heads"] = config.num_attention_heads final_kwargs["output_linear"] = hasattr( config, "n_heads" - ) # Since this is the distilbert case. + ) # This is the distilbert case; they have a linear layer as the output. if hasattr(config, "attention_dropout"): final_kwargs["dropout"] = config.attention_dropout else: diff --git a/allennlp/modules/transformer/bimodal_encoder.py b/allennlp/modules/transformer/bimodal_encoder.py index acc993194df..27634fc1f30 100644 --- a/allennlp/modules/transformer/bimodal_encoder.py +++ b/allennlp/modules/transformer/bimodal_encoder.py @@ -150,19 +150,19 @@ def forward( for idx in range(start1, self.fixed_layer1): with torch.no_grad(): - embedding1 = self.layers1[idx](embedding1, attention_mask1)[0] + embedding1 = self.layers1[idx](embedding1, attention_mask1).hidden_states start1 = self.fixed_layer1 for idx in range(start1, end1): - embedding1 = self.layers1[idx](embedding1, attention_mask1)[0] + embedding1 = self.layers1[idx](embedding1, attention_mask1).hidden_states for idx in range(start2, self.fixed_layer2): with torch.no_grad(): - embedding2 = self.layers2[idx](embedding2, attention_mask2)[0] + embedding2 = self.layers2[idx](embedding2, attention_mask2).hidden_states start2 = self.fixed_layer2 for idx in range(start2, end2): - embedding2 = self.layers2[idx](embedding2, attention_mask2)[0] + embedding2 = self.layers2[idx](embedding2, attention_mask2).hidden_states if count == 0 and self.in_batch_pairs: # new batch size is the batch_size ^2 @@ -230,10 +230,10 @@ def forward( all_encoder_layers2.append(embedding2) for idx in range(start2, len(self.layers2)): - embedding2 = self.layers2[idx](embedding2, attention_mask2)[0] + embedding2 = self.layers2[idx](embedding2, attention_mask2).hidden_states for idx in range(start1, len(self.layers1)): - embedding1 = self.layers1[idx](embedding1, attention_mask1)[0] + embedding1 = self.layers1[idx](embedding1, attention_mask1).hidden_states # add the end part to finish. if not output_all_encoded_layers: diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index bc05d192939..7b7cfc21206 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -13,8 +13,7 @@ from allennlp.common import FromParams, Params, Lazy, Registrable from allennlp.common.checks import ConfigurationError -from allennlp.modules.transformer import TransformerModule - +from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.attention_module import ( T5Attention, AttentionOutput, @@ -143,6 +142,10 @@ def __init__( self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.self_attention.hidden_size) self.dropout = nn.Dropout(dropout) + @property + def hidden_size(self) -> int: + return self.self_attention.hidden_size + def forward( self, hidden_states: FloatT, @@ -271,7 +274,7 @@ def __init__( @property def hidden_size(self) -> int: - return self.layer[0].self_attention.hidden_size + return self.layer[0].hidden_size def forward( self, diff --git a/allennlp/modules/transformer/transformer_layer.py b/allennlp/modules/transformer/transformer_layer.py index 46f61baa5b7..f3814e6201e 100644 --- a/allennlp/modules/transformer/transformer_layer.py +++ b/allennlp/modules/transformer/transformer_layer.py @@ -1,4 +1,5 @@ from typing import Union, Optional, TYPE_CHECKING +from dataclasses import dataclass import torch @@ -7,6 +8,7 @@ from allennlp.modules.transformer.activation_layer import ActivationLayer from allennlp.modules.transformer.attention_module import SelfAttention, AttentionOutput from allennlp.modules.transformer.output_layer import OutputLayer +from allennlp.modules.transformer.util import FloatT if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig @@ -105,6 +107,17 @@ def _from_config(cls, config: "PretrainedConfig", **kwargs): return cls(**final_kwargs) +@dataclass +class TransformerLayerOutput: + """ + Encapsulates the outputs of the `TransformerLayer` module. + """ + + hidden_states: FloatT + self_attention_probs: Optional[FloatT] = None + cross_attention_probs: Optional[FloatT] = None + + class TransformerLayer(TransformerModule, FromParams): """ This module is a single transformer layer, mapping to `BertLayer` in the architecture in BERT. @@ -181,7 +194,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - ): + ) -> TransformerLayerOutput: """ # Parameters @@ -202,9 +215,8 @@ def forward( output_attentions=output_attentions, ) attention_output = attention_outputs.hidden_states - outputs = ( - attention_outputs.attention_probs, - ) # add self attentions if we output attention weights + self_attention_probs = attention_outputs.attention_probs + cross_attention_probs = None if encoder_hidden_states is not None: assert hasattr( @@ -221,13 +233,12 @@ def forward( output_attentions, ) attention_output = cross_attention_outputs.hidden_states - outputs = outputs + ( # type: ignore - cross_attention_outputs.attention_probs, - ) # add cross attentions if we output attention weights + cross_attention_probs = cross_attention_outputs.attention_probs intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) - outputs = (layer_output,) + outputs # type: ignore + + outputs = TransformerLayerOutput(layer_output, self_attention_probs, cross_attention_probs) return outputs @classmethod diff --git a/allennlp/modules/transformer/transformer_stack.py b/allennlp/modules/transformer/transformer_stack.py index 7bc4a7247d3..052bd008292 100644 --- a/allennlp/modules/transformer/transformer_stack.py +++ b/allennlp/modules/transformer/transformer_stack.py @@ -1,5 +1,6 @@ -from typing import Union, Optional, TYPE_CHECKING +from typing import Union, Optional, Tuple, TYPE_CHECKING import logging +from dataclasses import dataclass import torch @@ -7,6 +8,7 @@ from allennlp.modules.util import replicate_layers from allennlp.modules.transformer.transformer_layer import TransformerLayer from allennlp.modules.transformer.transformer_module import TransformerModule +from allennlp.modules.transformer.util import FloatT if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig @@ -15,6 +17,18 @@ logger = logging.getLogger(__name__) +@dataclass +class TransformerStackOutput: + """ + Encapsulates the outputs of the `TransformerLayer` module. + """ + + final_hidden_states: FloatT + all_hidden_states: Optional[Tuple] = None + all_self_attentions: Optional[Tuple] = None + all_cross_attentions: Optional[Tuple] = None + + class TransformerStack(TransformerModule, FromParams): """ This module is the basic transformer stack. @@ -87,7 +101,7 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, - ): + ) -> TransformerStackOutput: """ # Parameters @@ -118,7 +132,7 @@ def forward( encoder_attention_mask, output_attentions, ) - hidden_states = layer_outputs[0] + hidden_states = layer_outputs.hidden_states if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) # type: ignore if self._add_cross_attention: @@ -127,10 +141,8 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore - return tuple( - v - for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] - if v is not None + return TransformerStackOutput( + hidden_states, all_hidden_states, all_attentions, all_cross_attentions ) @classmethod diff --git a/tests/modules/transformer/transformer_layer_test.py b/tests/modules/transformer/transformer_layer_test.py index 641032ee78d..538df89c9e0 100644 --- a/tests/modules/transformer/transformer_layer_test.py +++ b/tests/modules/transformer/transformer_layer_test.py @@ -243,7 +243,7 @@ def test_layer_matches_huggingface(layer_params, module_name, hf_module): torch.manual_seed(1234) hf_output = hf_module(hidden_states, attention_mask=attention_mask_hf) - assert torch.allclose(output[0], hf_output[0]) + assert torch.allclose(output.hidden_states, hf_output[0]) @pytest.mark.parametrize( @@ -282,7 +282,7 @@ def test_layer_from_pretrained(pretrained_name, relevant_top_level_module): attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 torch.manual_seed(1234) - output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + output = module(hidden_states, attention_mask=attention_mask.squeeze()).hidden_states torch.manual_seed(1234) hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf)[0] diff --git a/tests/modules/transformer/transformer_stack_test.py b/tests/modules/transformer/transformer_stack_test.py index cf42f6c0f6d..4812a74f58b 100644 --- a/tests/modules/transformer/transformer_stack_test.py +++ b/tests/modules/transformer/transformer_stack_test.py @@ -55,7 +55,7 @@ def test_transformer_stack_from_params(params): hidden_states, attention_mask=attention_mask ) - assert torch.allclose(from_layer_output[0], output[0]) + assert torch.allclose(from_layer_output.final_hidden_states, output.final_hidden_states) # Make sure forward pass raises with bad input. with pytest.raises(AssertionError): @@ -102,7 +102,7 @@ def test_loading_from_pretrained(pretrained_model_name): torch.manual_seed(SEED) hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf) - assert torch.allclose(output[0], hf_output[0]) + assert torch.allclose(output.final_hidden_states, hf_output[0]) def test_loading_partial_pretrained_weights(): From 7d08c68b01f0f7d3348c00f19a14089e9b396c0a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 2 Jun 2021 13:45:47 -0700 Subject: [PATCH 11/13] small doc fix --- allennlp/modules/transformer/attention_module.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/allennlp/modules/transformer/attention_module.py b/allennlp/modules/transformer/attention_module.py index acc52f11502..4d98caba4b2 100644 --- a/allennlp/modules/transformer/attention_module.py +++ b/allennlp/modules/transformer/attention_module.py @@ -50,7 +50,8 @@ class AttentionModule(TransformerModule, FromParams): The number of attention heads. scoring_func: `str` (default = `scaled_dot_product`) The name of the attention-calculating function to be used. - Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`. + Eg. `additive`, `linear`, etc. For a complete list, please check + :mod:`allennlp.modules.attention.attention`. output_linear: `bool` (default = `False`) Whether to add an additional output linear layer at the end. dropout: `float` (default = `0.0`) @@ -296,6 +297,8 @@ def forward( query_length: Optional[int] = None, ): """ + # Parameters + query_states : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` past_key_states : `torch.Tensor`, optional @@ -563,7 +566,8 @@ class SelfAttention(AttentionModule): dropout: `float` (default = `0.0`) scoring_func: `str` (default = `scaled_dot_product`) The name of the attention-calculating function to be used. - Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`. + Eg. `additive`, `linear`, etc. For a complete list, please check + :mod:`allennlp.modules.attention.attention`. """ _pretrained_relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] From e455b6a47459b1343a853eee7db7201beaac3a27 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Wed, 2 Jun 2021 13:54:34 -0700 Subject: [PATCH 12/13] Update allennlp/modules/transformer/transformer_stack.py Co-authored-by: Pete --- allennlp/modules/transformer/transformer_stack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/modules/transformer/transformer_stack.py b/allennlp/modules/transformer/transformer_stack.py index 052bd008292..3825ccac7cb 100644 --- a/allennlp/modules/transformer/transformer_stack.py +++ b/allennlp/modules/transformer/transformer_stack.py @@ -20,7 +20,7 @@ @dataclass class TransformerStackOutput: """ - Encapsulates the outputs of the `TransformerLayer` module. + Encapsulates the outputs of the `TransformerStack` module. """ final_hidden_states: FloatT From da08ac20e0336ffc41502cf4b6cd7c6bc96ee244 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Wed, 2 Jun 2021 13:57:48 -0700 Subject: [PATCH 13/13] remove old file --- .../modules/transformer/self_attention.py | 161 ------------------ 1 file changed, 161 deletions(-) delete mode 100644 allennlp/modules/transformer/self_attention.py diff --git a/allennlp/modules/transformer/self_attention.py b/allennlp/modules/transformer/self_attention.py deleted file mode 100644 index d464012de81..00000000000 --- a/allennlp/modules/transformer/self_attention.py +++ /dev/null @@ -1,161 +0,0 @@ -from typing import Optional, TYPE_CHECKING - -import torch - -from allennlp.common import FromParams -from allennlp.modules.attention import Attention -from allennlp.modules.transformer.transformer_module import TransformerModule -from allennlp.modules.transformer.util import apply_mask - -if TYPE_CHECKING: - from transformers.configuration_utils import PretrainedConfig - - -class SelfAttention(TransformerModule, FromParams): - """ - This module computes the self-attention, similar to the architecture in BERT. Additionally, the attention - scoring function can be specified. - Details in the paper: - [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, 2019] - (https://api.semanticscholar.org/CorpusID:52967399) - - # Parameters - - hidden_size: `int` - num_attention_heads: `int` - dropout: `float` (default = `0.0`) - scoring_func: `str` (default = `scaled_dot_product`) - The name of the attention-calculating function to be used. - Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`. - """ - - _pretrained_relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] - _pretrained_mapping = { - "layer": "layers", - "q_lin": "query", - "k_lin": "key", - "v_lin": "value", - "out_lin": "output", - "transformer": "encoder", - } - - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - dropout: float = 0.0, - scoring_func: str = "scaled_dot_product", - output_linear: bool = False, - ): - super().__init__() - if hidden_size % num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (hidden_size, num_attention_heads) - ) - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = torch.nn.Linear(hidden_size, self.all_head_size) - self.key = torch.nn.Linear(hidden_size, self.all_head_size) - self.value = torch.nn.Linear(hidden_size, self.all_head_size) - - self.scoring_func = scoring_func - if self.scoring_func in ["additive", "linear", "bilinear"]: - self.attn = Attention.by_name(self.scoring_func)(hidden_size, hidden_size) - elif self.scoring_func == "scaled_dot_product": - self.attn = Attention.by_name(self.scoring_func)(self.attention_head_size, False) - else: - self.attn = Attention.by_name(self.scoring_func)() - - # out linear layer for distilbert. - if output_linear: - self.output = torch.nn.Linear(hidden_size, self.all_head_size) - - self.dropout = torch.nn.Dropout(dropout) - - def _transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - query_states: torch.Tensor, - key_states: Optional[torch.Tensor] = None, - value_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ): - """ - # Parameters - - query_states : `torch.Tensor` - Shape `batch_size x seq_len x hidden_dim` - key_states : `torch.Tensor`, optional - Shape `batch_size x seq_len x hidden_dim` - value_states : `torch.Tensor`, optional - Shape `batch_size x seq_len x hidden_dim` - attention_mask : `torch.BoolTensor`, optional - Shape `batch_size x seq_len` - head_mask : `torch.BoolTensor`, optional - output_attentions : `bool` - Whether to also return the attention probabilities, default = `False` - """ - if key_states is None: - key_states = query_states - if value_states is None: - value_states = query_states - - mixed_query_layer = self.query(query_states) - mixed_key_layer = self.key(key_states) - mixed_value_layer = self.value(value_states) - - query_layer = self._transpose_for_scores(mixed_query_layer) - key_layer = self._transpose_for_scores(mixed_key_layer) - value_layer = self._transpose_for_scores(mixed_value_layer) - - attention_scores = self.attn(query_layer, key_layer.transpose(-1, -2)) - - if attention_mask is not None: - attention_scores = apply_mask(attention_scores, attention_mask) - - attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - if hasattr(self, "output"): - context_layer = self.output(context_layer) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs - - @classmethod - def _from_config(cls, config: "PretrainedConfig", **kwargs): - final_kwargs = {} - final_kwargs["hidden_size"] = config.hidden_size - final_kwargs["num_attention_heads"] = config.num_attention_heads - final_kwargs["output_linear"] = hasattr( - config, "n_heads" - ) # Since this is the distilbert case. - if hasattr(config, "attention_dropout"): - final_kwargs["dropout"] = config.attention_dropout - else: - final_kwargs["dropout"] = config.attention_probs_dropout_prob - final_kwargs.update(**kwargs) - return cls(**final_kwargs)