Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Changes and improvements to how we initialize transformer modules from pretrained models #5200

Merged
merged 27 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ jobs:
run: |
git clone https://github.com/allenai/allennlp-models.git
cd allennlp-models
# TODO: remove this
git checkout transformer-init
pip install --upgrade --upgrade-strategy eager -e . -r dev-requirements.txt

- name: Run models tests
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Use `dist_reduce_sum` in distributed metrics.
- Allow Google Cloud Storage paths in `cached_path` ("gs://...").
- Renamed `nn.util.load_state_dict()` to `read_state_dict` to avoid confusion with `torch.nn.Module.load_state_dict()`.
- `TransformerModule.from_pretrained_module` now only accepts a pretrained model ID (e.g. "bert-base-case") instead of
an actual `torch.nn.Module`. Other parameters to this method have changed as well.
- Print the first batch to the console by default.
- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0).

### Added

- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module.
- Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files.
- Added `nn.util.distributed_device()` helper function.
- Added `allennlp.nn.util.load_state_dict` helper function.
- Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers
such as the `PretrainedTransformerEmbedder` and `PretrainedTransformerMismatchedEmbedder`.
Expand Down
6 changes: 3 additions & 3 deletions allennlp/commands/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from allennlp.commands.subcommand import Subcommand
from allennlp.common.file_utils import cached_path
from allennlp.nn.util import load_state_dict
from allennlp.nn.util import read_state_dict


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -249,10 +249,10 @@ def _get_checkpoint_path(checkpoint: str) -> str:
def _diff(args: argparse.Namespace):
checkpoint_1_path = _get_checkpoint_path(args.checkpoint1)
checkpoint_2_path = _get_checkpoint_path(args.checkpoint2)
checkpoint_1 = load_state_dict(
checkpoint_1 = read_state_dict(
checkpoint_1_path, strip_prefix=args.strip_prefix_1, strict=False
)
checkpoint_2 = load_state_dict(
checkpoint_2 = read_state_dict(
checkpoint_2_path, strip_prefix=args.strip_prefix_2, strict=False
)
for step in checkpoint_diff(checkpoint_1, checkpoint_2, args.scale, args.threshold):
Expand Down
9 changes: 8 additions & 1 deletion allennlp/common/testing/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,19 @@ def run_distributed_test(

func: `Callable`
`func` needs to be global for spawning the processes, so that it can be pickled.

start_method: `Optional[str]`, optional (default = `None`)
The start method to use for starting the workers. Defaults to "spawn" for GPU
processes and fork otherwise.
"""
device_ids = device_ids or [-1, -1]
check_for_gpu(device_ids)
# "fork" start method is the default and should be preferred, except when we're
# running the tests on GPU, in which case we need to use "spawn".
start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork"
if "start_method" in kwargs:
start_method = kwargs.pop("start_method")
else:
start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork"
Comment on lines +73 to +76
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of hacky, but I had an issue on Mac where the only work-around was the use "spawn" as the start method, and I didn't want to change the signature of this function because it's used in a lot of places.

nprocs = world_size = len(device_ids)
mp.start_processes(
init_process,
Expand Down
12 changes: 12 additions & 0 deletions allennlp/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,18 @@ def is_distributed() -> bool:
return dist.is_available() and dist.is_initialized()


def is_global_primary() -> bool:
"""
Checks if the distributed process group is the global primary (rank = 0).
If the distributed process group is not available or has not been initialized,
this trivially returns `True`.
"""
if not is_distributed():
return True
else:
return dist.get_rank() == 0


def sanitize_wordpiece(wordpiece: str) -> str:
"""
Sanitizes wordpieces from BERT, RoBERTa or ALBERT tokenizers.
Expand Down
2 changes: 1 addition & 1 deletion allennlp/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _load(

# Load state dict. We pass `strict=False` so PyTorch doesn't raise a RuntimeError
# if the state dict is missing keys because we handle this case below.
model_state = util.load_state_dict(weights_file, cuda_device=cuda_device)
model_state = util.read_state_dict(weights_file, cuda_device=cuda_device)
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)

# Modules might define a class variable called `authorized_missing_keys`,
Expand Down
52 changes: 10 additions & 42 deletions allennlp/modules/backbones/vilbert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from allennlp.data.fields.text_field import TextFieldTensors
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.backbones.backbone import Backbone
from allennlp.modules.transformer import BiModalEncoder, ImageFeatureEmbeddings, Embeddings
from allennlp.modules.transformer import (
BiModalEncoder,
ImageFeatureEmbeddings,
TransformerEmbeddings,
TransformerPooler,
)

logger = logging.getLogger(__name__)

Expand All @@ -23,7 +28,7 @@ class VilbertBackbone(Backbone):
def __init__(
self,
vocab: Vocabulary,
text_embeddings: Embeddings,
text_embeddings: TransformerEmbeddings,
image_embeddings: ImageFeatureEmbeddings,
encoder: BiModalEncoder,
pooled_output_dim: int,
Expand All @@ -36,7 +41,6 @@ def __init__(
self.text_embeddings = text_embeddings
self.image_embeddings = image_embeddings
self.encoder = encoder
from allennlp.modules.transformer import TransformerPooler

self.t_pooler = TransformerPooler(encoder.hidden_size1, pooled_output_dim)
self.v_pooler = TransformerPooler(encoder.hidden_size2, pooled_output_dim)
Expand Down Expand Up @@ -66,44 +70,7 @@ def from_huggingface_model_name(
image_fixed_layer: int,
fusion_method: str = "sum",
):
from transformers import AutoModel

transformer = AutoModel.from_pretrained(model_name)

from copy import deepcopy

text_embeddings = deepcopy(transformer.embeddings)

# Albert (and maybe others?) has this "embedding_size", that's different from "hidden_size".
# To get them to the same dimensionality, it uses a linear transform after the embedding
# layer, which we need to pull out and copy here.
if hasattr(transformer.config, "embedding_size"):
config = transformer.config

from transformers.models.albert.modeling_albert import AlbertModel

if isinstance(transformer, AlbertModel):
linear_transform = deepcopy(transformer.encoder.embedding_hidden_mapping_in)
else:
logger.warning(
"Unknown model that uses separate embedding size; weights of the linear "
f"transform will not be initialized. Model type is: {transformer.__class__}"
)
linear_transform = torch.nn.Linear(config.embedding_dim, config.hidden_dim)

# We can't just use torch.nn.Sequential here, even though that's basically all this is,
# because Sequential doesn't accept *inputs, only a single argument.

class EmbeddingsShim(torch.nn.Module):
def __init__(self, embeddings: torch.nn.Module, linear_transform: torch.nn.Module):
super().__init__()
self.linear_transform = linear_transform
self.embeddings = embeddings

def forward(self, *inputs, **kwargs):
return self.linear_transform(self.embeddings(*inputs, **kwargs))

text_embeddings = EmbeddingsShim(text_embeddings, linear_transform)
text_embeddings = TransformerEmbeddings.from_pretrained_module(model_name)

image_embeddings = ImageFeatureEmbeddings(
feature_size=image_feature_dim,
Expand All @@ -112,7 +79,7 @@ def forward(self, *inputs, **kwargs):
)

encoder = BiModalEncoder.from_pretrained_module(
pretrained_module=transformer,
model_name,
num_hidden_layers2=image_num_hidden_layers,
hidden_size2=image_hidden_size,
num_attention_heads2=image_num_attention_heads,
Expand All @@ -126,6 +93,7 @@ def forward(self, *inputs, **kwargs):
fixed_layer1=text_fixed_layer,
fixed_layer2=image_fixed_layer,
)

return cls(
vocab=vocab,
text_embeddings=text_embeddings,
Expand Down
7 changes: 5 additions & 2 deletions allennlp/modules/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,12 @@ def forward(self, token_ids: torch.LongTensor, mask: torch.BoolTensor):
```
"""

from allennlp.modules.transformer.layer_norm import LayerNorm
from allennlp.modules.transformer.positional_encoding import SinusoidalPositionalEncoding

from allennlp.modules.transformer.transformer_module import TransformerModule
from allennlp.modules.transformer.transformer_module import (
TransformerModule,
DistributedLoadingStrategy,
)
from allennlp.modules.transformer.transformer_embeddings import (
Embeddings,
TransformerEmbeddings,
Expand Down
5 changes: 3 additions & 2 deletions allennlp/modules/transformer/bimodal_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@ def forward(
input_tensor2,
attention_mask1=None,
attention_mask2=None,
co_attention_mask=None,
co_attention_mask=None, # TODO: is this flag necessary?
use_co_attention_mask=False,
):
"""
# Parameters

input_tensor1 : `torch.Tensor`
Shape `batch_size x seq_len1 x hidden_dim1`
where `seq_len1` can be the sequence length
Expand All @@ -143,7 +145,6 @@ def forward(
if you know which words correspond to which regions in the image,
this mask can be applied to limit the attention given the bias.
use_co_attention_mask : `bool`
# TODO: is this flag necessary?
Whether to use co_attention_mask or not, default = `False`.
"""

Expand Down
2 changes: 1 addition & 1 deletion allennlp/modules/transformer/bimodal_connection_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def forward(self, hidden_states1, input_tensor1, hidden_states2, input_tensor2):

class BiModalConnectionLayer(TransformerModule, FromParams):

_huggingface_mapping = {"biAttention": "bimodal_attention", "biOutput": "bimodal_output"}
_pretrained_mapping = {"biAttention": "bimodal_attention", "biOutput": "bimodal_output"}

def __init__(
self,
Expand Down
110 changes: 17 additions & 93 deletions allennlp/modules/transformer/bimodal_encoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Optional, Dict, List, Union
from typing import Optional, List, TYPE_CHECKING

import torch

from allennlp.common import FromParams

from allennlp.modules.util import replicate_layers

from allennlp.modules.transformer.transformer_layer import TransformerLayer
from allennlp.modules.transformer.bimodal_connection_layer import BiModalConnectionLayer
from allennlp.modules.transformer.transformer_module import TransformerModule

if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig


class BiModalEncoder(TransformerModule, FromParams):
"""
Expand Down Expand Up @@ -46,8 +48,9 @@ class BiModalEncoder(TransformerModule, FromParams):
in_batch_pairs: `bool` (default = `False`)
"""

_huggingface_mapping = {"layer": "layers1"}
_relevant_module = "encoder"
_pretrained_mapping = {"layer": "layers1"}
_pretrained_relevant_module = ["encoder", "bert.encoder"]
_pretrained_allow_missing = [r"^layers2\..*", r"^c_layer\..*"]

def __init__(
self,
Expand Down Expand Up @@ -243,93 +246,14 @@ def forward(
)

@classmethod
def _get_input_arguments(
cls,
pretrained_module: torch.nn.Module,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
The `pretrained_module` only supplies one of the modalities.
"""
submodules = cls._get_mapped_submodules(pretrained_module, source, mapping)

def _from_config(cls, config: "PretrainedConfig", **kwargs):
final_kwargs = {}

final_kwargs["num_hidden_layers1"] = len(submodules["layers1"])

final_kwargs["hidden_size1"] = submodules["layers1.0.attention.self.query"].in_features
final_kwargs["num_attention_heads1"] = submodules[
"layers1.0.attention.self"
].num_attention_heads
final_kwargs["attention_dropout1"] = submodules["layers1.0.attention.self.dropout"].p
final_kwargs["hidden_dropout1"] = submodules["layers1.0.attention.output.dropout"].p
final_kwargs["intermediate_size1"] = submodules["layers1.0.intermediate.dense"].out_features
final_kwargs["activation"] = submodules["layers1.0.intermediate"].intermediate_act_fn

final_kwargs["num_hidden_layers1"] = config.num_hidden_layers
final_kwargs["hidden_size1"] = config.hidden_size
final_kwargs["num_attention_heads1"] = config.num_attention_heads
final_kwargs["attention_dropout1"] = config.attention_probs_dropout_prob
final_kwargs["hidden_dropout1"] = config.hidden_dropout_prob
final_kwargs["intermediate_size1"] = config.intermediate_size
final_kwargs["activation"] = config.hidden_act
final_kwargs.update(**kwargs)

return final_kwargs

def _load_from_pretrained_module(
self,
pretrained_module: torch.nn.Module,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
ignore_absent_parameters: Optional[List] = None,
):
if source == "huggingface":
ignore_absent_parameters = ["layers2", "c_layer"]
super()._load_from_pretrained_module(
pretrained_module, source, mapping, ignore_absent_parameters
)

@classmethod
def from_pretrained_module( # type: ignore
cls,
pretrained_module: Union[str, torch.nn.Module],
num_hidden_layers2: int,
hidden_size2: int,
combined_hidden_size: int,
intermediate_size2: int,
num_attention_heads2: int,
combined_num_attention_heads: int,
attention_dropout2: float,
hidden_dropout2: float,
biattention_id1: List[int],
biattention_id2: List[int],
fixed_layer1: int,
fixed_layer2: int,
fast_mode: bool = False,
with_coattention: bool = True,
in_batch_pairs: bool = False,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
# **kwargs,
):
"""
The `pretrained_module` only supplies one of the modalities.
"""
pretrained_module = cls.get_relevant_module(
pretrained_module, source=source, mapping=mapping
)
final_kwargs = {}
final_kwargs.update(cls._get_input_arguments(pretrained_module, source, mapping))
final_kwargs["num_hidden_layers2"] = num_hidden_layers2
final_kwargs["hidden_size2"] = hidden_size2
final_kwargs["combined_hidden_size"] = combined_hidden_size
final_kwargs["intermediate_size2"] = intermediate_size2
final_kwargs["num_attention_heads2"] = num_attention_heads2
final_kwargs["combined_num_attention_heads"] = combined_num_attention_heads
final_kwargs["attention_dropout2"] = attention_dropout2
final_kwargs["hidden_dropout2"] = hidden_dropout2
final_kwargs["biattention_id1"] = biattention_id1
final_kwargs["biattention_id2"] = biattention_id2
final_kwargs["fixed_layer1"] = fixed_layer1
final_kwargs["fixed_layer2"] = fixed_layer2
final_kwargs["fast_mode"] = fast_mode
final_kwargs["with_coattention"] = with_coattention
final_kwargs["in_batch_pairs"] = in_batch_pairs

return super().from_pretrained_module(pretrained_module, source, mapping, **final_kwargs)
return cls(**final_kwargs)
7 changes: 7 additions & 0 deletions allennlp/modules/transformer/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch

from allennlp.modules.transformer.transformer_module import TransformerModule


class LayerNorm(torch.nn.LayerNorm, TransformerModule):
_pretrained_mapping = {"gamma": "weight", "beta": "bias"}
Loading