Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit dcec284

Browse files
dirkgrepwalshjacob-morrisonnelson-liuAkshitaB
committed
T5 (#4969)
* Formatting * New activation functions * Makes position embeddings optional in the transformer embeddings * Adds T5 * Various fixes to make this start up * Share weights * Adds one test that passes, and one test that fails * use min_value_of_dtype in apply_mask * fixes, add beam search * encoder fixes * fix * fix beam search * fix tests * rename to just 'T5' * fix initialization from pretrained * add Model, DatasetReader, and Predictor * remove useless dataset reader * move high-level peices to allennlp-models * revert predictor changes * remove unneeded hidden_size * remove stray comment * bool masks * CHANGELOG * fix test file name * revert other change * revert other change * Distributed training with gradient accumulation (#5100) * Fixes distributed training with gradient accumulation * Fix in case we don't do anything in a batch group * Test for the problematic condition * Formatting * More formatting * Changelog * Fix another test * Fix even more tests * Fixes one more test * I can fix these tests all day. * Add link to gallery and demo in README (#5103) * Add link to gallery in README * Update README.md * try emojis Is this overkill? * Adding a metadata field to the basic classifier (#5104) * Adding metadata parameter to BasicClassifier * Fix * Updating the changelog * reformatting * updating parameter type * fixing import Co-authored-by: Dirk Groeneveld <[email protected]> * additional W&B params (#5114) * additional W&B params * add wandb_kwargs * fix * fix docs * Add eval_mode argument to pretrained transformer embedder (#5111) * Add eval_mode argument to pretrained transformer embedder * Edit changelog entry * Lint * Update allennlp/modules/token_embedders/pretrained_transformer_embedder.py * Apply suggestions from code review Co-authored-by: Evan Pete Walsh <[email protected]> Co-authored-by: Evan Pete Walsh <[email protected]> * specify 'truncation' to avoid transformers warning (#5120) * specify 'truncation' to avoid transformers warning * Update docs * Remove `stride` param * Update CHANGELOG.md Co-authored-by: Dirk Groeneveld <[email protected]> * Predicting with a dataset reader on a multitask model (#5115) * Create a way to use allennlp predict with a dataset and a multitask model * Fix type ignoration * Changelog * Fix to the predictor * fix bug with interleaving dataset reader (#5122) * fix bug with interleaving dataset reader * more tests * Update allennlp/data/dataset_readers/interleaving_dataset_reader.py * Update allennlp/data/dataset_readers/interleaving_dataset_reader.py * remove jsonpickle from dependencies (#5121) Co-authored-by: Dirk Groeneveld <[email protected]> * Update docstring for basic_classifier (#5124) * improve error message from Registrable class (#5125) Co-authored-by: Akshita Bhagia <[email protected]> * Prepare for release v2.3.0 * fix docs CI * Take the number of runs in the test for distributed metrics (#5127) * Take the number of runs in the test for distributed metrics * Changelog * Add influence functions to interpret module (#4988) * creating a new functionality to fields and instances to support outputing instnaces to json files * creating tests for the new functionality * fixing docs * Delete __init__.py * Delete influence_interpreter.py * Delete use_if.py * Delete simple_influence_test.py * fixing docs * finishing up SimpleInfluence * passing lint * passing format * making small progress in coding * Delete fast_influence.py Submit to the wrong branch * Delete faiss_utils.py wrong branch * Delete gpt2_bug.py not sure why it's included * Delete text_class.py not sure why it's included * adding test file * adding testing files * deleted unwanted files * deleted unwanted files and rearrange test files * small bug * adjust function call to save instance in json * Update allennlp/interpret/influence_interpreters/influence_interpreter.py Co-authored-by: Evan Pete Walsh <[email protected]> * Update allennlp/interpret/influence_interpreters/influence_interpreter.py Co-authored-by: Evan Pete Walsh <[email protected]> * Update allennlp/interpret/influence_interpreters/influence_interpreter.py Co-authored-by: Evan Pete Walsh <[email protected]> * move some documentation of parameters to base class * delete one comment * delete one deprecated abstract method * changing interface * formatting * formatting err * passing mypy * passing mypy * passing mypy * passing mypy * passing integration test * passing integration test * adding a new option to the do-all function * modifying the callable function to the interface * update API, fixes * doc fixes * add `from_path` and `from_archive` methods * fix docs, improve logging * add test * address @matt-gardner's comments * fixes to documentation * update docs Co-authored-by: Evan Pete Walsh <[email protected]> Co-authored-by: Evan Pete Walsh <[email protected]> * Update CONTRIBUTING.md (#5133) * Update CONTRIBUTING.md * updated changelog Co-authored-by: Akshita Bhagia <[email protected]> Co-authored-by: Arjun Subramonian <[email protected]> * fix #5132 (#5134) * fix * Prepare for release v2.3.1 * Fairness Metrics (#5093) * Added three definitions of fairness * Updated CHANGELOG * Added DemographicParityWithoutGroundTruth and finished tests * finished refactoring Independence, Separation, and Sufficiency to accumulate * added distributed functionality to Independence, Sufficiency, and Separation * Finished aggregate and distributed functionality for DemographicParityWithoutGroundTruth * fixed GPU and doc issues * fixed GPU and doc issues * fixed GPU and doc issues * fixed GPU issues * fixed GPU issues * added init file * fixed typo * minor docstring changes * minor changes to docstring * Added simple explanations of fairness metrics to docstrings * Further vectorized all metric implementations * Fixed device issue Co-authored-by: Arjun Subramonian <[email protected]> Co-authored-by: Akshita Bhagia <[email protected]> Co-authored-by: Dirk Groeneveld <[email protected]> * fix cached_path for hub downloads (#5141) * fix cached_path for hub downloads * fix test name * fix type hint * Update allennlp/common/file_utils.py Co-authored-by: Lysandre Debut <[email protected]> Co-authored-by: Lysandre Debut <[email protected]> * fix * fix Co-authored-by: epwalsh <[email protected]> Co-authored-by: Evan Pete Walsh <[email protected]> Co-authored-by: Jacob Morrison <[email protected]> Co-authored-by: Nelson Liu <[email protected]> Co-authored-by: Akshita Bhagia <[email protected]> Co-authored-by: Leo Liu <[email protected]> Co-authored-by: ArjunSubramonian <[email protected]> Co-authored-by: Arjun Subramonian <[email protected]> Co-authored-by: Arjun Subramonian <[email protected]> Co-authored-by: Lysandre Debut <[email protected]>
1 parent cd96669 commit dcec284

File tree

9 files changed

+1521
-34
lines changed

9 files changed

+1521
-34
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
### Added
11+
12+
- Added a T5 implementation to `modules.transformers`.
13+
1014
### Fixed
1115

1216
- Fixed `cached_path()` for "hf://" files.

allennlp/modules/transformer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,4 @@ def forward(self, token_ids: torch.LongTensor, mask: torch.BoolTensor):
140140

141141
from allennlp.modules.transformer.bimodal_attention import BiModalAttention
142142
from allennlp.modules.transformer.bimodal_encoder import BiModalEncoder
143+
from allennlp.modules.transformer.t5 import T5

allennlp/modules/transformer/t5.py

Lines changed: 1264 additions & 0 deletions
Large diffs are not rendered by default.

allennlp/modules/transformer/transformer_embeddings.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,14 @@ def __init__(
119119
dropout: float = 0.1,
120120
output_size: Optional[int] = None,
121121
):
122-
123122
embedding_dict = {}
124123

125124
word_embeddings = torch.nn.Embedding(vocab_size, embedding_size, padding_idx=pad_token_id)
126125
embedding_dict["word_embeddings"] = word_embeddings
127126

128-
position_embeddings = torch.nn.Embedding(max_position_embeddings, embedding_size)
129-
embedding_dict["position_embeddings"] = position_embeddings
127+
if max_position_embeddings > 0:
128+
position_embeddings = torch.nn.Embedding(max_position_embeddings, embedding_size)
129+
embedding_dict["position_embeddings"] = position_embeddings
130130

131131
if type_vocab_size > 0:
132132
token_type_embeddings = torch.nn.Embedding(type_vocab_size, embedding_size)
@@ -163,16 +163,15 @@ def forward( # type: ignore
163163

164164
embedding_inputs = [input_ids]
165165

166-
if position_ids is None:
167-
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
168-
position_ids = position_ids.unsqueeze(0).expand(input_shape)
169-
170-
embedding_inputs.append(position_ids)
171-
172-
if token_type_ids is None:
173-
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
166+
if "position_embeddings" in self.embeddings:
167+
if position_ids is None:
168+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
169+
position_ids = position_ids.unsqueeze(0).expand(input_shape)
170+
embedding_inputs.append(position_ids)
174171

175-
if len(self.embeddings) == 3:
172+
if "token_type_embeddings" in self.embeddings:
173+
if token_type_ids is None:
174+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
176175
embedding_inputs.append(token_type_ids)
177176

178177
embeddings = super().forward(*embedding_inputs)

allennlp/modules/transformer/transformer_module.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Dict, Union, List
1+
from typing import Optional, Dict, Union, List, Any
22
import logging
33
import inspect
44

@@ -32,23 +32,26 @@ def __init__(self, *args, **kwargs):
3232
def _get_mapping(
3333
cls,
3434
pretrained_module: Optional[torch.nn.Module] = None,
35-
source="huggingface",
35+
source: str = "huggingface",
3636
mapping: Optional[Dict[str, str]] = None,
3737
):
3838
"""
3939
Returns the mapping to be used, based on the optional `pretrained_module`.
4040
If `pretrained_module` is not given, the default module-level mapping is returned.
4141
"""
4242
combined_mapping = {}
43-
if "huggingface" in source:
43+
if "huggingface" == source:
4444
combined_mapping.update(cls._huggingface_mapping)
4545
if mapping is not None:
4646
combined_mapping.update(mapping)
4747
return combined_mapping
4848

4949
@classmethod
5050
def _get_mapped_submodules(
51-
cls, pretrained_module, source="huggingface", mapping: Optional[Dict[str, str]] = None
51+
cls,
52+
pretrained_module: torch.nn.Module,
53+
source: str = "huggingface",
54+
mapping: Optional[Dict[str, str]] = None,
5255
):
5356
"""
5457
Subclasses overload this method, and provide appropriate name mapping based on the source.
@@ -64,7 +67,7 @@ def _get_mapped_submodules(
6467

6568
def _construct_default_mapping(
6669
self,
67-
pretrained_module,
70+
pretrained_module: torch.nn.Module,
6871
source: str = "huggingface",
6972
mapping: Optional[Dict[str, str]] = None,
7073
):
@@ -127,10 +130,10 @@ def _load_from_pretrained_module(
127130
def _get_input_arguments(
128131
cls,
129132
pretrained_module: torch.nn.Module,
130-
source="huggingface",
133+
source: str = "huggingface",
131134
mapping: Optional[Dict[str, str]] = None,
132135
**kwargs,
133-
):
136+
) -> Dict[str, Any]:
134137
"""
135138
Constructs the arguments required for instantiating an object of this class, using
136139
the values from `pretrained_module`.
@@ -142,7 +145,7 @@ def get_relevant_module(
142145
cls,
143146
pretrained_module: Union[str, torch.nn.Module],
144147
relevant_module: Optional[Union[str, List[str]]] = None,
145-
source="huggingface",
148+
source: str = "huggingface",
146149
mapping: Optional[Dict[str, str]] = None,
147150
):
148151
"""
@@ -187,7 +190,7 @@ def get_relevant_module(
187190
def from_pretrained_module(
188191
cls,
189192
pretrained_module: Union[str, torch.nn.Module],
190-
source="huggingface",
193+
source: str = "huggingface",
191194
mapping: Optional[Dict[str, str]] = None,
192195
**kwargs,
193196
):

allennlp/modules/transformer/util.py

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import Union
1+
from typing import Union, Tuple
22
import torch
33

4+
from allennlp.nn.util import min_value_of_dtype
5+
46

57
def apply_mask(
68
values: torch.FloatTensor, mask: Union[torch.BoolTensor, torch.IntTensor, torch.FloatTensor]
@@ -13,13 +15,87 @@ def apply_mask(
1315
mask : `torch.BoolTensor`
1416
Shape `batch_size x target_seq_len` OR `batch_size x 1 x 1 x target_seq_len`
1517
"""
16-
if len(mask.shape) == 2:
17-
# We create a 4D attention mask from a 2D tensor mask.
18+
# We create a 4D attention mask from a 2D or 3D tensor mask.
19+
if mask.dim() == 2:
1820
# The shape is `batch_size x 1 x 1 x target_seq_len` which is broadcast
1921
# to `batch_size x num_attention_heads x source_seq_len x target_seq_len`
20-
mask = mask.unsqueeze(1).unsqueeze(2)
21-
# `mask==1` to convert float tensors.
22-
mask = (
23-
~(mask == 1)
24-
) * -10e5 # -10e5 to ensure that the model also works in half-precision mode.
22+
mask = mask[:, None, None, :]
23+
elif mask.dim() == 3:
24+
mask = mask[:, None, :, :]
25+
mask = mask.to(values.dtype)
26+
mask = (1.0 - mask) * min_value_of_dtype(values.dtype)
2527
return values + mask
28+
29+
30+
def get_extended_attention_mask(
31+
attention_mask: torch.Tensor,
32+
input_shape: Tuple[int, ...],
33+
dtype: torch.dtype,
34+
is_decoder: bool = False,
35+
) -> torch.Tensor:
36+
"""
37+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
38+
39+
# Parameters
40+
41+
attention_mask : `torch.Tensor`
42+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
43+
input_shape : `Tuple[int, ...]`
44+
The shape of the input to the model.
45+
dtype : `torch.dtype`
46+
The datatype of the resulting mask.
47+
is_decoder : `bool`, optional (default = `False`)
48+
If this is for a decoder stack.
49+
50+
# Returns
51+
52+
`torch.Tensor`
53+
The extended attention mask, with a the same dtype as `attention_mask.dtype`.
54+
"""
55+
# Adapted from https://github.com/huggingface/transformers/blob/
56+
# 4c32f9f26e6a84f0d9843fec8757e6ce640bb44e/src/transformers/modeling_utils.py#L221.
57+
58+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
59+
# ourselves in which case we just need to make it broadcastable to all heads.
60+
if attention_mask.dim() == 3:
61+
extended_attention_mask = attention_mask[:, None, :, :]
62+
elif attention_mask.dim() == 2:
63+
# Provided a padding mask of dimensions [batch_size, seq_length]
64+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
65+
# - if the model is an encoder, make the mask broadcastable to
66+
# `(batch_size, num_heads, seq_length, seq_length)`
67+
if is_decoder:
68+
batch_size, seq_length = input_shape
69+
seq_ids = torch.arange(seq_length, device=attention_mask.device)
70+
causal_mask = (
71+
seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
72+
)
73+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
74+
# causal and attention masks must have same type with pytorch version < 1.3
75+
causal_mask = causal_mask.to(attention_mask.dtype)
76+
77+
if causal_mask.shape[1] < attention_mask.shape[1]:
78+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
79+
causal_mask = torch.cat(
80+
[
81+
torch.ones(
82+
(batch_size, seq_length, prefix_seq_len),
83+
device=attention_mask.device,
84+
dtype=causal_mask.dtype,
85+
),
86+
causal_mask,
87+
],
88+
axis=-1,
89+
)
90+
91+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
92+
else:
93+
extended_attention_mask = attention_mask[:, None, None, :]
94+
else:
95+
raise ValueError(
96+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
97+
input_shape, attention_mask.shape
98+
)
99+
)
100+
101+
return extended_attention_mask

allennlp/nn/activations.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
[PyTorch activations](https://pytorch.org/docs/master/nn.html#non-linear-activations).
66
Here we provide a thin wrapper to allow registering them and instantiating them `from_params`.
77
8-
The available activation functions are
8+
The available activation functions include
99
1010
* "linear"
1111
* ["mish"](https://arxiv.org/abs/1908.08681)
@@ -27,6 +27,8 @@
2727
* ["selu"](https://pytorch.org/docs/master/nn.html#torch.nn.SELU)
2828
"""
2929

30+
import math
31+
3032
import torch
3133

3234
from allennlp.common import Registrable
@@ -86,3 +88,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8688
class SwishActivation(Activation):
8789
def forward(self, x: torch.Tensor) -> torch.Tensor:
8890
return x * torch.sigmoid(x)
91+
92+
93+
@Activation.register("gelu_new")
94+
class GeluNew(Activation):
95+
"""
96+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also
97+
see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
98+
"""
99+
100+
def forward(self, x: torch.Tensor) -> torch.Tensor:
101+
return (
102+
0.5
103+
* x
104+
* (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
105+
)
106+
107+
108+
@Activation.register("gelu_fast")
109+
class GeluFast(Activation):
110+
def forward(self, x: torch.Tensor) -> torch.Tensor:
111+
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))

tests/modules/transformer/self_attention_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
from allennlp.common import Params
66
from allennlp.common import cached_transformers
7-
from allennlp.common.testing import assert_equal_parameters
8-
7+
from allennlp.common.testing import assert_equal_parameters, AllenNlpTestCase
98
from allennlp.modules.transformer import SelfAttention
10-
from allennlp.common.testing import AllenNlpTestCase
9+
from allennlp.nn.util import min_value_of_dtype
1110

1211
from transformers.models.bert.configuration_bert import BertConfig
1312
from transformers.models.bert.modeling_bert import BertSelfAttention
@@ -160,7 +159,7 @@ def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name)
160159
)[0]
161160
else:
162161
# The attn_mask is processed outside the self attention module in HF bert models.
163-
attention_mask = (~(attention_mask == 1)) * -10e5
162+
attention_mask = (~(attention_mask == 1)) * min_value_of_dtype(hidden_states.dtype)
164163
torch.manual_seed(1234)
165164
hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask)[0]
166165

0 commit comments

Comments
 (0)