From 64b24a1b0b095f1c9fc16f2967cc93f32e043709 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 11 May 2021 09:06:12 -0700 Subject: [PATCH 01/23] updates --- allennlp/common/util.py | 12 + allennlp/modules/transformer/__init__.py | 5 +- .../modules/transformer/bimodal_encoder.py | 105 +--- .../modules/transformer/self_attention.py | 64 +-- allennlp/modules/transformer/t5.py | 34 +- .../transformer/transformer_embeddings.py | 40 +- .../modules/transformer/transformer_layer.py | 64 +-- .../modules/transformer/transformer_module.py | 529 ++++++++++++------ .../modules/transformer/transformer_stack.py | 77 +-- allennlp/nn/util.py | 13 +- 10 files changed, 483 insertions(+), 460 deletions(-) diff --git a/allennlp/common/util.py b/allennlp/common/util.py index db77d795e8d..4db2ef6b5fe 100644 --- a/allennlp/common/util.py +++ b/allennlp/common/util.py @@ -509,6 +509,18 @@ def is_distributed() -> bool: return dist.is_available() and dist.is_initialized() +def is_global_primary() -> bool: + """ + Checks if the distributed process group is the global primary (rank = 0). + If the distributed process group is not available or has not been initialized, + this trivially returns `True`. + """ + if not is_distributed(): + return True + else: + return dist.get_rank() == 0 + + def sanitize_wordpiece(wordpiece: str) -> str: """ Sanitizes wordpieces from BERT, RoBERTa or ALBERT tokenizers. diff --git a/allennlp/modules/transformer/__init__.py b/allennlp/modules/transformer/__init__.py index b0b56b90d17..a5a64e45b84 100644 --- a/allennlp/modules/transformer/__init__.py +++ b/allennlp/modules/transformer/__init__.py @@ -125,7 +125,10 @@ def forward(self, token_ids: torch.LongTensor, mask: torch.BoolTensor): from allennlp.modules.transformer.positional_encoding import SinusoidalPositionalEncoding -from allennlp.modules.transformer.transformer_module import TransformerModule +from allennlp.modules.transformer.transformer_module import ( + TransformerModule, + DistributedLoadingStrategy, +) from allennlp.modules.transformer.transformer_embeddings import ( Embeddings, TransformerEmbeddings, diff --git a/allennlp/modules/transformer/bimodal_encoder.py b/allennlp/modules/transformer/bimodal_encoder.py index bf5e732e96d..4cff9e1b3a3 100644 --- a/allennlp/modules/transformer/bimodal_encoder.py +++ b/allennlp/modules/transformer/bimodal_encoder.py @@ -1,14 +1,16 @@ -from typing import Optional, Dict, List, Union +from typing import Optional, List, TYPE_CHECKING + import torch from allennlp.common import FromParams - from allennlp.modules.util import replicate_layers - from allennlp.modules.transformer.transformer_layer import TransformerLayer from allennlp.modules.transformer.bimodal_connection_layer import BiModalConnectionLayer from allennlp.modules.transformer.transformer_module import TransformerModule +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class BiModalEncoder(TransformerModule, FromParams): """ @@ -243,93 +245,14 @@ def forward( ) @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - """ - The `pretrained_module` only supplies one of the modalities. - """ - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["num_hidden_layers1"] = len(submodules["layers1"]) - - final_kwargs["hidden_size1"] = submodules["layers1.0.attention.self.query"].in_features - final_kwargs["num_attention_heads1"] = submodules[ - "layers1.0.attention.self" - ].num_attention_heads - final_kwargs["attention_dropout1"] = submodules["layers1.0.attention.self.dropout"].p - final_kwargs["hidden_dropout1"] = submodules["layers1.0.attention.output.dropout"].p - final_kwargs["intermediate_size1"] = submodules["layers1.0.intermediate.dense"].out_features - final_kwargs["activation"] = submodules["layers1.0.intermediate"].intermediate_act_fn - + final_kwargs["num_hidden_layers1"] = config.num_hidden_layers + final_kwargs["hidden_size1"] = config.hidden_size + final_kwargs["num_attention_heads1"] = config.num_attention_heads + final_kwargs["attention_dropout1"] = config.attention_probs_dropout_prob + final_kwargs["hidden_dropout1"] = config.hidden_dropout_prob + final_kwargs["intermediate_size1"] = config.intermediate_size + final_kwargs["activation"] = config.hidden_act final_kwargs.update(**kwargs) - - return final_kwargs - - def _load_from_pretrained_module( - self, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - ignore_absent_parameters: Optional[List] = None, - ): - if source == "huggingface": - ignore_absent_parameters = ["layers2", "c_layer"] - super()._load_from_pretrained_module( - pretrained_module, source, mapping, ignore_absent_parameters - ) - - @classmethod - def from_pretrained_module( # type: ignore - cls, - pretrained_module: Union[str, torch.nn.Module], - num_hidden_layers2: int, - hidden_size2: int, - combined_hidden_size: int, - intermediate_size2: int, - num_attention_heads2: int, - combined_num_attention_heads: int, - attention_dropout2: float, - hidden_dropout2: float, - biattention_id1: List[int], - biattention_id2: List[int], - fixed_layer1: int, - fixed_layer2: int, - fast_mode: bool = False, - with_coattention: bool = True, - in_batch_pairs: bool = False, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - # **kwargs, - ): - """ - The `pretrained_module` only supplies one of the modalities. - """ - pretrained_module = cls.get_relevant_module( - pretrained_module, source=source, mapping=mapping - ) - final_kwargs = {} - final_kwargs.update(cls._get_input_arguments(pretrained_module, source, mapping)) - final_kwargs["num_hidden_layers2"] = num_hidden_layers2 - final_kwargs["hidden_size2"] = hidden_size2 - final_kwargs["combined_hidden_size"] = combined_hidden_size - final_kwargs["intermediate_size2"] = intermediate_size2 - final_kwargs["num_attention_heads2"] = num_attention_heads2 - final_kwargs["combined_num_attention_heads"] = combined_num_attention_heads - final_kwargs["attention_dropout2"] = attention_dropout2 - final_kwargs["hidden_dropout2"] = hidden_dropout2 - final_kwargs["biattention_id1"] = biattention_id1 - final_kwargs["biattention_id2"] = biattention_id2 - final_kwargs["fixed_layer1"] = fixed_layer1 - final_kwargs["fixed_layer2"] = fixed_layer2 - final_kwargs["fast_mode"] = fast_mode - final_kwargs["with_coattention"] = with_coattention - final_kwargs["in_batch_pairs"] = in_batch_pairs - - return super().from_pretrained_module(pretrained_module, source, mapping, **final_kwargs) + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/self_attention.py b/allennlp/modules/transformer/self_attention.py index 6db6aba1fad..8be8c60c6fa 100644 --- a/allennlp/modules/transformer/self_attention.py +++ b/allennlp/modules/transformer/self_attention.py @@ -1,4 +1,5 @@ -from typing import Optional, Dict +from typing import Optional, TYPE_CHECKING + import torch from allennlp.common import FromParams @@ -6,6 +7,9 @@ from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.util import apply_mask +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class SelfAttention(TransformerModule, FromParams): """ @@ -26,7 +30,14 @@ class SelfAttention(TransformerModule, FromParams): """ _relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] - _huggingface_mapping = {"layer": "layers"} + _huggingface_mapping = { + "layer": "layers", + "q_lin": "query", + "k_lin": "key", + "v_lin": "value", + "out_lin": "output", + "transformer": "encoder", + } def __init__( self, @@ -133,47 +144,16 @@ def forward( return outputs @classmethod - def _get_mapping( - cls, pretrained_module=None, source="huggingface", mapping: Optional[Dict[str, str]] = None - ): - combined_mapping = {} - if "huggingface" in source: - combined_mapping.update(cls._huggingface_mapping) - if mapping is not None: - combined_mapping.update(mapping) - if pretrained_module is not None: - for name, _ in pretrained_module.named_modules(): - if "q_lin" in name: - combined_mapping["q_lin"] = "query" - combined_mapping["k_lin"] = "key" - combined_mapping["v_lin"] = "value" - combined_mapping["out_lin"] = "output" - combined_mapping["transformer"] = "encoder" - break - return combined_mapping - - @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["hidden_size"] = submodules["query"].in_features - if hasattr(submodules[""], "num_attention_heads"): - final_kwargs["num_attention_heads"] = submodules[""].num_attention_heads - elif hasattr(submodules[""], "n_heads"): - final_kwargs["num_attention_heads"] = submodules[""].n_heads - final_kwargs["output_linear"] = True # Since this is the distilbert case. + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["output_linear"] = hasattr( + config, "n_heads" + ) # Since this is the distilbert case. + if hasattr(config, "attention_dropout"): + final_kwargs["dropout"] = config.attention_dropout else: - raise AttributeError("Cannot find a relevant attribute for number of heads.") - - final_kwargs["dropout"] = submodules["dropout"].p - + final_kwargs["dropout"] = config.attention_probs_dropout_prob final_kwargs.update(**kwargs) - return final_kwargs diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 83305487b76..1772fb5b217 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -5,7 +5,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, List, Union, Dict, Any +from typing import Optional, Tuple, List, Union, Dict, TYPE_CHECKING import torch from torch import nn @@ -21,6 +21,9 @@ ) from allennlp.nn.beam_search import BeamSearch +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + # Unfortunately mypy is insane, so I have to wrap these in unions. FloatT = Union[torch.FloatTensor] IntT = Union[torch.IntTensor] @@ -1003,16 +1006,7 @@ def __init__( ) @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ) -> Dict[str, Any]: - from transformers.models.t5 import T5Config - - config: T5Config = pretrained_module.config + def _from_config(cls, config: "PretrainedConfig", **kwargs): attention_kwargs = { "hidden_size": config.d_model, "key_value_proj_dim": config.d_kv, @@ -1039,8 +1033,8 @@ def _get_input_arguments( } ), ) - return { - "encoder": Lazy( + return cls( + encoder=Lazy( T5EncoderStack.basic_encoder, contructor_extras={ "num_blocks": config.num_layers, @@ -1050,7 +1044,7 @@ def _get_input_arguments( "dropout": config.dropout_rate, }, ), - "decoder": Lazy( + decoder=Lazy( T5DecoderStack.basic_decoder, contructor_extras={ "num_blocks": config.num_decoder_layers, @@ -1061,12 +1055,12 @@ def _get_input_arguments( "dropout": config.dropout_rate, }, ), - "decoder_start_token_id": config.decoder_start_token_id, - "pad_token_id": config.pad_token_id, - "eos_token_id": config.eos_token_id, - "vocab_size": config.vocab_size, - "model_dim": config.d_model, - } + decoder_start_token_id=config.decoder_start_token_id, + pad_token_id=config.pad_token_id, + eos_token_id=config.eos_token_id, + vocab_size=config.vocab_size, + model_dim=config.d_model, + ) def _shift_right(self, input_ids, start_value: int): # shift inputs to the right diff --git a/allennlp/modules/transformer/transformer_embeddings.py b/allennlp/modules/transformer/transformer_embeddings.py index 754344d1c0e..29ab1f02b71 100644 --- a/allennlp/modules/transformer/transformer_embeddings.py +++ b/allennlp/modules/transformer/transformer_embeddings.py @@ -1,11 +1,13 @@ -from typing import Optional, Dict +from typing import Optional, TYPE_CHECKING import torch from allennlp.common import FromParams - from allennlp.modules.transformer.transformer_module import TransformerModule +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class Embeddings(TransformerModule, FromParams): """ @@ -182,32 +184,12 @@ def forward( # type: ignore return embeddings @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["vocab_size"] = submodules["embeddings.word_embeddings"].num_embeddings - final_kwargs["embedding_size"] = submodules["embeddings.word_embeddings"].embedding_dim - final_kwargs["pad_token_id"] = submodules["embeddings.word_embeddings"].padding_idx - final_kwargs["max_position_embeddings"] = submodules[ - "embeddings.position_embeddings" - ].num_embeddings - - if "embeddings.token_type_embeddings" in submodules: - final_kwargs["type_vocab_size"] = submodules[ - "embeddings.token_type_embeddings" - ].num_embeddings - - else: - final_kwargs["type_vocab_size"] = 0 - + final_kwargs["vocab_size"] = config.vocab_size + final_kwargs["embedding_size"] = config.hidden_size + final_kwargs["pad_token_id"] = config.pad_token_id + final_kwargs["max_position_embeddings"] = config.max_position_embeddings + final_kwargs["type_vocab_size"] = config.type_vocab_size final_kwargs.update(**kwargs) - - return final_kwargs + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/transformer_layer.py b/allennlp/modules/transformer/transformer_layer.py index 3282b2dbf14..ec19cd8f910 100644 --- a/allennlp/modules/transformer/transformer_layer.py +++ b/allennlp/modules/transformer/transformer_layer.py @@ -1,15 +1,16 @@ -from typing import Union, Optional, Dict +from typing import Union, Optional, TYPE_CHECKING import torch from allennlp.common import FromParams - from allennlp.modules.transformer.transformer_module import TransformerModule - from allennlp.modules.transformer.activation_layer import ActivationLayer from allennlp.modules.transformer.self_attention import SelfAttention from allennlp.modules.transformer.output_layer import OutputLayer +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class AttentionLayer(TransformerModule, FromParams): """ @@ -77,25 +78,16 @@ def forward( return outputs @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - final_kwargs["hidden_size"] = submodules["self.query"].in_features - final_kwargs["num_attention_heads"] = submodules["self"].num_attention_heads - final_kwargs["attention_dropout"] = submodules["self.dropout"].p - final_kwargs["hidden_dropout"] = submodules["output.dropout"].p + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["attention_dropout"] = config.attention_probs_dropout_drop + final_kwargs["hidden_dropout"] = config.hidden_dropout_prob final_kwargs.update(**kwargs) - - return final_kwargs + return cls(**final_kwargs) class TransformerLayer(TransformerModule, FromParams): @@ -218,32 +210,14 @@ def forward( return outputs @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["hidden_size"] = submodules["attention.self.query"].in_features - final_kwargs["num_attention_heads"] = submodules["attention.self"].num_attention_heads - final_kwargs["attention_dropout"] = submodules["attention.self.dropout"].p - final_kwargs["hidden_dropout"] = submodules["attention.output.dropout"].p - final_kwargs["intermediate_size"] = submodules["intermediate.dense"].out_features - - # We require the if block as `act_fn` is a function rather than a module, - # so `_get_mapped_submodules` does not automatically fix this. - if source == "huggingface": - final_kwargs["activation"] = getattr(submodules["intermediate"], "intermediate_act_fn") - else: - final_kwargs["activation"] = getattr(submodules["intermediate"], "act_fn") - - final_kwargs["add_cross_attention"] = "cross_attention" in submodules - + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["attention_dropout"] = config.attention_probs_dropout_drop + final_kwargs["hidden_dropout"] = config.hidden_dropout_prob + final_kwargs["intermediate_size"] = config.intermediate_size + final_kwargs["activation"] = config.hidden_act + final_kwargs["add_cross_attention"] = config.add_cross_attention final_kwargs.update(**kwargs) - - return final_kwargs + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index 861120deca2..9b42d2026dc 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -1,229 +1,420 @@ -from typing import Optional, Dict, Union, List, Any +from collections import OrderedDict +from enum import Enum +from itertools import chain import logging -import inspect +import os +from os import PathLike +from typing import TYPE_CHECKING, Optional, Dict, Union, List, Any, TypeVar, Type +import warnings import torch +import torch.distributed as dist + +from allennlp.common.util import is_distributed, is_global_primary +from allennlp.nn.util import distributed_device + +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig -from allennlp.common import cached_transformers logger = logging.getLogger(__name__) +_T = TypeVar("_T", bound="TransformerModule") +StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"] + + +class DistributedLoadingStrategy(Enum): + """ + Strategy options for loading state dictionaries in distributed process groups. + """ + + FREE_FOR_ALL = "FREE_FOR_ALL" + """ + Each process group loads its own state dict from disk. + """ + + MEMORY_EFFICIENT = "MEMORY_EFFICIENT" + """ + Only the primary process group loads the state dict from disk, then it broadcasts + each state tensor one-by-one to the other process groups. + """ + + @classmethod + def from_str(cls, s: str) -> "DistributedLoadingStrategy": + for option in cls: + if option.value.lower() == s.lower(): + return option + raise ValueError(f"Unknown distributed loading strategy: '{s}'") + + class TransformerModule(torch.nn.Module): """ Base class to help with generalized loading of pretrained weights. - `_huggingface_mapping` is an optional mapping for each class, that determines - any differences in the module names between the class modules and the huggingface model's - modules. - - `_relevant_module` is an optional str or list of str which contains the expected name of the module - in the huggingface pretrained model. It can be a list to account for different names in different - models. The search is carried out in the order of the list. + Subclasses should override `_from_config()` if you want to instantiate them with + `from_pretrained_module()`. """ _huggingface_mapping: Dict[str, str] = {} + """ + An optional mapping for each class that determines any differences in the module + names between the class modules and the HuggingFace model's modules. + Keys correspond to HuggingFace submodule names, values correspond to submodules names of this module. + """ + _relevant_module: Optional[Union[str, List[str]]] = None + """ + An optional string or list of strings which contains the expected name of the module + in the HuggingFace pretrained model. It can be a list to account for different names in different + models. The search is carried out in the order of the list. + """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + _distributed_loading_strategy: DistributedLoadingStrategy = ( + DistributedLoadingStrategy.FREE_FOR_ALL + ) + """ + The default strategy for loading a state dictionary within a distributed process group. + """ @classmethod def _get_mapping( cls, - pretrained_module: Optional[torch.nn.Module] = None, - source: str = "huggingface", mapping: Optional[Dict[str, str]] = None, ): """ - Returns the mapping to be used, based on the optional `pretrained_module`. - If `pretrained_module` is not given, the default module-level mapping is returned. + Returns the mapping to be used, based on the optional `mapping` overrides + and the default module-level mapping. """ combined_mapping = {} - if "huggingface" == source: - combined_mapping.update(cls._huggingface_mapping) + combined_mapping.update(cls._huggingface_mapping) if mapping is not None: combined_mapping.update(mapping) return combined_mapping - @classmethod - def _get_mapped_submodules( - cls, - pretrained_module: torch.nn.Module, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - ): - """ - Subclasses overload this method, and provide appropriate name mapping based on the source. - """ - submodules = dict(pretrained_module.named_modules()) - combined_mapping = cls._get_mapping(pretrained_module, source, mapping) - for name, module in pretrained_module.named_modules(): - newname = name - for key, val in combined_mapping.items(): - newname = newname.replace(key, val) - submodules[newname] = submodules.pop(name) - return submodules - - def _construct_default_mapping( - self, - pretrained_module: torch.nn.Module, - source: str = "huggingface", + @staticmethod + def _get_mapped_state_dict( + module: torch.nn.Module, + state_dict: StateDictType, mapping: Optional[Dict[str, str]] = None, - ): + ) -> StateDictType: """ - Recursively constructs the default mapping of parameter names for loading pretrained module weights. - Keys are parameter names from this module, and values are corresponding parameter names in the - expected pretrained module, as per `source`. - """ - combined_mapping = self._get_mapping(pretrained_module, source, mapping) - for name, module in self.named_modules(): - if name != "": - if hasattr(module, "_construct_default_mapping"): - # We handle collisions by giving priority to the outer module's mapping. - combined_mapping = dict( - list( - module._construct_default_mapping( - pretrained_module, source, combined_mapping - ).items() - ) - + list(combined_mapping.items()) - ) - return combined_mapping + Recursively map keys in a HuggingFace `state_dict` to the corresponding keys + for this module and all submodules. - def _load_from_pretrained_module( - self, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - ignore_absent_parameters: Optional[List] = None, - ): - """ - Loads the weights of the `pretrained_module` into the instance. - Optionally, a `mapping` is specified for any differences in parameter names - between `pretrained_module` and the instance. + This is a `@staticmethod` instead of an instance method so that we can call + it on modules that do not inherit from `TransformerModule` in case those + modules have submodules that are `TransformerModule` instances. """ - ignore_absent_parameters = ignore_absent_parameters or [] - combined_mapping = self._construct_default_mapping(pretrained_module, source, mapping) - if mapping is not None: - combined_mapping.update(mapping) + # First fix all top-level keys according to `combined_mapping`. + combined_mapping = ( + module._get_mapping(mapping) if isinstance(module, TransformerModule) else {} + ) + for hf_key, cls_key in combined_mapping.items(): + relevant_keys = set([key for key in state_dict.keys() if key.startswith(hf_key)]) + for key in relevant_keys: + new_key = key.replace(hf_key, cls_key, 1) + state_dict[new_key] = state_dict.pop(key) - inverse_mapping = {val: key for key, val in combined_mapping.items()} - pretrained_parameters = dict(pretrained_module.named_parameters()) - for name, parameter in self.named_parameters(): - pretrained_name = name - for key, val in inverse_mapping.items(): - # so that we replace the names of submodules too. - # eg. module.key.anothermodule --> module.val.anothermodule - pretrained_name = pretrained_name.replace(key, val) - if not any( - [pretrained_name.startswith(paraname) for paraname in ignore_absent_parameters] - ): - if pretrained_name not in pretrained_parameters: - raise ValueError( - f"Couldn't find a matching parameter for {name}. Is this module " - "compatible with the pretrained module you're using?" - ) - parameter.data.copy_(pretrained_parameters[pretrained_name].data) + # Now loop through the submodules, calling this function on each submodule. + for name, submodule in module.named_children(): + # Pull-out the part of the state_dict corresponding to just this submodule. + relevant_keys = set([key for key in state_dict.keys() if key.startswith(name + ".")]) + module_state_dict = { + key.replace(name + ".", "", 1): state_dict.pop(key) for key in relevant_keys + } + # Recursively call this function from the submodule to map this part + # of the state_dict. + module_state_dict = TransformerModule._get_mapped_state_dict( + submodule, module_state_dict + ) + # And then update the full state_dict. + for key, value in module_state_dict.items(): + state_dict[name + "." + key] = value + + return state_dict @classmethod - def _get_input_arguments( + def _get_relevant_submodule_state( cls, - pretrained_module: torch.nn.Module, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ) -> Dict[str, Any]: + state_dict: StateDictType, + relevant_module: Optional[Union[str, List[str]]] = None, + ) -> StateDictType: """ - Constructs the arguments required for instantiating an object of this class, using - the values from `pretrained_module`. + Returns the relevant part of the `state_dict`. """ - return kwargs + relevant_modules: Optional[List[str]] = None + if relevant_module: + relevant_modules = ( + [relevant_module] if isinstance(relevant_module, str) else relevant_module + ) + elif isinstance(cls._relevant_module, str): + relevant_modules = [cls._relevant_module] + elif isinstance(cls._relevant_module, list): + relevant_modules = cls._relevant_module + + if relevant_modules: + found = False + for module_name in relevant_modules: + relevant_keys = set( + [key for key in state_dict.keys() if key.startswith(module_name + ".")] + ) + if relevant_keys: + # Only keep elements of state dict that correspond to the relevant module. + state_dict = { + key.replace(module_name + ".", "", 1): value + for key, value in state_dict.items() + if key in relevant_keys + } + found = True + break + + if not found: + warnings.warn( + f"{relevant_modules} was not found at top level of state_dict!", UserWarning + ) + + return state_dict @classmethod - def get_relevant_module( + def _get_pretrained_state_dict( cls, - pretrained_module: Union[str, torch.nn.Module], + model_name: str, + weights_path: Optional[Union[str, PathLike]] = None, relevant_module: Optional[Union[str, List[str]]] = None, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - load_weights: bool = True, - ): + ) -> StateDictType: + """ + Get a HuggingFace pretrained `state_dict` corresponding to this module. """ - Returns the relevant underlying module given a model name/object. + if weights_path is None: + from transformers.file_utils import WEIGHTS_NAME - # Parameters + # First see if we can find the weights locally. + if os.path.isdir(model_name): + local_weights_path = os.path.join(model_name, WEIGHTS_NAME) + if os.path.isfile(local_weights_path): + logger.info("Found weights at local path %s", local_weights_path) + weights_path = local_weights_path + + # If we haven't found locally, we assume model ID corresponds to a model + # on the HuggingFace Hub. + if weights_path is None: + from allennlp.common.file_utils import cached_path + + weights_path = cached_path(f"hf://{model_name}/{WEIGHTS_NAME}") + + # Now load the state dict. + logger.info("Loading state dict from %s", weights_path) + state_dict = torch.load(weights_path, map_location="cpu") - pretrained_module : `Union[str, torch.nn.Module]` - Name of the transformer model containing the layer, - or the actual layer (not the model object). - relevant_module : `Optional[Union[str, List[str]]]`, optional - Name of the desired module. Defaults to cls._relevant_module. - source : `str`, optional - Where the model came from. Default - huggingface. - mapping : `Dict[str, str]`, optional - Optional mapping that determines any differences in the module names - between the class modules and the input model's modules. - Default - cls._huggingface_mapping - load_weights : `bool`, optional - Whether or not to load the pretrained weights. - Default is `True`. + # Keep just the relevant_module, remove everything else. + state_dict = cls._get_relevant_submodule_state(state_dict) + + return state_dict + + @staticmethod + def _collect_state_dict( + module: torch.nn.Module, state_dict: Optional[StateDictType], recurse: bool = True + ) -> StateDictType: """ - if isinstance(pretrained_module, str): - pretrained_module = cached_transformers.get( - pretrained_module, False, load_weights=load_weights + Collect a module's state dict across distributed processes. + """ + # This is the device we'll use for the broadcast operation. + device = distributed_device() + + # Gather current state dict and prepare to iterator over it. + # We iterate over this state dict instead of `state_dict` so we can be sure + # that the order is consistent across processes. + # We'll also update this state dict as we go and return it at the end. + if recurse: + current_state_dict = module.state_dict() + else: + # Only collect state of direct members, including both parameters and buffers. + current_state_dict = OrderedDict( + chain( + # Paramaters + ((n, p.data) for (n, p) in module.named_parameters(recurse=False)), + # Buffers + module.named_buffers(recurse=False), + ) ) + keys = list(current_state_dict.keys()) - relevant_module = relevant_module or cls._relevant_module + for key in keys: + tensor = current_state_dict[key] + if is_global_primary(): + assert state_dict is not None + if key in state_dict: + tensor = state_dict[key] + else: + logger.warning( + f"Missing key {key} from state_dict (available keys: {list(state_dict.keys())})" + ) + tensor = tensor.to(device) + dist.broadcast(tensor, 0) + current_state_dict[key] = tensor - if relevant_module is not None: - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - # If the relevant_module is not found, we assume that the pretrained_module - # is already the relevant module. - if isinstance(relevant_module, str): - relevant_module = [relevant_module] - found = False - for module in relevant_module: - if module in submodules: - pretrained_module = submodules[module] - found = True - break + return current_state_dict - if not found: - logger.warning( - "{} was not found! The submodules are: {}".format( - relevant_module, submodules.keys() - ) + @staticmethod + def _load_state_dict_distributed( + module: torch.nn.Module, state_dict: Optional[StateDictType], strict: bool = True + ) -> None: + """ + Load a `state_dict` within a distributed process group. + + The `state_dict` may be `None` if the current process group is not the global primary, + in which case it will gather the parameters from the global primary one-by-one. + + This is a `@staticmethod` instead of an instance method so that we can call + it on modules that do not inherit from `TransformerModule` in case those + modules have submodules that are `TransformerModule` instances. + """ + submodules = dict(module.named_children()) + + # If we've found a sharded module or there aren't any more submodules of the current module, + # we collect the state_dict and load it now instead of recursing further. + if getattr(module, "_is_sharded", False) or not submodules: + state_dict = TransformerModule._collect_state_dict(module, state_dict) + assert state_dict is not None + module.load_state_dict(state_dict, strict=strict) + else: + # We'll recursively call this function on each submodule, but first we need + # to collect any parameters that are direct members of this module. + direct_member_state_dict = TransformerModule._collect_state_dict( + module, state_dict, recurse=False + ) + missing_keys, unexpected_keys = module.load_state_dict( + direct_member_state_dict, strict=False + ) + if strict and unexpected_keys: + raise ValueError(f"Unexpected keys in state dict: {unexpected_keys}") + + # Okay, now for the recursive part. + for name, submodule in submodules.items(): + submodule_state_dict: Optional[StateDictType] = None + if is_global_primary(): + assert state_dict is not None + submodule_state_dict = { + key.replace(name + ".", "", 1): value + for key, value in state_dict.items() + if key.startswith(name + ".") + } + submodule_state_dict = TransformerModule._collect_state_dict( + submodule, submodule_state_dict ) - return pretrained_module + assert submodule_state_dict is not None + submodule.load_state_dict(submodule_state_dict, strict=strict) + + @classmethod + def _from_config(cls: Type[_T], config: "PretrainedConfig", **kwargs) -> _T: + """ + Instantiate this module from a HuggingFace config. Subclasses should override + this method if you want to be able to instantiate them with `from_pretrained_module()`. + """ + raise NotImplementedError @classmethod def from_pretrained_module( - cls, - pretrained_module: Union[str, torch.nn.Module], - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, + cls: Type[_T], + model_name: str, load_weights: bool = True, + weights_path: Optional[Union[str, PathLike]] = None, + auto_config_kwargs: Optional[Dict[str, Any]] = None, + mapping: Optional[Dict[str, str]] = None, + relevant_module: Optional[Union[str, List[str]]] = None, + strict: bool = True, + distributed_loading_strategy: Optional[Union[str, DistributedLoadingStrategy]] = None, **kwargs, - ): - """ - Creates and returns an instance of the class, by using the weights - (and the architecture, by default) of the `pretrained_module`. - Optionally, the architecture can be changed by providing arguments. + ) -> _T: """ - accepted_args = inspect.getfullargspec(cls).args - accepted_args.remove("self") - for key in kwargs: - assert key in accepted_args, ( - "{} is not a valid argument for creating an instance of `{}`. " - "Accepted arguments are {}.".format(key, cls.__name__, accepted_args) - ) + Initialize this module from a corresponding model from HuggingFace. - pretrained_module = cls.get_relevant_module( - pretrained_module, source=source, mapping=mapping, load_weights=load_weights - ) - final_kwargs = cls._get_input_arguments(pretrained_module, source, mapping) - final_kwargs.update(kwargs) - module = cls(**final_kwargs) - module._load_from_pretrained_module(pretrained_module, source, mapping) - return module + + !!! Note + This method is only available for subclasses that implement `from_config()`. + Otherwise a `NotImplementedError` will be raised. + + # Parameters + + model_name : `str` + The model identifier or path. + + load_weights : `bool`, optional (default = `True`) + Whether to download and load the pretrained weights. If `False`, the + weights are left uninitialized. + + weights_path : `Optional[Union[str, PathLike]]`, optional (default = `None`) + When `load_weights` is `True`, this can be set to override the weights file. + Otherwise the default weights from the pretrained model are used. + + auto_config_kwargs : `Optional[Dict[str, Any]]`, optional (default = `None`) + Optional key-word arguments to pass to `transformers.AutoConfig.from_pretrained()` + to load the pretrained model's configuration file. + + mapping : `Optional[Dict[str, str]]`, optional (default = `None`) + Optional mapping that determines any differences in the submodule names + between this module and the pretrained model from HuggingFace. + If not given, the class's default is used: `cls._huggingface_mapping`. + + relevant_module : `Optionall[str]`, optional (default = `None`) + An optional submodule of the HuggingFace module to initialize weights from. + This is only relevant when `load_weights` is `True`. + If not given, the class's default is used: `cls._relevant_module`. + + strict : `bool`, optional (default = `True`) + Whether to load the `state_dict` in "strict" model. This only applies + when `load_weights` is `True`. + + distributed_loading_strategy : `Optional[Union[str, DistributedLoadingStrategy]]`, optional (default = `None`) + The loading strategy to use within a distributed process group. This only applies + when `load_weights` is `True`. If not specified, this class's default is used: + `cls._distributed_loading_strategy`. + + **kwargs : Any + Key word arguments to pass to `cls.from_config()` when instantiating the module. + """ # noqa: E501 + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(model_name, **(auto_config_kwargs or {})) + model = cls._from_config(config, **kwargs) + + if load_weights: + # Resolve the loading strategy to use. + loading_strategy: DistributedLoadingStrategy + if isinstance(distributed_loading_strategy, DistributedLoadingStrategy): + loading_strategy = distributed_loading_strategy + elif isinstance(distributed_loading_strategy, str): + loading_strategy = DistributedLoadingStrategy.from_str(distributed_loading_strategy) + else: + loading_strategy = cls._distributed_loading_strategy + + state_dict: Optional[StateDictType] = None + if is_global_primary() or loading_strategy == DistributedLoadingStrategy.FREE_FOR_ALL: + # Load the pretrained HuggingFace state_dict. + pretrained_state_dict = cls._get_pretrained_state_dict( + model_name, + weights_path=weights_path, + relevant_module=relevant_module, + ) + # Now map keys from the HuggingFace state_dict to the corresponding keys from + # this class. This is called recursively on each submodule of the current module. + state_dict = TransformerModule._get_mapped_state_dict( + model, pretrained_state_dict, mapping=mapping + ) + + if not is_distributed() or loading_strategy == DistributedLoadingStrategy.FREE_FOR_ALL: + assert state_dict is not None + logger.info("Loading state_dict into module") + model.load_state_dict(state_dict, strict=strict) + else: + # We're in distributed training. `state_dict` is `None` for all process groups + # except the global primary. + # Syncronize here since non-primary process groups will have to wait for the primary + # to load the state_dict into memory. + dist.barrier() + # Now load the state dict into the model. + logger.info("Loading state_dict into module (MEMORY_EFFICIENT strategy)") + TransformerModule._load_state_dict_distributed(model, state_dict, strict=strict) + + return model diff --git a/allennlp/modules/transformer/transformer_stack.py b/allennlp/modules/transformer/transformer_stack.py index 09fb1d2bc40..79164f54b0b 100644 --- a/allennlp/modules/transformer/transformer_stack.py +++ b/allennlp/modules/transformer/transformer_stack.py @@ -1,14 +1,17 @@ -from typing import Union, Optional, Dict +from typing import Union, Optional, TYPE_CHECKING import logging import torch from allennlp.common import FromParams - from allennlp.modules.util import replicate_layers from allennlp.modules.transformer.transformer_layer import TransformerLayer from allennlp.modules.transformer.transformer_module import TransformerModule +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + + logger = logging.getLogger(__name__) @@ -129,67 +132,17 @@ def forward( ) @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - final_kwargs["num_hidden_layers"] = len(submodules["layers"]) - - final_kwargs["hidden_size"] = submodules["layers.0.attention.self.query"].in_features - final_kwargs["num_attention_heads"] = submodules[ - "layers.0.attention.self" - ].num_attention_heads - final_kwargs["attention_dropout"] = submodules["layers.0.attention.self.dropout"].p - final_kwargs["hidden_dropout"] = submodules["layers.0.attention.output.dropout"].p - final_kwargs["intermediate_size"] = submodules["layers.0.intermediate.dense"].out_features - - # We require the if block as `act_fn` is a function rather than a module, - # so `_get_mapped_submodules` does not automatically fix this. - if source == "huggingface": - final_kwargs["activation"] = getattr( - submodules["layers.0.intermediate"], "intermediate_act_fn" - ) - else: - final_kwargs["activation"] = getattr(submodules["layers.0.intermediate"], "act_fn") - - final_kwargs["add_cross_attention"] = "layers.0.cross_attention" in submodules + final_kwargs["num_hidden_layers"] = config.num_hidden_layers + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob + final_kwargs["hidden_dropout"] = config.hidden_dropout_prob + final_kwargs["intermediate_size"] = config.intermediate_size + final_kwargs["activation"] = config.hidden_act + final_kwargs["add_cross_attention"] = config.add_cross_attention final_kwargs.update(**kwargs) - - return final_kwargs - - @classmethod - def from_pretrained_module( # type: ignore - cls, - pretrained_module: Union[str, torch.nn.Module], - num_hidden_layers: Optional[Union[int, range]] = None, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - load_weights: bool = True, - **kwargs, - ): - final_kwargs = {} - if num_hidden_layers is not None: - if isinstance(num_hidden_layers, range): - if mapping is None: - mapping = {} - for num_layer, mapped in enumerate(num_hidden_layers): - mapping[str(mapped)] = str(num_layer) - final_kwargs["num_hidden_layers"] = len(num_hidden_layers) - else: - final_kwargs["num_hidden_layers"] = num_hidden_layers - - return super().from_pretrained_module( - pretrained_module, - source=source, - mapping=mapping, - load_weights=load_weights, - **final_kwargs, - ) + return cls(**final_kwargs) diff --git a/allennlp/nn/util.py b/allennlp/nn/util.py index b1c391bd52e..9a558e019fc 100644 --- a/allennlp/nn/util.py +++ b/allennlp/nn/util.py @@ -2019,6 +2019,17 @@ def tiny_value_of_dtype(dtype: torch.dtype): _V = TypeVar("_V", int, float, torch.Tensor) +def distributed_device() -> torch.device: + """ + Get the correct `torch.device` of the current process to use for distributed point-to-point communication. + """ + if not is_distributed(): + raise RuntimeError( + "'distributed_device()' can only be called within a distributed process group" + ) + return int_to_device(-1 if dist.get_backend() != "nccl" else torch.cuda.current_device()) + + def dist_reduce(value: _V, reduce_op, **kwargs) -> _V: """ Reduces the given `value` across all distributed worker nodes according the given @@ -2043,7 +2054,7 @@ def dist_reduce(value: _V, reduce_op, **kwargs) -> _V: """ if not is_distributed(): return value - device = int_to_device(-1 if dist.get_backend() != "nccl" else torch.cuda.current_device()) + device = distributed_device() value_tensor = torch.tensor(value, device=device, **kwargs) dist.all_reduce(value_tensor, op=reduce_op) From ba418dfa7c67e27f8e4cc177d3561b9583f4f10b Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 11 May 2021 09:12:00 -0700 Subject: [PATCH 02/23] rename 'load_state_dict' -> 'read_state_dict' --- allennlp/commands/diff.py | 6 +++--- allennlp/models/model.py | 2 +- allennlp/modules/transformer/transformer_module.py | 6 +++--- allennlp/nn/util.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/allennlp/commands/diff.py b/allennlp/commands/diff.py index 6d86f7db76f..35738ca2237 100644 --- a/allennlp/commands/diff.py +++ b/allennlp/commands/diff.py @@ -19,7 +19,7 @@ from allennlp.commands.subcommand import Subcommand from allennlp.common.file_utils import cached_path -from allennlp.nn.util import load_state_dict +from allennlp.nn.util import read_state_dict logger = logging.getLogger(__name__) @@ -249,10 +249,10 @@ def _get_checkpoint_path(checkpoint: str) -> str: def _diff(args: argparse.Namespace): checkpoint_1_path = _get_checkpoint_path(args.checkpoint1) checkpoint_2_path = _get_checkpoint_path(args.checkpoint2) - checkpoint_1 = load_state_dict( + checkpoint_1 = read_state_dict( checkpoint_1_path, strip_prefix=args.strip_prefix_1, strict=False ) - checkpoint_2 = load_state_dict( + checkpoint_2 = read_state_dict( checkpoint_2_path, strip_prefix=args.strip_prefix_2, strict=False ) for step in checkpoint_diff(checkpoint_1, checkpoint_2, args.scale, args.threshold): diff --git a/allennlp/models/model.py b/allennlp/models/model.py index 5ff7c967e8e..2800243a6a1 100644 --- a/allennlp/models/model.py +++ b/allennlp/models/model.py @@ -335,7 +335,7 @@ def _load( # Load state dict. We pass `strict=False` so PyTorch doesn't raise a RuntimeError # if the state dict is missing keys because we handle this case below. - model_state = util.load_state_dict(weights_file, cuda_device=cuda_device) + model_state = util.read_state_dict(weights_file, cuda_device=cuda_device) missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False) # Modules might define a class variable called `authorized_missing_keys`, diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index 9b42d2026dc..ac113f2ad42 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -11,7 +11,7 @@ import torch.distributed as dist from allennlp.common.util import is_distributed, is_global_primary -from allennlp.nn.util import distributed_device +from allennlp.nn.util import distributed_device, read_state_dict if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig @@ -204,8 +204,8 @@ def _get_pretrained_state_dict( weights_path = cached_path(f"hf://{model_name}/{WEIGHTS_NAME}") # Now load the state dict. - logger.info("Loading state dict from %s", weights_path) - state_dict = torch.load(weights_path, map_location="cpu") + logger.info("Reading state dict from %s", weights_path) + state_dict = read_state_dict(weights_path) # Keep just the relevant_module, remove everything else. state_dict = cls._get_relevant_submodule_state(state_dict) diff --git a/allennlp/nn/util.py b/allennlp/nn/util.py index a458bdd45a7..459190b6b28 100644 --- a/allennlp/nn/util.py +++ b/allennlp/nn/util.py @@ -926,7 +926,7 @@ def inner_device_mapping(storage: torch.Storage, location) -> torch.Storage: return inner_device_mapping -def load_state_dict( +def read_state_dict( path: Union[PathLike, str], strip_prefix: Optional[str] = None, ignore: Optional[List[str]] = None, @@ -934,7 +934,7 @@ def load_state_dict( cuda_device: int = -1, ) -> Dict[str, torch.Tensor]: """ - Load a PyTorch model state dictionary from a checkpoint at the given `path`. + Read a PyTorch model state dictionary from a checkpoint at the given `path`. # Parameters From ed0b8fa85f4dae79e13040b08b3b2c5e17a46143 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 11 May 2021 13:30:40 -0700 Subject: [PATCH 03/23] fix TransformerStack --- allennlp/modules/transformer/output_layer.py | 6 +- .../modules/transformer/transformer_module.py | 178 +++--------- .../modules/transformer/transformer_stack.py | 6 +- allennlp/nn/util.py | 98 ++++++- .../transformer/transformer_module_test.py | 81 +++--- .../transformer/transformer_stack_test.py | 253 +++++------------- 6 files changed, 259 insertions(+), 363 deletions(-) diff --git a/allennlp/modules/transformer/output_layer.py b/allennlp/modules/transformer/output_layer.py index 03dd1f9d5df..df79b1779ec 100644 --- a/allennlp/modules/transformer/output_layer.py +++ b/allennlp/modules/transformer/output_layer.py @@ -5,6 +5,10 @@ from allennlp.modules.transformer.transformer_module import TransformerModule +class LayerNorm(torch.nn.LayerNorm, TransformerModule): + _huggingface_mapping = {"gamma": "weight", "beta": "bias"} + + class OutputLayer(TransformerModule, FromParams): _huggingface_mapping = {"LayerNorm": "layer_norm"} @@ -12,7 +16,7 @@ class OutputLayer(TransformerModule, FromParams): def __init__(self, input_size: int, hidden_size: int, dropout: float): super().__init__() self.dense = torch.nn.Linear(input_size, hidden_size) - self.layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12) + self.layer_norm = LayerNorm(hidden_size, eps=1e-12) self.dropout = torch.nn.Dropout(dropout) def forward(self, hidden_states, input_tensor): diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index ac113f2ad42..65e7a507a32 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -1,6 +1,4 @@ -from collections import OrderedDict from enum import Enum -from itertools import chain import logging import os from os import PathLike @@ -11,7 +9,7 @@ import torch.distributed as dist from allennlp.common.util import is_distributed, is_global_primary -from allennlp.nn.util import distributed_device, read_state_dict +from allennlp.nn.util import StateDictType, read_state_dict, load_state_dict_distributed if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig @@ -21,23 +19,26 @@ _T = TypeVar("_T", bound="TransformerModule") -StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"] class DistributedLoadingStrategy(Enum): """ - Strategy options for loading state dictionaries in distributed process groups. + Strategy options for loading state dictionaries across distributed processes. """ FREE_FOR_ALL = "FREE_FOR_ALL" """ - Each process group loads its own state dict from disk. + Each process loads its own state dict from disk. """ MEMORY_EFFICIENT = "MEMORY_EFFICIENT" """ - Only the primary process group loads the state dict from disk, then it broadcasts + Only the primary process loads the state dict from disk, then it broadcasts each state tensor one-by-one to the other process groups. + + This is particularly useful when you have multiple distributed workers on the same + machine (shared CPU memory), and don't have enough memory for each process to load + its own copy of the state dict at the same time. """ @classmethod @@ -92,47 +93,16 @@ def _get_mapping( combined_mapping.update(mapping) return combined_mapping - @staticmethod def _get_mapped_state_dict( - module: torch.nn.Module, + self, state_dict: StateDictType, mapping: Optional[Dict[str, str]] = None, ) -> StateDictType: """ Recursively map keys in a HuggingFace `state_dict` to the corresponding keys for this module and all submodules. - - This is a `@staticmethod` instead of an instance method so that we can call - it on modules that do not inherit from `TransformerModule` in case those - modules have submodules that are `TransformerModule` instances. """ - # First fix all top-level keys according to `combined_mapping`. - combined_mapping = ( - module._get_mapping(mapping) if isinstance(module, TransformerModule) else {} - ) - for hf_key, cls_key in combined_mapping.items(): - relevant_keys = set([key for key in state_dict.keys() if key.startswith(hf_key)]) - for key in relevant_keys: - new_key = key.replace(hf_key, cls_key, 1) - state_dict[new_key] = state_dict.pop(key) - - # Now loop through the submodules, calling this function on each submodule. - for name, submodule in module.named_children(): - # Pull-out the part of the state_dict corresponding to just this submodule. - relevant_keys = set([key for key in state_dict.keys() if key.startswith(name + ".")]) - module_state_dict = { - key.replace(name + ".", "", 1): state_dict.pop(key) for key in relevant_keys - } - # Recursively call this function from the submodule to map this part - # of the state_dict. - module_state_dict = TransformerModule._get_mapped_state_dict( - submodule, module_state_dict - ) - # And then update the full state_dict. - for key, value in module_state_dict.items(): - state_dict[name + "." + key] = value - - return state_dict + return _get_mapped_state_dict(self, state_dict, mapping=mapping) @classmethod def _get_relevant_submodule_state( @@ -212,100 +182,6 @@ def _get_pretrained_state_dict( return state_dict - @staticmethod - def _collect_state_dict( - module: torch.nn.Module, state_dict: Optional[StateDictType], recurse: bool = True - ) -> StateDictType: - """ - Collect a module's state dict across distributed processes. - """ - # This is the device we'll use for the broadcast operation. - device = distributed_device() - - # Gather current state dict and prepare to iterator over it. - # We iterate over this state dict instead of `state_dict` so we can be sure - # that the order is consistent across processes. - # We'll also update this state dict as we go and return it at the end. - if recurse: - current_state_dict = module.state_dict() - else: - # Only collect state of direct members, including both parameters and buffers. - current_state_dict = OrderedDict( - chain( - # Paramaters - ((n, p.data) for (n, p) in module.named_parameters(recurse=False)), - # Buffers - module.named_buffers(recurse=False), - ) - ) - keys = list(current_state_dict.keys()) - - for key in keys: - tensor = current_state_dict[key] - if is_global_primary(): - assert state_dict is not None - if key in state_dict: - tensor = state_dict[key] - else: - logger.warning( - f"Missing key {key} from state_dict (available keys: {list(state_dict.keys())})" - ) - tensor = tensor.to(device) - dist.broadcast(tensor, 0) - current_state_dict[key] = tensor - - return current_state_dict - - @staticmethod - def _load_state_dict_distributed( - module: torch.nn.Module, state_dict: Optional[StateDictType], strict: bool = True - ) -> None: - """ - Load a `state_dict` within a distributed process group. - - The `state_dict` may be `None` if the current process group is not the global primary, - in which case it will gather the parameters from the global primary one-by-one. - - This is a `@staticmethod` instead of an instance method so that we can call - it on modules that do not inherit from `TransformerModule` in case those - modules have submodules that are `TransformerModule` instances. - """ - submodules = dict(module.named_children()) - - # If we've found a sharded module or there aren't any more submodules of the current module, - # we collect the state_dict and load it now instead of recursing further. - if getattr(module, "_is_sharded", False) or not submodules: - state_dict = TransformerModule._collect_state_dict(module, state_dict) - assert state_dict is not None - module.load_state_dict(state_dict, strict=strict) - else: - # We'll recursively call this function on each submodule, but first we need - # to collect any parameters that are direct members of this module. - direct_member_state_dict = TransformerModule._collect_state_dict( - module, state_dict, recurse=False - ) - missing_keys, unexpected_keys = module.load_state_dict( - direct_member_state_dict, strict=False - ) - if strict and unexpected_keys: - raise ValueError(f"Unexpected keys in state dict: {unexpected_keys}") - - # Okay, now for the recursive part. - for name, submodule in submodules.items(): - submodule_state_dict: Optional[StateDictType] = None - if is_global_primary(): - assert state_dict is not None - submodule_state_dict = { - key.replace(name + ".", "", 1): value - for key, value in state_dict.items() - if key.startswith(name + ".") - } - submodule_state_dict = TransformerModule._collect_state_dict( - submodule, submodule_state_dict - ) - assert submodule_state_dict is not None - submodule.load_state_dict(submodule_state_dict, strict=strict) - @classmethod def _from_config(cls: Type[_T], config: "PretrainedConfig", **kwargs) -> _T: """ @@ -328,7 +204,7 @@ def from_pretrained_module( **kwargs, ) -> _T: """ - Initialize this module from a corresponding model from HuggingFace. + Initialize this module from a corresponding model on HuggingFace. !!! Note @@ -415,6 +291,36 @@ def from_pretrained_module( dist.barrier() # Now load the state dict into the model. logger.info("Loading state_dict into module (MEMORY_EFFICIENT strategy)") - TransformerModule._load_state_dict_distributed(model, state_dict, strict=strict) + load_state_dict_distributed(model, state_dict, strict=strict) return model + + +def _get_mapped_state_dict( + module: torch.nn.Module, + state_dict: StateDictType, + mapping: Optional[Dict[str, str]] = None, +) -> StateDictType: + # First fix all top-level keys according to `combined_mapping`. + combined_mapping = module._get_mapping(mapping) if isinstance(module, TransformerModule) else {} + for hf_key, cls_key in combined_mapping.items(): + relevant_keys = set([key for key in state_dict.keys() if key.startswith(hf_key)]) + for key in relevant_keys: + new_key = key.replace(hf_key, cls_key, 1) + state_dict[new_key] = state_dict.pop(key) + + # Now loop through the submodules, calling this function on each submodule. + for name, submodule in module.named_children(): + # Pull-out the part of the state_dict corresponding to just this submodule. + relevant_keys = set([key for key in state_dict.keys() if key.startswith(name + ".")]) + module_state_dict = { + key.replace(name + ".", "", 1): state_dict.pop(key) for key in relevant_keys + } + # Recursively call this function from the submodule to map this part + # of the state_dict. + module_state_dict = TransformerModule._get_mapped_state_dict(submodule, module_state_dict) + # And then update the full state_dict. + for key, value in module_state_dict.items(): + state_dict[name + "." + key] = value + + return state_dict diff --git a/allennlp/modules/transformer/transformer_stack.py b/allennlp/modules/transformer/transformer_stack.py index 79164f54b0b..33c69703040 100644 --- a/allennlp/modules/transformer/transformer_stack.py +++ b/allennlp/modules/transformer/transformer_stack.py @@ -42,7 +42,7 @@ class TransformerStack(TransformerModule, FromParams): """ _huggingface_mapping = {"layer": "layers"} - _relevant_module = "encoder" + _relevant_module = ["encoder", "bert.encoder"] def __init__( self, @@ -134,15 +134,13 @@ def forward( @classmethod def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - final_kwargs["num_hidden_layers"] = config.num_hidden_layers final_kwargs["hidden_size"] = config.hidden_size final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["add_cross_attention"] = config.add_cross_attention final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob final_kwargs["hidden_dropout"] = config.hidden_dropout_prob final_kwargs["intermediate_size"] = config.intermediate_size final_kwargs["activation"] = config.hidden_act - final_kwargs["add_cross_attention"] = config.add_cross_attention - final_kwargs.update(**kwargs) return cls(**final_kwargs) diff --git a/allennlp/nn/util.py b/allennlp/nn/util.py index 459190b6b28..0d8210323c9 100644 --- a/allennlp/nn/util.py +++ b/allennlp/nn/util.py @@ -4,6 +4,7 @@ import copy from collections import defaultdict, OrderedDict +from itertools import chain import json import logging from os import PathLike @@ -16,11 +17,18 @@ import torch.distributed as dist from allennlp.common.checks import ConfigurationError -from allennlp.common.util import int_to_device, is_distributed +from allennlp.common.util import int_to_device, is_distributed, is_global_primary logger = logging.getLogger(__name__) T = TypeVar("T") +StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"] + +_MODULE_SHARDED_FLAG = "_is_sharded_allennlp" +""" +This flag is used to indicate when a module's parameters have been sharded across +distributed workers. +""" def move_to_device(obj, device: Union[torch.device, int]): @@ -2168,3 +2176,91 @@ def dist_reduce_sum(value: _V, **kwargs) -> _V: if not is_distributed(): return value return dist_reduce(value, dist.ReduceOp.SUM, **kwargs) + + +def _collect_state_dict( + module: torch.nn.Module, state_dict: Optional[StateDictType], recurse: bool = True +) -> StateDictType: + """ + Collect a module's state dict across distributed processes. + """ + # This is the device we'll use for the broadcast operation. + device = distributed_device() + + # Gather current state dict and prepare to iterator over it. + # We iterate over this state dict instead of `state_dict` so we can be sure + # that the order is consistent across processes. + # We'll also update this state dict as we go and return it at the end. + if recurse: + current_state_dict = module.state_dict() + else: + # Only collect state of direct members, including both parameters and buffers. + current_state_dict = OrderedDict( + chain( + # Paramaters + ((n, p.data) for (n, p) in module.named_parameters(recurse=False)), + # Buffers + module.named_buffers(recurse=False), + ) + ) + keys = list(current_state_dict.keys()) + + for key in keys: + tensor = current_state_dict[key] + if is_global_primary(): + assert state_dict is not None + if key in state_dict: + tensor = state_dict[key] + else: + logger.warning( + f"Missing key {key} from state_dict (available keys: {list(state_dict.keys())})" + ) + tensor = tensor.to(device) + dist.broadcast(tensor, 0) + current_state_dict[key] = tensor + + return current_state_dict + + +def load_state_dict_distributed( + module: torch.nn.Module, state_dict: Optional[StateDictType], strict: bool = True +) -> None: + """ + Load a `state_dict` to the `module` within a distributed process. Only the global + primary process requires the `state_dict` to not be `None`. All other processes + will have the state tensors broadcasted to them one-by-one. + """ + if is_global_primary(): + assert state_dict is not None + else: + assert state_dict is None + + submodules = dict(module.named_children()) + + # If we've found a sharded module or there aren't any more submodules of the current module, + # we collect the state_dict and load it now instead of recursing further. + if getattr(module, _MODULE_SHARDED_FLAG, False) or not submodules: + state_dict = _collect_state_dict(module, state_dict) + assert state_dict is not None + module.load_state_dict(state_dict, strict=strict) + else: + # We'll recursively call this function on each submodule, but first we need + # to collect any parameters that are direct members of this module. + direct_member_state_dict = _collect_state_dict(module, state_dict, recurse=False) + missing_keys, unexpected_keys = module.load_state_dict( + direct_member_state_dict, strict=False + ) + if strict and unexpected_keys: + raise ValueError(f"Unexpected keys in state dict: {unexpected_keys}") + + # Okay, now for the recursive part. + for name, submodule in submodules.items(): + submodule_state_dict: Optional[StateDictType] = None + if is_global_primary(): + assert state_dict is not None + submodule_state_dict = { + key.replace(name + ".", "", 1): value + for key, value in state_dict.items() + if key.startswith(name + ".") + } + load_state_dict_distributed(submodule, submodule_state_dict, strict=strict) diff --git a/tests/modules/transformer/transformer_module_test.py b/tests/modules/transformer/transformer_module_test.py index d5002f215ea..307c8295ad8 100644 --- a/tests/modules/transformer/transformer_module_test.py +++ b/tests/modules/transformer/transformer_module_test.py @@ -1,74 +1,89 @@ import torch +from torch.nn import Parameter -from allennlp.common.testing import assert_equal_parameters +from allennlp.common.testing import assert_equal_parameters, assert_allclose from allennlp.modules.transformer import TransformerModule from allennlp.common.testing import AllenNlpTestCase class TestTransformerModule(AllenNlpTestCase): - def test_can_load_pretrained_weights(self): + def test_get_mapped_state_dict(self): class InternalOld(torch.nn.Module): def __init__(self, inp, out): super().__init__() self.ff = torch.nn.Linear(inp, out) + self.p = Parameter(torch.randn(out, out)) + self.register_buffer("b", torch.randn(inp, inp)) def forward(self, x): - x = self.ff(x) + x = self.ff(x).matmul(self.p) return x class InternalNew(TransformerModule): + _huggingface_mapping = {"ff": "linear", "p": "param", "b": "buffer"} + def __init__(self, inp, out): super().__init__() self.linear = torch.nn.Linear(inp, out) - - def _construct_default_mapping(self, pretrained_module, source, mapping): - # return {"linear": "ff"} - return {"ff": "linear"} + self.param = Parameter(torch.randn(out, out)) + self.register_buffer("buffer", torch.randn(inp, inp)) def forward(self, x): - x = self.linear(x) + x = self.linear(x).matmul(self.param) return x class ExternalOld(torch.nn.Module): def __init__(self, inp, out): super().__init__() self.internal = InternalOld(inp, out) + self.p = Parameter(torch.randn(out, out)) def forward(self, x): - x = self.internal(x) + x = self.internal(x).matmul(self.p) return x - class External(TransformerModule): - # _huggingface_mapping = {"internal_layer": "internal"} - _huggingface_mapping = {"internal": "internal_layer"} + class ExternalNew(TransformerModule): + _huggingface_mapping = {"internal": "internal_layer", "p": "param"} def __init__(self, inp, out): super().__init__() self.internal_layer = InternalNew(inp, out) + self.param = Parameter(torch.randn(out, out)) def forward(self, x): - x = self.internal_layer(x) + x = self.internal_layer(x).matmul(self.param) return x - iold = InternalOld(3, 5) - x = torch.randn(4, 3) - iold.forward(x) - inew = InternalNew(3, 5) - inew._load_from_pretrained_module(iold) - mapping = { - val: key - for key, val in inew._construct_default_mapping(iold, "huggingface", {}).items() - } - assert_equal_parameters(iold, inew, mapping=mapping) - eold = ExternalOld(3, 5) + state_dict_old = eold.state_dict() + + enew = ExternalNew(3, 5) + state_dict_new = enew._get_mapped_state_dict(state_dict_old) + assert set(state_dict_new.keys()) == set( + [ + "internal_layer.linear.weight", + "internal_layer.linear.bias", + "internal_layer.param", + "internal_layer.buffer", + "param", + ] + ) + + enew.load_state_dict(state_dict_new) + x = torch.randn(4, 3) - eold.forward(x) - - enew = External(3, 5) - enew._load_from_pretrained_module(eold) - mapping = { - val: key - for key, val in enew._construct_default_mapping(eold, "huggingface", {}).items() - } - assert_equal_parameters(eold, enew, mapping=mapping) + out_old = eold(x) + out_new = enew(x) + assert_allclose(out_old, out_new) + + assert_equal_parameters( + eold, + enew, + mapping={ + "internal_layer.linear.weight": "internal.ff.weight", + "internal_layer.linear.bias": "internal.ff.bias", + "internal_layer.param": "internal.p", + "internal_layer.buffer": "internal.b", + "param": "p", + }, + ) diff --git a/tests/modules/transformer/transformer_stack_test.py b/tests/modules/transformer/transformer_stack_test.py index 0481a407937..cf42f6c0f6d 100644 --- a/tests/modules/transformer/transformer_stack_test.py +++ b/tests/modules/transformer/transformer_stack_test.py @@ -1,20 +1,12 @@ import copy + import torch import pytest from allennlp.common import Params from allennlp.common import cached_transformers - -from allennlp.common.testing import assert_equal_parameters from allennlp.modules.transformer import TransformerStack, TransformerLayer -from allennlp.common.testing import AllenNlpTestCase -from transformers.models.bert.configuration_bert import BertConfig -from transformers.models.bert.modeling_bert import BertEncoder -from transformers.models.roberta.configuration_roberta import RobertaConfig -from transformers.models.roberta.modeling_roberta import RobertaEncoder -from transformers.models.electra.configuration_electra import ElectraConfig -from transformers.models.electra.modeling_electra import ElectraEncoder PARAMS_DICT = { "num_hidden_layers": 3, @@ -26,208 +18,93 @@ "activation": "relu", } - -def get_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) - params["attention_probs_dropout_prob"] = params.pop("attention_dropout") - params["hidden_dropout_prob"] = params.pop("hidden_dropout") - - torch.manual_seed(1234) - hf_module = BertEncoder(BertConfig(**params)) - modules["bert"] = hf_module - - torch.manual_seed(1234) - hf_module = RobertaEncoder(RobertaConfig(**params)) - modules["roberta"] = hf_module - - torch.manual_seed(1234) - hf_module = ElectraEncoder(ElectraConfig(**params)) - modules["electra"] = hf_module - - return modules +SEED = 1234 -class TestTransformerStack(AllenNlpTestCase): - def setup_method(self): - super().setup_method() +@pytest.fixture +def params(): + return Params(copy.deepcopy(PARAMS_DICT)) - self.params_dict = { - "num_hidden_layers": 3, - "hidden_size": 6, - "intermediate_size": 3, - "num_attention_heads": 2, - "attention_dropout": 0.1, - "hidden_dropout": 0.2, - "activation": "relu", - } - params = Params(copy.deepcopy(self.params_dict)) +def test_transformer_stack_from_params(params): + torch.manual_seed(SEED) + transformer_stack = TransformerStack.from_params(params) - self.transformer_stack = TransformerStack.from_params(params) + # Make sure we have the right number of modules. + modules = dict(transformer_stack.named_modules()) + assert len(modules["layers"]) == PARAMS_DICT["num_hidden_layers"] - self.pretrained_name = "bert-base-uncased" + hidden_states = torch.randn(2, 3, 6) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - self.pretrained = cached_transformers.get(self.pretrained_name, False) + # Make sure forward pass can run. + torch.manual_seed(SEED) + output = transformer_stack.forward(hidden_states, attention_mask=attention_mask) - def test_can_construct_from_params(self): - - modules = dict(self.transformer_stack.named_modules()) - assert len(modules["layers"]) == self.params_dict["num_hidden_layers"] - - def test_forward_runs(self): - self.transformer_stack.forward(torch.randn(2, 3, 6), attention_mask=torch.randn(2, 3)) - - with pytest.raises(AssertionError): - self.transformer_stack.forward( - torch.randn(2, 3, 6), - attention_mask=torch.randn(2, 3), - encoder_hidden_states=torch.randn(2, 3, 6), - ) - - def test_layer_same_as_params(self): - params = copy.deepcopy(self.params_dict) - num_hidden_layers = params.pop("num_hidden_layers") - # params = Params(params) - - torch.manual_seed(1234) - transformer_layer = TransformerLayer(**params) - transformer_stack_from_layer = TransformerStack(num_hidden_layers, transformer_layer) - torch.manual_seed(1234) - transformer_stack_from_params = TransformerStack(num_hidden_layers, **params) + # Make sure we get the same results when instantiating from a single layer. + torch.manual_seed(SEED) + layer_params = copy.deepcopy(PARAMS_DICT) + num_hidden_layers = layer_params.pop("num_hidden_layers") + transformer_layer = TransformerLayer(**layer_params) # type: ignore[arg-type] + transformer_stack_from_layer = TransformerStack( + num_hidden_layers, transformer_layer # type: ignore[arg-type] + ) - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + torch.manual_seed(SEED) + from_layer_output = transformer_stack_from_layer.forward( + hidden_states, attention_mask=attention_mask + ) - transformer_stack_from_layer.eval() - transformer_stack_from_params.eval() + assert torch.allclose(from_layer_output[0], output[0]) - torch.manual_seed(1234) - layer_output = transformer_stack_from_layer.forward( - hidden_states, attention_mask=attention_mask + # Make sure forward pass raises with bad input. + with pytest.raises(AssertionError): + transformer_stack.forward( + torch.randn(2, 3, 6), + attention_mask=torch.randn(2, 3), + encoder_hidden_states=torch.randn(2, 3, 6), ) - torch.manual_seed(1234) - params_output = transformer_stack_from_params.forward( - hidden_states, attention_mask=attention_mask - ) - assert torch.allclose(layer_output[0], params_output[0]) +def test_transformer_stack_with_cross_attention(params): + params["add_cross_attention"] = True - def test_cross_attention(self): - params = copy.deepcopy(self.params_dict) - params["add_cross_attention"] = True + transformer_stack = TransformerStack.from_params(params).eval() + modules = dict(transformer_stack.named_modules()) - params = Params(params) + assert hasattr(modules["layers.0"], "cross_attention") - transformer_stack = TransformerStack.from_params(params) - modules = dict(transformer_stack.named_modules()) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + transformer_stack.forward( + torch.randn(2, 3, 6), + attention_mask=attention_mask, + encoder_hidden_states=torch.randn(2, 3, 6), + ) - assert hasattr(modules["layers.0"], "cross_attention") - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - transformer_stack.forward( - torch.randn(2, 3, 6), - attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 3, 6), - ) +@pytest.mark.parametrize("pretrained_model_name", ["epwalsh/bert-xsmall-dummy", "bert-base-cased"]) +def test_loading_from_pretrained(pretrained_model_name): + transformer_stack = TransformerStack.from_pretrained_module(pretrained_model_name).eval() + pretrained_module = cached_transformers.get(pretrained_model_name, True).encoder.eval() - transformer_stack_new = TransformerStack.from_pretrained_module( - transformer_stack, source="allennlp" - ) + batch_size = 2 + seq_length = 15 + hidden_size = transformer_stack.layers[0]._hidden_size - new_modules = dict(transformer_stack_new.named_modules()) - assert hasattr(new_modules["layers.0"], "cross_attention") - - def test_loading_from_pretrained_weights(self): - pretrained_module = self.pretrained.encoder - module = TransformerStack.from_pretrained_module(pretrained_module) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping) - - def test_loading_partial_pretrained_weights(self): - - kwargs = TransformerStack._get_input_arguments(self.pretrained.encoder) - # The pretrained module has 12 bert layers, while the instance will have only 3. - kwargs["num_hidden_layers"] = 3 - transformer_stack = TransformerStack(**kwargs) - transformer_stack._load_from_pretrained_module(self.pretrained.encoder) - mapping = { - val: key - for key, val in transformer_stack._construct_default_mapping( - self.pretrained.encoder, "huggingface", {} - ).items() - } - assert_equal_parameters( - self.pretrained.encoder, - transformer_stack, - mapping, - ) + hidden_states = torch.randn(batch_size, seq_length, hidden_size) + attention_mask = torch.randint(0, 2, (batch_size, seq_length)) + attention_mask_hf = attention_mask[:, None, None, :] + attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 - @pytest.mark.skip("Takes up too much memory") - @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) - def test_forward_against_huggingface_outputs(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + torch.manual_seed(SEED) + output = transformer_stack(hidden_states, attention_mask=attention_mask) - stack = TransformerStack.from_pretrained_module(hf_module) + torch.manual_seed(SEED) + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf) - torch.manual_seed(1234) - output = stack.forward(hidden_states, attention_mask=attention_mask) - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - torch.manual_seed(1234) - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) + assert torch.allclose(output[0], hf_output[0]) - assert torch.allclose(output[0], hf_output[0]) - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - ], - ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - - torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - pretrained_module = pretrained.transformer - else: - pretrained_module = pretrained.encoder - - torch.manual_seed(1234) - module = TransformerStack.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 1 - seq_len = 768 - dim = dict(module.named_modules())["layers.0.attention.self.query"].in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, seq_len)) - mask_reshp = (batch_size, 1, 1, dim) - attention_mask_hf = (attention_mask == 0).view(mask_reshp) - attention_mask_hf = attention_mask_hf.expand(batch_size, 12, seq_len, seq_len) * -10e5 - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() - - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask_hf)[0] - - assert torch.allclose(output, hf_output) +def test_loading_partial_pretrained_weights(): + # The pretrained module has 12 bert layers, while the instance will have only 3. + TransformerStack.from_pretrained_module("bert-base-cased", num_hidden_layers=3, strict=False) From eefbef07c1d9e495ed85a6d05c1849ef90980c27 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 11 May 2021 15:19:48 -0700 Subject: [PATCH 04/23] more fixes --- .../modules/transformer/transformer_layer.py | 8 +- .../modules/transformer/transformer_module.py | 8 +- .../transformer/transformer_layer_test.py | 517 +++++++----------- 3 files changed, 218 insertions(+), 315 deletions(-) diff --git a/allennlp/modules/transformer/transformer_layer.py b/allennlp/modules/transformer/transformer_layer.py index ec19cd8f910..bb25262e3b4 100644 --- a/allennlp/modules/transformer/transformer_layer.py +++ b/allennlp/modules/transformer/transformer_layer.py @@ -29,7 +29,7 @@ class AttentionLayer(TransformerModule, FromParams): Dropout probability for the `OutputLayer`. """ - _relevant_module = "encoder.layers.0.attention" + _relevant_module = "encoder.layer.0.attention" _huggingface_mapping = {"layer": "layers"} def __init__( @@ -83,7 +83,7 @@ def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs["hidden_size"] = config.hidden_size final_kwargs["num_attention_heads"] = config.num_attention_heads - final_kwargs["attention_dropout"] = config.attention_probs_dropout_drop + final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob final_kwargs["hidden_dropout"] = config.hidden_dropout_prob final_kwargs.update(**kwargs) @@ -112,7 +112,7 @@ class TransformerLayer(TransformerModule, FromParams): This is helpful when using the layer in a decoder. """ - _relevant_module = "encoder.layers.0" + _relevant_module = "encoder.layer.0" _huggingface_mapping = { "layer": "layers", "intermediate_act_fn": "act_fn", @@ -214,7 +214,7 @@ def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} final_kwargs["hidden_size"] = config.hidden_size final_kwargs["num_attention_heads"] = config.num_attention_heads - final_kwargs["attention_dropout"] = config.attention_probs_dropout_drop + final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob final_kwargs["hidden_dropout"] = config.hidden_dropout_prob final_kwargs["intermediate_size"] = config.intermediate_size final_kwargs["activation"] = config.hidden_act diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index 65e7a507a32..866d36c5f1c 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -178,7 +178,7 @@ def _get_pretrained_state_dict( state_dict = read_state_dict(weights_path) # Keep just the relevant_module, remove everything else. - state_dict = cls._get_relevant_submodule_state(state_dict) + state_dict = cls._get_relevant_submodule_state(state_dict, relevant_module=relevant_module) return state_dict @@ -275,9 +275,7 @@ def from_pretrained_module( ) # Now map keys from the HuggingFace state_dict to the corresponding keys from # this class. This is called recursively on each submodule of the current module. - state_dict = TransformerModule._get_mapped_state_dict( - model, pretrained_state_dict, mapping=mapping - ) + state_dict = model._get_mapped_state_dict(pretrained_state_dict, mapping=mapping) if not is_distributed() or loading_strategy == DistributedLoadingStrategy.FREE_FOR_ALL: assert state_dict is not None @@ -318,7 +316,7 @@ def _get_mapped_state_dict( } # Recursively call this function from the submodule to map this part # of the state_dict. - module_state_dict = TransformerModule._get_mapped_state_dict(submodule, module_state_dict) + module_state_dict = _get_mapped_state_dict(submodule, module_state_dict) # And then update the full state_dict. for key, value in module_state_dict.items(): state_dict[name + "." + key] = value diff --git a/tests/modules/transformer/transformer_layer_test.py b/tests/modules/transformer/transformer_layer_test.py index 1ecf183eace..10ab837a79c 100644 --- a/tests/modules/transformer/transformer_layer_test.py +++ b/tests/modules/transformer/transformer_layer_test.py @@ -1,13 +1,7 @@ import copy + import torch import pytest - -from allennlp.common import Params -from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters -from allennlp.modules.transformer import AttentionLayer, TransformerLayer -from allennlp.common.testing import AllenNlpTestCase - from transformers.models.bert.configuration_bert import BertConfig from transformers.models.bert.modeling_bert import BertAttention, BertLayer from transformers.models.roberta.configuration_roberta import RobertaConfig @@ -15,6 +9,11 @@ from transformers.models.electra.configuration_electra import ElectraConfig from transformers.models.electra.modeling_electra import ElectraAttention, ElectraLayer +from allennlp.common import Params +from allennlp.common import cached_transformers +from allennlp.modules.transformer import AttentionLayer, TransformerLayer + + ATTENTION_PARAMS_DICT = { "hidden_size": 6, "num_attention_heads": 2, @@ -23,141 +22,113 @@ } -def get_attention_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) +@pytest.fixture +def attention_params(): + return Params(copy.deepcopy(ATTENTION_PARAMS_DICT)) + + +def test_attention(attention_params): + attention_layer = AttentionLayer.from_params(attention_params.duplicate()).eval() + + assert attention_layer.self.num_attention_heads == attention_params["num_attention_heads"] + assert attention_layer.self.attention_head_size == int( + attention_params["hidden_size"] / attention_params["num_attention_heads"] + ) + assert ( + attention_layer.self.all_head_size + == attention_params["num_attention_heads"] * attention_layer.self.attention_head_size + ) + assert attention_layer.self.query.in_features == attention_params["hidden_size"] + assert attention_layer.self.key.in_features == attention_params["hidden_size"] + assert attention_layer.self.value.in_features == attention_params["hidden_size"] + assert attention_layer.self.dropout.p == attention_params["attention_dropout"] + + assert attention_layer.output.dense.in_features == attention_params["hidden_size"] + assert attention_layer.output.dense.out_features == attention_params["hidden_size"] + assert attention_layer.output.layer_norm.normalized_shape[0] == attention_params["hidden_size"] + assert attention_layer.output.dropout.p == attention_params["hidden_dropout"] + + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + attention_layer(torch.randn(2, 3, 6), attention_mask=attention_mask) + + +def get_attention_modules(): + params = copy.deepcopy(ATTENTION_PARAMS_DICT) params["attention_probs_dropout_prob"] = params.pop("attention_dropout") params["hidden_dropout_prob"] = params.pop("hidden_dropout") torch.manual_seed(1234) - hf_module = BertAttention(BertConfig(**params)) - modules["bert"] = hf_module + yield "bert", BertAttention(BertConfig(**params)).eval() torch.manual_seed(1234) - hf_module = RobertaAttention(RobertaConfig(**params)) - modules["roberta"] = hf_module + yield "roberta", RobertaAttention(RobertaConfig(**params)).eval() torch.manual_seed(1234) - hf_module = ElectraAttention(ElectraConfig(**params)) - modules["electra"] = hf_module + yield "electra", ElectraAttention(ElectraConfig(**params)).eval() - return modules +@pytest.mark.parametrize("module_name, hf_module", get_attention_modules()) +def test_attention_matches_huggingface(attention_params, module_name, hf_module): + hidden_states = torch.randn(2, 3, 6) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) -class TestAttentionLayer(AllenNlpTestCase): - def setup_method(self): - super().setup_method() + attention = AttentionLayer.from_params(attention_params).eval() + state_dict = attention._get_mapped_state_dict(hf_module.state_dict()) + attention.load_state_dict(state_dict) - self.params_dict = { - "hidden_size": 6, - "num_attention_heads": 2, - "attention_dropout": 0.1, - "hidden_dropout": 0.2, - } + torch.manual_seed(1234) + output = attention(hidden_states, attention_mask=attention_mask) + # We do this because bert, roberta, electra process the attention_mask at the model level. + attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - params = Params(copy.deepcopy(self.params_dict)) + torch.manual_seed(1234) + hf_output = hf_module(hidden_states, attention_mask=attention_mask_hf) - self.attention_layer = AttentionLayer.from_params(params) + assert torch.allclose(output[0], hf_output[0]) - def test_can_construct_from_params(self): - attention_layer = self.attention_layer +@pytest.mark.parametrize( + "pretrained_name, relevant_top_level_module", + [ + ("bert-base-cased", "bert"), + ("epwalsh/bert-xsmall-dummy", None), + ], +) +def test_attention_from_pretrained(pretrained_name, relevant_top_level_module): + torch.manual_seed(1234) + pretrained = cached_transformers.get(pretrained_name, False).eval() - assert attention_layer.self.num_attention_heads == self.params_dict["num_attention_heads"] - assert attention_layer.self.attention_head_size == int( - self.params_dict["hidden_size"] / self.params_dict["num_attention_heads"] - ) - assert ( - attention_layer.self.all_head_size - == self.params_dict["num_attention_heads"] * attention_layer.self.attention_head_size - ) - assert attention_layer.self.query.in_features == self.params_dict["hidden_size"] - assert attention_layer.self.key.in_features == self.params_dict["hidden_size"] - assert attention_layer.self.value.in_features == self.params_dict["hidden_size"] - assert attention_layer.self.dropout.p == self.params_dict["attention_dropout"] - - assert attention_layer.output.dense.in_features == self.params_dict["hidden_size"] - assert attention_layer.output.dense.out_features == self.params_dict["hidden_size"] - assert ( - attention_layer.output.layer_norm.normalized_shape[0] == self.params_dict["hidden_size"] - ) - assert attention_layer.output.dropout.p == self.params_dict["hidden_dropout"] + if "distilbert" in pretrained_name: + encoder = pretrained.transformer + else: + encoder = pretrained.encoder + # Hacky way to get a bert layer. + pretrained_module = list(encoder.layer.modules())[1].attention - def test_forward_runs(self): - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - self.attention_layer.forward(torch.randn(2, 3, 6), attention_mask=attention_mask) + torch.manual_seed(1234) + module = AttentionLayer.from_pretrained_module( + pretrained_name, + relevant_module=None + if relevant_top_level_module is None + else f"{relevant_top_level_module}.encoder.layer.0.attention", + ).eval() + + batch_size = 2 + seq_length = 15 + hidden_size = module.self.query.in_features + + hidden_states = torch.randn(batch_size, seq_length, hidden_size) + attention_mask = torch.randint(0, 2, (batch_size, seq_length)) + attention_mask_hf = attention_mask[:, None, None, :] + attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 - @pytest.mark.parametrize( - "module_name, hf_module", get_attention_modules(ATTENTION_PARAMS_DICT).items() - ) - def test_forward_against_huggingface_outputs(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - - attention = AttentionLayer.from_pretrained_module(hf_module) - - torch.manual_seed(1234) - output = attention.forward(hidden_states, attention_mask=attention_mask) - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - torch.manual_seed(1234) - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) - - assert torch.allclose(output[0], hf_output[0]) - - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "roberta-base", - ], - ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - - torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - encoder = pretrained.transformer - else: - encoder = pretrained.encoder - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(encoder.layer.modules()): - if i == 1: - break - - pretrained_module = pretrained_module.attention - - torch.manual_seed(1234) - module = AttentionLayer.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 2 - seq_len = 768 - dim = module.self.query.in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, seq_len)) - mask_reshp = (batch_size, 1, 1, dim) - attention_mask_hf = (attention_mask == 0).view(mask_reshp).expand( - batch_size, 12, seq_len, seq_len - ) * -10e5 - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() - - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask_hf)[0] - - assert torch.allclose(output, hf_output, atol=1e-04) + torch.manual_seed(1234) + output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + + torch.manual_seed(1234) + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf)[0] + + assert torch.allclose(output, hf_output, atol=1e-04) LAYER_PARAMS_DICT = { @@ -170,213 +141,147 @@ def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name) } -def get_layer_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) - params["attention_probs_dropout_prob"] = params.pop("attention_dropout") - params["hidden_dropout_prob"] = params.pop("hidden_dropout") +@pytest.fixture +def layer_params(): + return Params(copy.deepcopy(LAYER_PARAMS_DICT)) - # bert, roberta, electra, layoutlm self attentions have the same code. - torch.manual_seed(1234) - hf_module = BertLayer(BertConfig(**params)) - modules["bert"] = hf_module +def test_layer(layer_params): + transformer_layer = TransformerLayer.from_params(layer_params.duplicate()).eval() - torch.manual_seed(1234) - hf_module = RobertaLayer(RobertaConfig(**params)) - modules["roberta"] = hf_module + assert ( + transformer_layer.attention.self.num_attention_heads == layer_params["num_attention_heads"] + ) + assert transformer_layer.attention.self.attention_head_size == int( + layer_params["hidden_size"] / layer_params["num_attention_heads"] + ) + assert ( + transformer_layer.attention.self.all_head_size + == layer_params["num_attention_heads"] + * transformer_layer.attention.self.attention_head_size + ) + assert transformer_layer.attention.self.query.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.self.key.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.self.value.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.self.dropout.p == layer_params["attention_dropout"] + + assert transformer_layer.attention.output.dense.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.output.dense.out_features == layer_params["hidden_size"] + assert ( + transformer_layer.attention.output.layer_norm.normalized_shape[0] + == layer_params["hidden_size"] + ) + assert transformer_layer.attention.output.dropout.p == layer_params["hidden_dropout"] - torch.manual_seed(1234) - hf_module = ElectraLayer(ElectraConfig(**params)) - modules["electra"] = hf_module + assert transformer_layer.intermediate.dense.in_features == layer_params["hidden_size"] + assert transformer_layer.intermediate.dense.out_features == layer_params["intermediate_size"] - return modules + assert transformer_layer.output.dense.in_features == layer_params["intermediate_size"] + assert transformer_layer.output.dense.out_features == layer_params["hidden_size"] + assert transformer_layer.output.layer_norm.normalized_shape[0] == layer_params["hidden_size"] -class TestTransformerLayer(AllenNlpTestCase): - def setup_method(self): - super().setup_method() + assert transformer_layer.output.dropout.p == layer_params["hidden_dropout"] - self.params_dict = { - "hidden_size": 6, - "intermediate_size": 3, - "num_attention_heads": 2, - "attention_dropout": 0.1, - "hidden_dropout": 0.2, - "activation": "relu", - } + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + transformer_layer(torch.randn(2, 3, 6), attention_mask=attention_mask) - params = Params(copy.deepcopy(self.params_dict)) + with pytest.raises(AssertionError): + transformer_layer( + torch.randn(2, 3, 6), + attention_mask=attention_mask, + encoder_hidden_states=torch.randn(2, 3, 6), + ) - self.transformer_layer = TransformerLayer.from_params(params) - self.pretrained_name = "bert-base-uncased" - self.pretrained = cached_transformers.get(self.pretrained_name, False) +def test_layer_with_cross_attention(layer_params): + layer_params["add_cross_attention"] = True - def test_can_construct_from_params(self): + transformer_layer = TransformerLayer.from_params(layer_params).eval() + assert hasattr(transformer_layer, "cross_attention") - transformer_layer = self.transformer_layer + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + transformer_layer( + torch.randn(2, 3, 6), + attention_mask=attention_mask, + encoder_hidden_states=torch.randn(2, 3, 6), + ) - assert ( - transformer_layer.attention.self.num_attention_heads - == self.params_dict["num_attention_heads"] - ) - assert transformer_layer.attention.self.attention_head_size == int( - self.params_dict["hidden_size"] / self.params_dict["num_attention_heads"] - ) - assert ( - transformer_layer.attention.self.all_head_size - == self.params_dict["num_attention_heads"] - * transformer_layer.attention.self.attention_head_size - ) - assert transformer_layer.attention.self.query.in_features == self.params_dict["hidden_size"] - assert transformer_layer.attention.self.key.in_features == self.params_dict["hidden_size"] - assert transformer_layer.attention.self.value.in_features == self.params_dict["hidden_size"] - assert transformer_layer.attention.self.dropout.p == self.params_dict["attention_dropout"] - assert ( - transformer_layer.attention.output.dense.in_features == self.params_dict["hidden_size"] - ) - assert ( - transformer_layer.attention.output.dense.out_features == self.params_dict["hidden_size"] - ) - assert ( - transformer_layer.attention.output.layer_norm.normalized_shape[0] - == self.params_dict["hidden_size"] - ) - assert transformer_layer.attention.output.dropout.p == self.params_dict["hidden_dropout"] +def get_layer_modules(): + params = copy.deepcopy(LAYER_PARAMS_DICT) + params["attention_probs_dropout_prob"] = params.pop("attention_dropout") + params["hidden_dropout_prob"] = params.pop("hidden_dropout") + params["hidden_act"] = params.pop("activation") - assert transformer_layer.intermediate.dense.in_features == self.params_dict["hidden_size"] - assert ( - transformer_layer.intermediate.dense.out_features - == self.params_dict["intermediate_size"] - ) + torch.manual_seed(1234) + yield "bert", BertLayer(BertConfig(**params)).eval() - assert transformer_layer.output.dense.in_features == self.params_dict["intermediate_size"] - assert transformer_layer.output.dense.out_features == self.params_dict["hidden_size"] + torch.manual_seed(1234) + yield "roberta", RobertaLayer(RobertaConfig(**params)).eval() - assert ( - transformer_layer.output.layer_norm.normalized_shape[0] - == self.params_dict["hidden_size"] - ) + torch.manual_seed(1234) + yield "electra", ElectraLayer(ElectraConfig(**params)).eval() - assert transformer_layer.output.dropout.p == self.params_dict["hidden_dropout"] - def test_forward_runs(self): - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - self.transformer_layer.forward(torch.randn(2, 3, 6), attention_mask=attention_mask) +@pytest.mark.parametrize("module_name, hf_module", get_layer_modules()) +def test_layer_matches_huggingface(layer_params, module_name, hf_module): + layer = TransformerLayer.from_params(layer_params).eval() + state_dict = layer._get_mapped_state_dict(hf_module.state_dict()) + layer.load_state_dict(state_dict) - with pytest.raises(AssertionError): - self.transformer_layer.forward( - torch.randn(2, 3, 6), - attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 3, 6), - ) + hidden_states = torch.randn(2, 3, 6) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - def test_cross_attention(self): - params = copy.deepcopy(self.params_dict) - params["add_cross_attention"] = True + torch.manual_seed(1234) + output = layer(hidden_states, attention_mask=attention_mask) + # We do this because bert, roberta, electra process the attention_mask at the model level. + attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 + torch.manual_seed(1234) + hf_output = hf_module(hidden_states, attention_mask=attention_mask_hf) - params = Params(params) + assert torch.allclose(output[0], hf_output[0]) - transformer_layer = TransformerLayer.from_params(params) - assert hasattr(transformer_layer, "cross_attention") - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - transformer_layer.forward( - torch.randn(2, 3, 6), - attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 3, 6), - ) +@pytest.mark.parametrize( + "pretrained_name, relevant_top_level_module", + [ + ("bert-base-cased", "bert"), + ("epwalsh/bert-xsmall-dummy", None), + ], +) +def test_layer_from_pretrained(pretrained_name, relevant_top_level_module): + torch.manual_seed(1234) + pretrained = cached_transformers.get(pretrained_name, False).eval() - transformer_layer_new = TransformerLayer.from_pretrained_module( - transformer_layer, source="allennlp" - ) + if "distilbert" in pretrained_name: + encoder = pretrained.transformer + else: + encoder = pretrained.encoder + # Hacky way to get a bert layer. + pretrained_module = list(encoder.layer.modules())[1] - assert hasattr(transformer_layer_new, "cross_attention") - - def test_loading_from_pretrained_weights(self): - - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(self.pretrained.encoder.layer.modules()): - if i == 1: - break - - module = TransformerLayer.from_pretrained_module(pretrained_module) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - @pytest.mark.parametrize("module_name, hf_module", get_layer_modules(LAYER_PARAMS_DICT).items()) - def test_forward_against_huggingface_outputs(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - - layer = TransformerLayer.from_pretrained_module(hf_module) - - torch.manual_seed(1234) - output = layer.forward(hidden_states, attention_mask=attention_mask) - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - torch.manual_seed(1234) - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) - - assert torch.allclose(output[0], hf_output[0]) - - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "roberta-base", - ], - ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - - torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - encoder = pretrained.transformer - else: - encoder = pretrained.encoder - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(encoder.layer.modules()): - if i == 1: - break - - pretrained_module = pretrained_module - - torch.manual_seed(1234) - module = TransformerLayer.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 2 - seq_len = 768 - dim = module.attention.self.query.in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, seq_len)) - mask_reshp = (batch_size, 1, 1, dim) - attention_mask_hf = (attention_mask == 0).view(mask_reshp).expand( - batch_size, 12, seq_len, seq_len - ) * -10e5 - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() - - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask_hf)[0] - - assert torch.allclose(output, hf_output, atol=1e-04) + torch.manual_seed(1234) + module = TransformerLayer.from_pretrained_module( + pretrained_name, + relevant_module=None + if relevant_top_level_module is None + else f"{relevant_top_level_module}.encoder.layer.0", + ).eval() + + batch_size = 2 + seq_length = 15 + hidden_size = module.attention.self.query.in_features + + hidden_states = torch.randn(batch_size, seq_length, hidden_size) + attention_mask = torch.randint(0, 2, (batch_size, seq_length)) + attention_mask_hf = attention_mask[:, None, None, :] + attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 + + torch.manual_seed(1234) + output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + + torch.manual_seed(1234) + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf)[0] + + assert torch.allclose(output, hf_output, atol=1e-04) From 96abee167785f9ecdbbc17dfcabba4a827f0658f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 12 May 2021 10:32:00 -0700 Subject: [PATCH 05/23] fix embeddings --- allennlp/modules/transformer/__init__.py | 2 +- allennlp/modules/transformer/layer_norm.py | 7 + allennlp/modules/transformer/output_layer.py | 5 +- .../transformer/transformer_embeddings.py | 11 +- .../modules/transformer/transformer_module.py | 6 +- .../transformer_embeddings_test.py | 548 +++++++++--------- 6 files changed, 299 insertions(+), 280 deletions(-) create mode 100644 allennlp/modules/transformer/layer_norm.py diff --git a/allennlp/modules/transformer/__init__.py b/allennlp/modules/transformer/__init__.py index a5a64e45b84..e748cfa9989 100644 --- a/allennlp/modules/transformer/__init__.py +++ b/allennlp/modules/transformer/__init__.py @@ -123,8 +123,8 @@ def forward(self, token_ids: torch.LongTensor, mask: torch.BoolTensor): ``` """ +from allennlp.modules.transformer.layer_norm import LayerNorm from allennlp.modules.transformer.positional_encoding import SinusoidalPositionalEncoding - from allennlp.modules.transformer.transformer_module import ( TransformerModule, DistributedLoadingStrategy, diff --git a/allennlp/modules/transformer/layer_norm.py b/allennlp/modules/transformer/layer_norm.py new file mode 100644 index 00000000000..a7e40d2ad13 --- /dev/null +++ b/allennlp/modules/transformer/layer_norm.py @@ -0,0 +1,7 @@ +import torch + +from allennlp.modules.transformer.transformer_module import TransformerModule + + +class LayerNorm(torch.nn.LayerNorm, TransformerModule): + _huggingface_mapping = {"gamma": "weight", "beta": "bias"} diff --git a/allennlp/modules/transformer/output_layer.py b/allennlp/modules/transformer/output_layer.py index df79b1779ec..2f4b8afdede 100644 --- a/allennlp/modules/transformer/output_layer.py +++ b/allennlp/modules/transformer/output_layer.py @@ -3,10 +3,7 @@ from allennlp.common import FromParams from allennlp.modules.transformer.transformer_module import TransformerModule - - -class LayerNorm(torch.nn.LayerNorm, TransformerModule): - _huggingface_mapping = {"gamma": "weight", "beta": "bias"} +from allennlp.modules.transformer.layer_norm import LayerNorm class OutputLayer(TransformerModule, FromParams): diff --git a/allennlp/modules/transformer/transformer_embeddings.py b/allennlp/modules/transformer/transformer_embeddings.py index 29ab1f02b71..25004964e46 100644 --- a/allennlp/modules/transformer/transformer_embeddings.py +++ b/allennlp/modules/transformer/transformer_embeddings.py @@ -3,6 +3,7 @@ import torch from allennlp.common import FromParams +from allennlp.modules.transformer.layer_norm import LayerNorm from allennlp.modules.transformer.transformer_module import TransformerModule if TYPE_CHECKING: @@ -40,7 +41,7 @@ def __init__(self, embeddings: torch.nn.ModuleDict, embedding_size: int, dropout ) ) self.embeddings = embeddings - self.layer_norm = torch.nn.LayerNorm(embedding_size, eps=1e-12) + self.layer_norm = LayerNorm(embedding_size, eps=1e-12) self.dropout = torch.nn.Dropout(dropout) def forward(self, *inputs) -> torch.Tensor: @@ -187,7 +188,13 @@ def forward( # type: ignore def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} final_kwargs["vocab_size"] = config.vocab_size - final_kwargs["embedding_size"] = config.hidden_size + # For Albert, the embedding size is different than the hidden size used + # in the model, so a linear transform is applied. + if hasattr(config, "embedding_size"): + final_kwargs["embedding_size"] = config.embedding_size + final_kwargs["output_size"] = config.hidden_size + else: + final_kwargs["embedding_size"] = config.hidden_size final_kwargs["pad_token_id"] = config.pad_token_id final_kwargs["max_position_embeddings"] = config.max_position_embeddings final_kwargs["type_vocab_size"] = config.type_vocab_size diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index 866d36c5f1c..d30050f987e 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -280,7 +280,11 @@ def from_pretrained_module( if not is_distributed() or loading_strategy == DistributedLoadingStrategy.FREE_FOR_ALL: assert state_dict is not None logger.info("Loading state_dict into module") - model.load_state_dict(state_dict, strict=strict) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict) + if missing_keys: + logger.warning("Missing keys from pretrained state dict: %s", missing_keys) + if unexpected_keys: + logger.warning("Unexpected keys in pretrained state dict: %s", unexpected_keys) else: # We're in distributed training. `state_dict` is `None` for all process groups # except the global primary. diff --git a/tests/modules/transformer/transformer_embeddings_test.py b/tests/modules/transformer/transformer_embeddings_test.py index d366f4732b4..3954db17a0d 100644 --- a/tests/modules/transformer/transformer_embeddings_test.py +++ b/tests/modules/transformer/transformer_embeddings_test.py @@ -1,23 +1,21 @@ -import pytest import copy + +import pytest import torch from torch.testing import assert_allclose - -from allennlp.common import Params, FromParams -from allennlp.common import cached_transformers - +from transformers import AutoModel from transformers.models.bert.configuration_bert import BertConfig from transformers.models.bert.modeling_bert import BertEmbeddings from transformers.models.albert.configuration_albert import AlbertConfig from transformers.models.albert.modeling_albert import AlbertEmbeddings -from allennlp.common.testing import assert_equal_parameters +from allennlp.common import Params, FromParams from allennlp.modules.transformer import ( TransformerEmbeddings, ImageFeatureEmbeddings, TransformerModule, ) -from allennlp.common.testing import AllenNlpTestCase + PARAMS_DICT = { "vocab_size": 20, @@ -29,9 +27,168 @@ } -def get_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) +@pytest.fixture +def params_dict(): + return copy.deepcopy(PARAMS_DICT) + + +@pytest.fixture +def params(params_dict): + return Params(params_dict) + + +@pytest.fixture +def transformer_embeddings(params): + return TransformerEmbeddings.from_params(params.duplicate()) + + +def test_can_construct_from_params(params_dict, transformer_embeddings): + embeddings = transformer_embeddings.embeddings + assert embeddings.word_embeddings.num_embeddings == params_dict["vocab_size"] + assert embeddings.word_embeddings.embedding_dim == params_dict["embedding_size"] + assert embeddings.word_embeddings.padding_idx == params_dict["pad_token_id"] + + assert embeddings.position_embeddings.num_embeddings == params_dict["max_position_embeddings"] + assert embeddings.position_embeddings.embedding_dim == params_dict["embedding_size"] + + assert embeddings.token_type_embeddings.num_embeddings == params_dict["type_vocab_size"] + assert embeddings.token_type_embeddings.embedding_dim == params_dict["embedding_size"] + + assert transformer_embeddings.layer_norm.normalized_shape[0] == params_dict["embedding_size"] + + assert transformer_embeddings.dropout.p == params_dict["dropout"] + + +def test_sanity(): + class TextEmbeddings(TransformerModule, FromParams): + def __init__( + self, + vocab_size: int, + hidden_size: int, + pad_token_id: int, + max_position_embeddings: int, + type_vocab_size: int, + dropout: float, + ): + super().__init__() + self.word_embeddings = torch.nn.Embedding( + vocab_size, hidden_size, padding_idx=pad_token_id + ) + self.position_embeddings = torch.nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = torch.nn.Embedding(type_vocab_size, hidden_size) + + self.layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = torch.nn.Dropout(dropout) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.layer_norm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + torch.manual_seed(23) + text = TextEmbeddings(10, 5, 2, 3, 7, 0.0) + torch.manual_seed(23) + transformer = TransformerEmbeddings(10, 5, 2, 3, 7, 0.0) + + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + + text_output = text(input_ids, token_type_ids, position_ids) + transformer_output = transformer(input_ids, token_type_ids, position_ids) + + assert_allclose(text_output, transformer_output) + + +def test_forward_runs_with_inputs(transformer_embeddings): + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + transformer_embeddings( + input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids + ) + + +def test_output_size(params): + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + params["output_size"] = 7 + module = TransformerEmbeddings.from_params(params) + output = module(input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids) + + assert output.shape[-1] == 7 + + +def test_no_token_type_layer(params): + params["type_vocab_size"] = 0 + module = TransformerEmbeddings.from_params(params) + assert len(module.embeddings) == 2 + + +@pytest.mark.parametrize( + "pretrained_name, relevant_module", + [ + ("bert-base-cased", "bert.embeddings"), + ("epwalsh/bert-xsmall-dummy", None), + ], +) +def test_loading_from_pretrained_module(pretrained_name, relevant_module): + TransformerEmbeddings.from_pretrained_module(pretrained_name, relevant_module=relevant_module) + + +def test_loading_albert(): + """ + Albert is a special case because it includes a Linear layer in the encoder + that maps the embeddings to the encoder hidden size, but we include this linear + layer within our embedding layer. + """ + transformer_embedding = TransformerEmbeddings.from_pretrained_module( + "albert-base-v2", + mapping={ + "albert.embeddings.LayerNorm": "layer_norm", + "albert.embeddings.LayerNorm": "layer_norm", + "albert.embeddings.word_embeddings": "embeddings.word_embeddings", + "albert.embeddings.position_embeddings": "embeddings.position_embeddings", + "albert.embeddings.token_type_embeddings": "embeddings.token_type_embeddings", + "albert.encoder.embedding_hidden_mapping_in": "linear_transform", + }, + strict=False, + ) + albert = AutoModel.from_pretrained("albert-base-v2") + assert_allclose( + transformer_embedding.embeddings.word_embeddings.weight.data, + albert.embeddings.word_embeddings.weight.data, + ) + assert_allclose( + transformer_embedding.linear_transform.weight.data, + albert.encoder.embedding_hidden_mapping_in.weight.data, + ) + + +def get_modules(): + params = copy.deepcopy(PARAMS_DICT) params["hidden_dropout_prob"] = params.pop("dropout") params["hidden_size"] = params.pop("embedding_size") @@ -39,270 +196,117 @@ def get_modules(params_dict): # bert, roberta, electra self attentions have the same code. torch.manual_seed(1234) - hf_module = BertEmbeddings(BertConfig(**params)) - modules["bert"] = hf_module + yield "bert", BertEmbeddings(BertConfig(**params)) - albertparams = copy.deepcopy(params_dict) + albertparams = copy.deepcopy(PARAMS_DICT) albertparams["hidden_dropout_prob"] = albertparams.pop("dropout") torch.manual_seed(1234) - hf_module = AlbertEmbeddings(AlbertConfig(**albertparams)) - modules["albert"] = hf_module - - return modules - - -class TestTransformerEmbeddings(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = {key: val for key, val in PARAMS_DICT.items()} - - params = Params(copy.deepcopy(self.params_dict)) - - self.transformer_embeddings = TransformerEmbeddings.from_params(params) - - def test_can_construct_from_params(self): - - transformer_embeddings = self.transformer_embeddings.embeddings - - assert ( - transformer_embeddings.word_embeddings.num_embeddings == self.params_dict["vocab_size"] - ) - assert ( - transformer_embeddings.word_embeddings.embedding_dim - == self.params_dict["embedding_size"] - ) - assert ( - transformer_embeddings.word_embeddings.padding_idx == self.params_dict["pad_token_id"] - ) - - assert ( - transformer_embeddings.position_embeddings.num_embeddings - == self.params_dict["max_position_embeddings"] - ) - assert ( - transformer_embeddings.position_embeddings.embedding_dim - == self.params_dict["embedding_size"] - ) - - assert ( - transformer_embeddings.token_type_embeddings.num_embeddings - == self.params_dict["type_vocab_size"] - ) - assert ( - transformer_embeddings.token_type_embeddings.embedding_dim - == self.params_dict["embedding_size"] - ) - - assert ( - self.transformer_embeddings.layer_norm.normalized_shape[0] - == self.params_dict["embedding_size"] - ) - - assert self.transformer_embeddings.dropout.p == self.params_dict["dropout"] - - def test_sanity(self): - class TextEmbeddings(TransformerModule, FromParams): - def __init__( - self, - vocab_size: int, - hidden_size: int, - pad_token_id: int, - max_position_embeddings: int, - type_vocab_size: int, - dropout: float, - ): - super().__init__() - self.word_embeddings = torch.nn.Embedding( - vocab_size, hidden_size, padding_idx=pad_token_id - ) - self.position_embeddings = torch.nn.Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = torch.nn.Embedding(type_vocab_size, hidden_size) - - self.layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12) - self.dropout = torch.nn.Dropout(dropout) - - def forward( - self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None - ): - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - device = input_ids.device if input_ids is not None else inputs_embeds.device - if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(input_shape) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = inputs_embeds + position_embeddings + token_type_embeddings - embeddings = self.layer_norm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - torch.manual_seed(23) - text = TextEmbeddings(10, 5, 2, 3, 7, 0.0) - torch.manual_seed(23) - transformer = TransformerEmbeddings(10, 5, 2, 3, 7, 0.0) - - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - - text_output = text.forward(input_ids, token_type_ids, position_ids) - transformer_output = transformer.forward(input_ids, token_type_ids, position_ids) - - assert_allclose(text_output, transformer_output) - - def test_forward_runs_with_inputs(self): - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - self.transformer_embeddings.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - def test_output_size(self): - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - params = copy.deepcopy(self.params_dict) - params["output_size"] = 7 - params = Params(params) - module = TransformerEmbeddings.from_params(params) - output = module.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - assert output.shape[-1] == 7 - - def test_no_token_type_layer(self): - params = copy.deepcopy(self.params_dict) - params["type_vocab_size"] = 0 - params = Params(params) - module = TransformerEmbeddings.from_params(params) - - assert len(module.embeddings) == 2 - - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "albert-base-v2", - ], + yield "albert", AlbertEmbeddings(AlbertConfig(**albertparams)) + + +@pytest.mark.parametrize("module_name, hf_module", get_modules()) +def test_forward_against_huggingface_output(transformer_embeddings, module_name, hf_module): + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + + state_dict = transformer_embeddings._get_mapped_state_dict(hf_module.state_dict()) + if "position_ids" in state_dict: + del state_dict["position_ids"] + transformer_embeddings.load_state_dict(state_dict) + + torch.manual_seed(1234) + transformer_embeddings = ( + transformer_embeddings.eval() + ) # setting to eval mode to avoid non-deterministic dropout. + output = transformer_embeddings( + input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - pretrained_module = cached_transformers.get(pretrained_name, False).embeddings - module = TransformerEmbeddings.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - missing = assert_equal_parameters(pretrained_module, module, mapping=mapping) - assert len(missing) == 0 - - @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) - def test_forward_against_huggingface_output(self, module_name, hf_module): - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - - torch.manual_seed(1234) - embeddings = TransformerEmbeddings.from_pretrained_module(hf_module) - - torch.manual_seed(1234) - embeddings = embeddings.eval() # setting to eval mode to avoid non-deterministic dropout. - output = embeddings.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - torch.manual_seed(1234) - hf_module = hf_module.eval() # setting to eval mode to avoid non-deterministic dropout. - hf_output = hf_module.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - assert torch.allclose(output, hf_output) - - -class TestImageFeatureEmbeddings(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = {"feature_size": 3, "embedding_size": 5, "dropout": 0.1} - - params = Params(copy.deepcopy(self.params_dict)) - - self.img_embeddings = ImageFeatureEmbeddings.from_params(params) - - def test_can_construct_from_params(self): - assert ( - self.img_embeddings.embeddings.image_embeddings.in_features - == self.params_dict["feature_size"] - ) - assert ( - self.img_embeddings.embeddings.image_embeddings.out_features - == self.params_dict["embedding_size"] - ) - assert ( - self.img_embeddings.embeddings.location_embeddings.out_features - == self.params_dict["embedding_size"] - ) - assert self.img_embeddings.dropout.p == self.params_dict["dropout"] - - def test_forward_runs_with_inputs(self): - batch_size = 2 - feature_dim = self.params_dict["feature_size"] - image_feature = torch.randn(batch_size, feature_dim) - image_location = torch.randn(batch_size, 4) - self.img_embeddings.forward(image_feature, image_location) - - def test_sanity(self): - class OldImageFeatureEmbeddings(TransformerModule, FromParams): - """Construct the embeddings from image, spatial location (omit now) and - token_type embeddings. - """ - - def __init__(self, feature_size: int, embedding_size: int, dropout: float = 0.0): - super().__init__() - - self.image_embeddings = torch.nn.Linear(feature_size, embedding_size) - self.image_location_embeddings = torch.nn.Linear(4, embedding_size, bias=False) - self.layer_norm = torch.nn.LayerNorm(embedding_size, eps=1e-12) - self.dropout = torch.nn.Dropout(dropout) - - def forward(self, image_feature: torch.Tensor, image_location: torch.Tensor): - img_embeddings = self.image_embeddings(image_feature) - loc_embeddings = self.image_location_embeddings(image_location) - embeddings = self.layer_norm(img_embeddings + loc_embeddings) - embeddings = self.dropout(embeddings) - - return embeddings - - torch.manual_seed(23) - old = OldImageFeatureEmbeddings(**self.params_dict) - torch.manual_seed(23) - now = ImageFeatureEmbeddings(**self.params_dict) - - batch_size = 2 - - image_feature = torch.randn(batch_size, self.params_dict["feature_size"]) - image_location = torch.randn(batch_size, 4) - - torch.manual_seed(23) - old_output = old.forward(image_feature, image_location) - torch.manual_seed(23) - now_output = now.forward(image_feature, image_location) - - assert_allclose(old_output, now_output) + + torch.manual_seed(1234) + hf_module = hf_module.eval() # setting to eval mode to avoid non-deterministic dropout. + hf_output = hf_module( + input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids + ) + + assert torch.allclose(output, hf_output) + + +@pytest.fixture +def image_params_dict(): + return {"feature_size": 3, "embedding_size": 5, "dropout": 0.1} + + +@pytest.fixture +def image_params(image_params_dict): + return Params(image_params_dict) + + +@pytest.fixture +def image_embeddings(image_params): + return ImageFeatureEmbeddings.from_params(image_params.duplicate()) + + +def test_can_construct_image_embeddings_from_params(image_embeddings, image_params_dict): + assert ( + image_embeddings.embeddings.image_embeddings.in_features + == image_params_dict["feature_size"] + ) + assert ( + image_embeddings.embeddings.image_embeddings.out_features + == image_params_dict["embedding_size"] + ) + assert ( + image_embeddings.embeddings.location_embeddings.out_features + == image_params_dict["embedding_size"] + ) + assert image_embeddings.dropout.p == image_params_dict["dropout"] + + +def test_image_embedding_forward_runs_with_inputs(image_embeddings, image_params_dict): + batch_size = 2 + feature_dim = image_params_dict["feature_size"] + image_feature = torch.randn(batch_size, feature_dim) + image_location = torch.randn(batch_size, 4) + image_embeddings(image_feature, image_location) + + +def test_image_embeddings_sanity(image_params_dict): + class OldImageFeatureEmbeddings(TransformerModule, FromParams): + """Construct the embeddings from image, spatial location (omit now) and + token_type embeddings. + """ + + def __init__(self, feature_size: int, embedding_size: int, dropout: float = 0.0): + super().__init__() + + self.image_embeddings = torch.nn.Linear(feature_size, embedding_size) + self.image_location_embeddings = torch.nn.Linear(4, embedding_size, bias=False) + self.layer_norm = torch.nn.LayerNorm(embedding_size, eps=1e-12) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, image_feature: torch.Tensor, image_location: torch.Tensor): + img_embeddings = self.image_embeddings(image_feature) + loc_embeddings = self.image_location_embeddings(image_location) + embeddings = self.layer_norm(img_embeddings + loc_embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + torch.manual_seed(23) + old = OldImageFeatureEmbeddings(**image_params_dict) + torch.manual_seed(23) + now = ImageFeatureEmbeddings(**image_params_dict) + + batch_size = 2 + + image_feature = torch.randn(batch_size, image_params_dict["feature_size"]) + image_location = torch.randn(batch_size, 4) + + torch.manual_seed(23) + old_output = old(image_feature, image_location) + torch.manual_seed(23) + now_output = now(image_feature, image_location) + + assert_allclose(old_output, now_output) From d3ed13bd7bf12990b9aa787c970dc3f62eff5c3d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 12 May 2021 11:26:58 -0700 Subject: [PATCH 06/23] fix toolkit tests --- .../modules/transformer/transformer_module.py | 17 ++++- tests/modules/transformer/toolkit_test.py | 68 +++++++++++++------ 2 files changed, 60 insertions(+), 25 deletions(-) diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index d30050f987e..b65bc2a617f 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -194,6 +194,7 @@ def _from_config(cls: Type[_T], config: "PretrainedConfig", **kwargs) -> _T: def from_pretrained_module( cls: Type[_T], model_name: str, + *, load_weights: bool = True, weights_path: Optional[Union[str, PathLike]] = None, auto_config_kwargs: Optional[Dict[str, Any]] = None, @@ -305,11 +306,21 @@ def _get_mapped_state_dict( ) -> StateDictType: # First fix all top-level keys according to `combined_mapping`. combined_mapping = module._get_mapping(mapping) if isinstance(module, TransformerModule) else {} - for hf_key, cls_key in combined_mapping.items(): - relevant_keys = set([key for key in state_dict.keys() if key.startswith(hf_key)]) + for hf_key, cls_key in sorted( + # Sort by most specific key first. + combined_mapping.items(), + key=lambda x: x[0].count("."), + reverse=True, + ): + relevant_keys = set( + [key for key in state_dict.keys() if (key == hf_key or key.startswith(hf_key + "."))] + ) for key in relevant_keys: new_key = key.replace(hf_key, cls_key, 1) - state_dict[new_key] = state_dict.pop(key) + # We have to be careful not to overwrite an entry that we might have updated + # on a previous iteration of this loop due to having a more specific key. + if new_key not in state_dict: + state_dict[new_key] = state_dict.pop(key) # Now loop through the submodules, calling this function on each submodule. for name, submodule in module.named_children(): diff --git a/tests/modules/transformer/toolkit_test.py b/tests/modules/transformer/toolkit_test.py index cd1bf60e9fd..ff59b9cf6b5 100644 --- a/tests/modules/transformer/toolkit_test.py +++ b/tests/modules/transformer/toolkit_test.py @@ -1,9 +1,10 @@ import torch +from torch.testing import assert_allclose from overrides import overrides +from transformers import AutoModel from transformers.models.albert.modeling_albert import AlbertEmbeddings from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters from allennlp.data.vocabulary import Vocabulary from allennlp.modules.token_embedders import Embedding, TokenEmbedder from allennlp.modules.transformer import TransformerStack, TransformerEmbeddings @@ -49,15 +50,19 @@ def forward(self, token_ids: torch.LongTensor): tiny.forward(torch.LongTensor([[0, 1, 2]])) def test_use_first_four_layers_of_pretrained(self): - pretrained = cached_transformers.get("bert-base-uncased", False) + pretrained = "bert-base-cased" class SmallTransformer(TokenEmbedder): def __init__(self): super().__init__() - self.embeddings = TransformerEmbeddings.from_pretrained_module(pretrained) - + self.embeddings = TransformerEmbeddings.from_pretrained_module( + pretrained, relevant_module="bert.embeddings" + ) self.transformer = TransformerStack.from_pretrained_module( - pretrained, num_hidden_layers=4 + pretrained, + num_hidden_layers=4, + relevant_module="bert.encoder", + strict=False, ) @overrides @@ -68,19 +73,27 @@ def forward(self, token_ids: torch.LongTensor): small = SmallTransformer() assert len(small.transformer.layers) == 4 - small.forward(torch.LongTensor([[0, 1, 2]])) + small(torch.LongTensor([[0, 1, 2]])) def test_use_selected_layers_of_bert_for_different_purposes(self): class MediumTransformer(torch.nn.Module): def __init__(self): super().__init__() - self.embeddings = TransformerEmbeddings.from_pretrained_module("bert-base-uncased") + self.embeddings = TransformerEmbeddings.from_pretrained_module( + "bert-base-cased", relevant_module="bert.embeddings" + ) self.separate_transformer = TransformerStack.from_pretrained_module( - "bert-base-uncased", num_hidden_layers=range(0, 8) + "bert-base-cased", + relevant_module="bert.encoder", + num_hidden_layers=8, + strict=False, ) self.combined_transformer = TransformerStack.from_pretrained_module( - "bert-base-uncased", - num_hidden_layers=range(8, 12), + "bert-base-cased", + relevant_module="bert.encoder", + num_hidden_layers=4, + mapping={f"layer.{l}": f"layers.{i}" for (i, l) in enumerate(range(8, 12))}, + strict=False, ) @overrides @@ -106,22 +119,31 @@ def forward( assert (len(medium.separate_transformer.layers)) == 8 assert (len(medium.combined_transformer.layers)) == 4 - pretrained = cached_transformers.get("bert-base-uncased", False) + pretrained = cached_transformers.get("bert-base-cased", False) pretrained_layers = dict(pretrained.encoder.layer.named_modules()) - medium_layers = dict(medium.combined_transformer.layers.named_modules()) + separate_layers = dict(medium.separate_transformer.layers.named_modules()) + assert_allclose( + separate_layers["0"].intermediate.dense.weight.data, + pretrained_layers["0"].intermediate.dense.weight.data, + ) - assert_equal_parameters( - medium_layers["0"], pretrained_layers["8"], TransformerStack._huggingface_mapping + combined_layers = dict(medium.combined_transformer.layers.named_modules()) + assert_allclose( + combined_layers["0"].intermediate.dense.weight.data, + pretrained_layers["8"].intermediate.dense.weight.data, ) - assert_equal_parameters( - medium_layers["1"], pretrained_layers["9"], TransformerStack._huggingface_mapping + assert_allclose( + combined_layers["1"].intermediate.dense.weight.data, + pretrained_layers["9"].intermediate.dense.weight.data, ) - assert_equal_parameters( - medium_layers["2"], pretrained_layers["10"], TransformerStack._huggingface_mapping + assert_allclose( + combined_layers["2"].intermediate.dense.weight.data, + pretrained_layers["10"].intermediate.dense.weight.data, ) - assert_equal_parameters( - medium_layers["3"], pretrained_layers["11"], TransformerStack._huggingface_mapping + assert_allclose( + combined_layers["3"].intermediate.dense.weight.data, + pretrained_layers["11"].intermediate.dense.weight.data, ) def test_combination_of_two_different_berts(self): @@ -130,8 +152,10 @@ def test_combination_of_two_different_berts(self): class AlmostRegularTransformer(TokenEmbedder): def __init__(self): super().__init__() - self.embeddings = TransformerEmbeddings.get_relevant_module("albert-base-v2") - self.transformer = TransformerStack.from_pretrained_module("bert-base-uncased") + self.embeddings = AutoModel.from_pretrained("albert-base-v2").embeddings + self.transformer = TransformerStack.from_pretrained_module( + "bert-base-cased", relevant_module="bert.encoder" + ) # We want to tune only the embeddings, because that's our experiment. self.transformer.requires_grad = False From daf4aa719757bdab195508248073dc25f30f64d7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 12 May 2021 13:09:59 -0700 Subject: [PATCH 07/23] fix self attention --- .../modules/transformer/self_attention.py | 2 +- .../transformer/self_attention_test.py | 193 ++++++------------ 2 files changed, 60 insertions(+), 135 deletions(-) diff --git a/allennlp/modules/transformer/self_attention.py b/allennlp/modules/transformer/self_attention.py index 8be8c60c6fa..71c253f9573 100644 --- a/allennlp/modules/transformer/self_attention.py +++ b/allennlp/modules/transformer/self_attention.py @@ -156,4 +156,4 @@ def _from_config(cls, config: "PretrainedConfig", **kwargs): else: final_kwargs["dropout"] = config.attention_probs_dropout_prob final_kwargs.update(**kwargs) - return final_kwargs + return cls(**final_kwargs) diff --git a/tests/modules/transformer/self_attention_test.py b/tests/modules/transformer/self_attention_test.py index e29ae44cf9e..2bceb73d2b8 100644 --- a/tests/modules/transformer/self_attention_test.py +++ b/tests/modules/transformer/self_attention_test.py @@ -1,21 +1,13 @@ import copy + import torch import pytest +from transformers import AutoModel from allennlp.common import Params -from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters, AllenNlpTestCase from allennlp.modules.transformer import SelfAttention from allennlp.nn.util import min_value_of_dtype -from transformers.models.bert.configuration_bert import BertConfig -from transformers.models.bert.modeling_bert import BertSelfAttention -from transformers.models.roberta.configuration_roberta import RobertaConfig -from transformers.models.roberta.modeling_roberta import RobertaSelfAttention -from transformers.models.electra.configuration_electra import ElectraConfig -from transformers.models.electra.modeling_electra import ElectraSelfAttention -from transformers.models.distilbert.configuration_distilbert import DistilBertConfig -from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention PARAMS_DICT = { "hidden_size": 6, @@ -24,145 +16,78 @@ } -def get_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) - params["attention_probs_dropout_prob"] = params.pop("dropout") +@pytest.fixture +def params_dict(): + return copy.deepcopy(PARAMS_DICT) - # bert, roberta, electra self attentions have the same code. - torch.manual_seed(1234) - hf_module = BertSelfAttention(BertConfig(**params)) - modules["bert"] = hf_module +@pytest.fixture +def params(params_dict): + return Params(params_dict) - torch.manual_seed(1234) - hf_module = RobertaSelfAttention(RobertaConfig(**params)) - modules["roberta"] = hf_module - torch.manual_seed(1234) - hf_module = ElectraSelfAttention(ElectraConfig(**params)) - modules["electra"] = hf_module +@pytest.fixture +def self_attention(params): + return SelfAttention.from_params(params.duplicate()) - torch.manual_seed(1234) - distilparams = copy.deepcopy(params_dict) - distilparams["n_heads"] = distilparams.pop("num_attention_heads") - distilparams["dim"] = distilparams.pop("hidden_size") - distilparams["attention_dropout"] = distilparams.pop("dropout") - hf_module = MultiHeadSelfAttention(DistilBertConfig(**distilparams)) - modules["distilbert"] = hf_module - return modules - - -class TestSelfAttention(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = {key: val for key, val in PARAMS_DICT.items()} +def test_can_construct_from_params(self_attention, params_dict): + assert self_attention.num_attention_heads == params_dict["num_attention_heads"] + assert self_attention.attention_head_size == int( + params_dict["hidden_size"] / params_dict["num_attention_heads"] + ) - params = Params(copy.deepcopy(self.params_dict)) + assert ( + self_attention.all_head_size + == params_dict["num_attention_heads"] * self_attention.attention_head_size + ) - self.self_attention = SelfAttention.from_params(params) + assert self_attention.query.in_features == params_dict["hidden_size"] + assert self_attention.key.in_features == params_dict["hidden_size"] + assert self_attention.value.in_features == params_dict["hidden_size"] - def test_can_construct_from_params(self): - assert self.self_attention.num_attention_heads == self.params_dict["num_attention_heads"] - assert self.self_attention.attention_head_size == int( - self.params_dict["hidden_size"] / self.params_dict["num_attention_heads"] - ) + assert self_attention.dropout.p == params_dict["dropout"] - assert ( - self.self_attention.all_head_size - == self.params_dict["num_attention_heads"] * self.self_attention.attention_head_size - ) - assert self.self_attention.query.in_features == self.params_dict["hidden_size"] - assert self.self_attention.key.in_features == self.params_dict["hidden_size"] - assert self.self_attention.value.in_features == self.params_dict["hidden_size"] +@pytest.mark.parametrize( + "pretrained_name, relevant_module", + [ + ("bert-base-cased", "bert.encoder.layer.0.attention.self"), + ("google/electra-base-generator", "electra.encoder.layer.0.attention.self"), + ("distilbert-base-uncased", "distilbert.transformer.layer.0.attention"), + ], +) +def test_loading_from_pretrained_weights_using_model_name(pretrained_name, relevant_module): + torch.manual_seed(1234) + module = SelfAttention.from_pretrained_module(pretrained_name, relevant_module=relevant_module) - assert self.self_attention.dropout.p == self.params_dict["dropout"] + torch.manual_seed(1234) + pretrained_module = dict(AutoModel.from_pretrained(pretrained_name).named_modules())[ + # Module name will exclude the top-level part (e.g. 'bert.', 'electra.') for some reason. + relevant_module[relevant_module.index(".") + 1 :] + ] - @pytest.mark.skip("Takes up too much memory") - @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) - def test_forward_against_huggingface_output(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + batch_size = 2 + seq_len = 3 + dim = module.query.in_features + hidden_states = torch.randn(batch_size, seq_len, dim) + attention_mask = torch.randint(0, 2, (batch_size, 1, 1, seq_len)) - torch.manual_seed(1234) - self_attention = SelfAttention.from_pretrained_module(hf_module) - - output = self_attention.forward(hidden_states, attention_mask=attention_mask) - if module_name == "distilbert": - hf_output = hf_module.forward( - hidden_states, hidden_states, hidden_states, mask=attention_mask - ) - else: - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) - - assert torch.allclose(output[0], hf_output[0]) - - @pytest.mark.skip("Takes up too much memory") - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "roberta-base", - "google/electra-base-generator", - "distilbert-base-uncased", - ], - ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): + # setting to eval mode to avoid non-deterministic dropout. + module = module.eval() + pretrained_module = pretrained_module.eval() + torch.manual_seed(1234) + output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + if "distilbert" in pretrained_name: torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - encoder = pretrained.transformer - else: - encoder = pretrained.encoder - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(encoder.layer.modules()): - if i == 1: - break - - # Get the self attention layer. - if "distilbert" in pretrained_name: - pretrained_module = pretrained_module.attention - else: - pretrained_module = pretrained_module.attention.self - + hf_output = pretrained_module( + hidden_states, hidden_states, hidden_states, mask=attention_mask + )[0] + else: + # The attn_mask is processed outside the self attention module in HF bert models. + attention_mask = (~(attention_mask == 1)) * min_value_of_dtype(hidden_states.dtype) torch.manual_seed(1234) - module = SelfAttention.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 2 - seq_len = 3 - dim = module.query.in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, 1, 1, seq_len)) - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask)[0] - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - if "distilbert" in pretrained_name: - torch.manual_seed(1234) - hf_output = pretrained_module.forward( - hidden_states, hidden_states, hidden_states, mask=attention_mask - )[0] - else: - # The attn_mask is processed outside the self attention module in HF bert models. - attention_mask = (~(attention_mask == 1)) * min_value_of_dtype(hidden_states.dtype) - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask)[0] - - assert torch.allclose(output, hf_output) + assert torch.allclose(output, hf_output) From 8fb523370535e25beb07e369047bd4e57a94a0b5 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 12 May 2021 13:23:54 -0700 Subject: [PATCH 08/23] fix bimodal encoder tests --- .../transformer/bimodal_encoder_test.py | 183 +++++++++--------- 1 file changed, 91 insertions(+), 92 deletions(-) diff --git a/tests/modules/transformer/bimodal_encoder_test.py b/tests/modules/transformer/bimodal_encoder_test.py index b95af3bfa1f..3ac682cccbf 100644 --- a/tests/modules/transformer/bimodal_encoder_test.py +++ b/tests/modules/transformer/bimodal_encoder_test.py @@ -1,95 +1,94 @@ -import copy import torch +from torch.testing import assert_allclose +from transformers import AutoModel +import pytest + from allennlp.common import Params -from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters from allennlp.modules.transformer import BiModalEncoder -from allennlp.common.testing import AllenNlpTestCase - - -class TestBiModalEncoder(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = { - "num_hidden_layers1": 3, - "num_hidden_layers2": 3, - "hidden_size1": 12, - "hidden_size2": 12, - "combined_hidden_size": 12, - "intermediate_size1": 3, - "intermediate_size2": 3, - "num_attention_heads1": 4, - "num_attention_heads2": 6, - "combined_num_attention_heads": 2, - "attention_dropout1": 0.1, - "hidden_dropout1": 0.2, - "attention_dropout2": 0.1, - "hidden_dropout2": 0.2, - "activation": "relu", - "biattention_id1": [1, 2], - "biattention_id2": [1, 2], - "fixed_layer1": 1, - "fixed_layer2": 1, - } - - params = Params(copy.deepcopy(self.params_dict)) - - self.bimodal_encoder = BiModalEncoder.from_params(params) - - self.pretrained = cached_transformers.get("bert-base-uncased", False) - - def test_can_construct_from_params(self): - - modules = dict(self.bimodal_encoder.named_modules()) - assert len(modules["layers1"]) == self.params_dict["num_hidden_layers1"] - assert len(modules["layers2"]) == self.params_dict["num_hidden_layers2"] - - def test_forward_runs(self): - - embedding1 = torch.randn(16, 34, self.params_dict["hidden_size1"]) - embedding2 = torch.randn(16, 2, self.params_dict["hidden_size2"]) - attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 - attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 - - self.bimodal_encoder.forward(embedding1, embedding2, attn_mask1, attn_mask2) - - def test_loading_from_pretrained_weights(self): - pretrained_module = self.pretrained.encoder - required_kwargs = [ - "num_hidden_layers2", - "hidden_size2", - "combined_hidden_size", - "intermediate_size2", - "num_attention_heads2", - "combined_num_attention_heads", - "attention_dropout2", - "hidden_dropout2", - "biattention_id1", - "biattention_id2", - "fixed_layer1", - "fixed_layer2", - ] - kwargs = {key: self.params_dict[key] for key in required_kwargs} - module = BiModalEncoder.from_pretrained_module(pretrained_module, **kwargs) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters( - pretrained_module, - module, - ignore_missing=True, - mapping=mapping, - ) - - def test_default_parameters(self): - encoder = BiModalEncoder() - embedding1 = torch.randn(16, 34, 1024) - embedding2 = torch.randn(16, 2, 1024) - attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 - attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 - - encoder.forward(embedding1, embedding2, attn_mask1, attn_mask2) + + +@pytest.fixture +def params_dict(): + return { + "num_hidden_layers1": 3, + "num_hidden_layers2": 3, + "hidden_size1": 12, + "hidden_size2": 12, + "combined_hidden_size": 12, + "intermediate_size1": 3, + "intermediate_size2": 3, + "num_attention_heads1": 4, + "num_attention_heads2": 6, + "combined_num_attention_heads": 2, + "attention_dropout1": 0.1, + "hidden_dropout1": 0.2, + "attention_dropout2": 0.1, + "hidden_dropout2": 0.2, + "activation": "relu", + "biattention_id1": [1, 2], + "biattention_id2": [1, 2], + "fixed_layer1": 1, + "fixed_layer2": 1, + } + + +@pytest.fixture +def params(params_dict): + return Params(params_dict) + + +@pytest.fixture +def bimodal_encoder(params): + return BiModalEncoder.from_params(params.duplicate()) + + +def test_can_construct_from_params(bimodal_encoder, params_dict): + modules = dict(bimodal_encoder.named_modules()) + assert len(modules["layers1"]) == params_dict["num_hidden_layers1"] + assert len(modules["layers2"]) == params_dict["num_hidden_layers2"] + + +def test_forward_runs(bimodal_encoder, params_dict): + embedding1 = torch.randn(16, 34, params_dict["hidden_size1"]) + embedding2 = torch.randn(16, 2, params_dict["hidden_size2"]) + attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 + attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 + bimodal_encoder(embedding1, embedding2, attn_mask1, attn_mask2) + + +def test_loading_from_pretrained_weights(params_dict): + pretrained_module = AutoModel.from_pretrained("bert-base-cased").encoder + + required_kwargs = [ + "num_hidden_layers2", + "hidden_size2", + "combined_hidden_size", + "intermediate_size2", + "num_attention_heads2", + "combined_num_attention_heads", + "attention_dropout2", + "hidden_dropout2", + "biattention_id1", + "biattention_id2", + "fixed_layer1", + "fixed_layer2", + ] + kwargs = {key: params_dict[key] for key in required_kwargs} + + module = BiModalEncoder.from_pretrained_module( + "bert-base-cased", relevant_module="bert.encoder", strict=False, **kwargs + ) + assert_allclose( + module.layers1[0].intermediate.dense.weight.data, + pretrained_module.layer[0].intermediate.dense.weight.data, + ) + + +def test_default_parameters(): + encoder = BiModalEncoder() + embedding1 = torch.randn(16, 34, 1024) + embedding2 = torch.randn(16, 2, 1024) + attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 + attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 + + encoder(embedding1, embedding2, attn_mask1, attn_mask2) From 0185e18bf5db9ecb1ff581d02c59b23fcbf4b7ba Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 12 May 2021 13:32:39 -0700 Subject: [PATCH 09/23] fix more tests --- .../transformer/activation_layer_test.py | 38 ++++--- .../transformer/bimodal_attention_test.py | 103 +++++++++--------- 2 files changed, 72 insertions(+), 69 deletions(-) diff --git a/tests/modules/transformer/activation_layer_test.py b/tests/modules/transformer/activation_layer_test.py index 8c1b7ebef26..2af0338a92e 100644 --- a/tests/modules/transformer/activation_layer_test.py +++ b/tests/modules/transformer/activation_layer_test.py @@ -1,32 +1,34 @@ -import copy import torch +import pytest from allennlp.common import Params from allennlp.modules.transformer import ActivationLayer -from allennlp.common.testing import AllenNlpTestCase -class TestActivationLayer(AllenNlpTestCase): - def setup_method(self): - super().setup_method() +@pytest.fixture +def params_dict(): + return { + "hidden_size": 5, + "intermediate_size": 3, + "activation": "relu", + } - self.params_dict = { - "hidden_size": 5, - "intermediate_size": 3, - "activation": "relu", - } - params = Params(copy.deepcopy(self.params_dict)) +@pytest.fixture +def params(params_dict): + return Params(params_dict) - self.activation_layer = ActivationLayer.from_params(params) - def test_can_construct_from_params(self): +@pytest.fixture +def activation_layer(params): + return ActivationLayer.from_params(params.duplicate()) - activation_layer = self.activation_layer - assert activation_layer.dense.in_features == self.params_dict["hidden_size"] - assert activation_layer.dense.out_features == self.params_dict["intermediate_size"] +def test_can_construct_from_params(activation_layer, params_dict): + activation_layer = activation_layer + assert activation_layer.dense.in_features == params_dict["hidden_size"] + assert activation_layer.dense.out_features == params_dict["intermediate_size"] - def test_forward_runs(self): - self.activation_layer.forward(torch.randn(7, 5)) +def test_forward_runs(activation_layer): + activation_layer.forward(torch.randn(7, 5)) diff --git a/tests/modules/transformer/bimodal_attention_test.py b/tests/modules/transformer/bimodal_attention_test.py index 40dc81f12de..270aefd23e7 100644 --- a/tests/modules/transformer/bimodal_attention_test.py +++ b/tests/modules/transformer/bimodal_attention_test.py @@ -1,55 +1,56 @@ -import copy import torch +import pytest from allennlp.common import Params from allennlp.modules.transformer import BiModalAttention -from allennlp.common.testing import AllenNlpTestCase - - -class TestBiModalAttention(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = { - "hidden_size1": 6, - "hidden_size2": 4, - "combined_hidden_size": 16, - "num_attention_heads": 2, - "dropout1": 0.1, - "dropout2": 0.2, - } - - params = Params(copy.deepcopy(self.params_dict)) - - self.biattention = BiModalAttention.from_params(params) - - def test_can_construct_from_params(self): - - biattention = self.biattention - - assert biattention.num_attention_heads == self.params_dict["num_attention_heads"] - assert biattention.attention_head_size == int( - self.params_dict["combined_hidden_size"] / self.params_dict["num_attention_heads"] - ) - assert ( - biattention.all_head_size - == self.params_dict["num_attention_heads"] * biattention.attention_head_size - ) - assert biattention.query1.in_features == self.params_dict["hidden_size1"] - assert biattention.key1.in_features == self.params_dict["hidden_size1"] - assert biattention.value1.in_features == self.params_dict["hidden_size1"] - assert biattention.dropout1.p == self.params_dict["dropout1"] - - assert biattention.query2.in_features == self.params_dict["hidden_size2"] - assert biattention.key2.in_features == self.params_dict["hidden_size2"] - assert biattention.value2.in_features == self.params_dict["hidden_size2"] - assert biattention.dropout2.p == self.params_dict["dropout2"] - - def test_forward_runs(self): - - self.biattention.forward( - torch.randn(2, 3, 6), - torch.randn(2, 3, 4), - torch.randint(0, 2, (2, 2, 3, 3)) == 1, # creating boolean tensors - torch.randint(0, 2, (2, 2, 3, 3)) == 1, - ) + + +@pytest.fixture +def params_dict(): + return { + "hidden_size1": 6, + "hidden_size2": 4, + "combined_hidden_size": 16, + "num_attention_heads": 2, + "dropout1": 0.1, + "dropout2": 0.2, + } + + +@pytest.fixture +def params(params_dict): + return Params(params_dict) + + +@pytest.fixture +def biattention(params): + return BiModalAttention.from_params(params.duplicate()) + + +def test_can_construct_from_params(biattention, params_dict): + assert biattention.num_attention_heads == params_dict["num_attention_heads"] + assert biattention.attention_head_size == int( + params_dict["combined_hidden_size"] / params_dict["num_attention_heads"] + ) + assert ( + biattention.all_head_size + == params_dict["num_attention_heads"] * biattention.attention_head_size + ) + assert biattention.query1.in_features == params_dict["hidden_size1"] + assert biattention.key1.in_features == params_dict["hidden_size1"] + assert biattention.value1.in_features == params_dict["hidden_size1"] + assert biattention.dropout1.p == params_dict["dropout1"] + + assert biattention.query2.in_features == params_dict["hidden_size2"] + assert biattention.key2.in_features == params_dict["hidden_size2"] + assert biattention.value2.in_features == params_dict["hidden_size2"] + assert biattention.dropout2.p == params_dict["dropout2"] + + +def test_forward_runs(biattention): + biattention( + torch.randn(2, 3, 6), + torch.randn(2, 3, 4), + torch.randint(0, 2, (2, 2, 3, 3)) == 1, # creating boolean tensors + torch.randint(0, 2, (2, 2, 3, 3)) == 1, + ) From f0866f9afc46d9c201e0e237ad0e8838f7a38c16 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 12 May 2021 14:58:29 -0700 Subject: [PATCH 10/23] fix T5! --- allennlp/modules/transformer/t5.py | 14 +++- .../modules/transformer/transformer_module.py | 71 +++++++++++++++++-- allennlp/nn/util.py | 5 +- 3 files changed, 83 insertions(+), 7 deletions(-) diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 1772fb5b217..ba2d6714318 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -386,9 +386,12 @@ def __init__( self_attention: Optional[T5Attention] = None, layer_norm: Optional[T5LayerNorm] = None, dropout: float = 0.1, + has_relative_attention_bias: bool = False, ): super().__init__() - self.self_attention = self_attention or T5Attention() + self.self_attention = self_attention or T5Attention( + has_relative_attention_bias=has_relative_attention_bias + ) self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.self_attention.hidden_size) self.dropout = nn.Dropout(dropout) @@ -963,6 +966,15 @@ class T5Output: class T5(TransformerModule, Registrable): _huggingface_mapping = {"shared": "token_embeddings"} + _tied_weights = { + "token_embeddings.weight": [ + "encoder.token_embeddings.weight", + "decoder.token_embeddings.weight", + "lm_head.weight", + ] + } + # Don't know why HF has this param in their state_dict. It's not used in their model. + _huggingface_ignore = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] default_implementation = "default" diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index b65bc2a617f..65f6a3b92d0 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -78,6 +78,17 @@ class TransformerModule(torch.nn.Module): The default strategy for loading a state dictionary within a distributed process group. """ + _tied_weights: Optional[Dict[str, List[str]]] = None + """ + A mapping that defines any weights that need to be tied. Keys and values are parameter names. + The values will be tied to the corresponding key. + """ + + _huggingface_ignore: Optional[List[str]] = None + """ + An optional list of weights to ignore from a pretrained state_dict. + """ + @classmethod def _get_mapping( cls, @@ -190,6 +201,20 @@ def _from_config(cls: Type[_T], config: "PretrainedConfig", **kwargs) -> _T: """ raise NotImplementedError + def tie_weights(self) -> None: + """ + Tie weights according to the `_tied_weights` class attribute. + + This should always be called after loading a state dictionary. It will be called + automatically within `from_pretrained_module()`. + """ + if self._tied_weights: + param_dict = dict(self.named_parameters()) + param_dict.update(dict(self.named_buffers())) + for anchor_name, free_names in self._tied_weights.items(): + for free_name in free_names: + param_dict[free_name] = param_dict[anchor_name] + @classmethod def from_pretrained_module( cls: Type[_T], @@ -274,18 +299,21 @@ def from_pretrained_module( weights_path=weights_path, relevant_module=relevant_module, ) + # Remove weights we want to ignore. + for key in model._huggingface_ignore or []: + if key in pretrained_state_dict: + del pretrained_state_dict[key] # Now map keys from the HuggingFace state_dict to the corresponding keys from # this class. This is called recursively on each submodule of the current module. state_dict = model._get_mapped_state_dict(pretrained_state_dict, mapping=mapping) + error_msgs: List[str] = [] + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] if not is_distributed() or loading_strategy == DistributedLoadingStrategy.FREE_FOR_ALL: assert state_dict is not None logger.info("Loading state_dict into module") - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict) - if missing_keys: - logger.warning("Missing keys from pretrained state dict: %s", missing_keys) - if unexpected_keys: - logger.warning("Unexpected keys in pretrained state dict: %s", unexpected_keys) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) else: # We're in distributed training. `state_dict` is `None` for all process groups # except the global primary. @@ -296,6 +324,39 @@ def from_pretrained_module( logger.info("Loading state_dict into module (MEMORY_EFFICIENT strategy)") load_state_dict_distributed(model, state_dict, strict=strict) + # Allow missing keys in state_dict for params that are going to be tied. + for param_names in model._tied_weights.values(): + for param_name in param_names: + if param_name in missing_keys: + missing_keys.remove(param_name) + + if missing_keys: + error_msgs.append( + "Missing key(s) in state_dict: {}".format( + ", ".join(f'"{k}"' for k in missing_keys) + ) + ) + if unexpected_keys: + error_msgs.append( + "Unexpected key(s) in state_dict: {}".format( + ", ".join(f'"{k}"' for k in unexpected_keys) + ) + ) + + if error_msgs and strict: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + cls.__name__, "\n\t".join(error_msgs) + ) + ) + + # If there were error messages but we're not loading in 'strict' mode, + # we just issue warnings from the logger. + for msg in error_msgs: + logger.warning(msg) + + model.tie_weights() + return model diff --git a/allennlp/nn/util.py b/allennlp/nn/util.py index 0d8210323c9..a6a88480342 100644 --- a/allennlp/nn/util.py +++ b/allennlp/nn/util.py @@ -2251,7 +2251,10 @@ def load_state_dict_distributed( direct_member_state_dict, strict=False ) if strict and unexpected_keys: - raise ValueError(f"Unexpected keys in state dict: {unexpected_keys}") + raise RuntimeError( + f"Error(s) in loading state_dict for {module.__class__.__name__}:" + f"\tUnexpected key(s) in state dict: {unexpected_keys}" + ) # Okay, now for the recursive part. for name, submodule in submodules.items(): From ca2874300a9da926040bccb27fc85c5d3a0405ed Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 12 May 2021 16:00:13 -0700 Subject: [PATCH 11/23] fixes --- allennlp/modules/transformer/t5.py | 4 +++- .../modules/transformer/transformer_embeddings.py | 14 ++++++++++++++ allennlp/modules/transformer/transformer_module.py | 10 +++------- .../transformer/transformer_embeddings_test.py | 9 --------- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index ba2d6714318..b92bb97d490 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -974,7 +974,9 @@ class T5(TransformerModule, Registrable): ] } # Don't know why HF has this param in their state_dict. It's not used in their model. - _huggingface_ignore = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _huggingface_ignore = [ + r"^decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight$" + ] default_implementation = "default" diff --git a/allennlp/modules/transformer/transformer_embeddings.py b/allennlp/modules/transformer/transformer_embeddings.py index 25004964e46..12f77c93af3 100644 --- a/allennlp/modules/transformer/transformer_embeddings.py +++ b/allennlp/modules/transformer/transformer_embeddings.py @@ -110,7 +110,21 @@ class TransformerEmbeddings(Embeddings): "word_embeddings": "embeddings.word_embeddings", "position_embeddings": "embeddings.position_embeddings", "token_type_embeddings": "embeddings.token_type_embeddings", + # Albert is a special case. A linear projection is applied to the embeddings, + # but that linear transformation lives in the encoder. + "albert.embeddings.LayerNorm": "layer_norm", + "albert.embeddings.LayerNorm": "layer_norm", + "albert.embeddings.word_embeddings": "embeddings.word_embeddings", + "albert.embeddings.position_embeddings": "embeddings.position_embeddings", + "albert.embeddings.token_type_embeddings": "embeddings.token_type_embeddings", + "albert.encoder.embedding_hidden_mapping_in": "linear_transform", } + _huggingface_ignore = [ + # Albert + f"^albert\.pooler\..*", + f"^albert\.encoder\.albert_layer_groups\..*", + f"^predictions\.*", + ] def __init__( self, diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index 65f6a3b92d0..1d6cdd8d3e2 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -86,7 +86,7 @@ class TransformerModule(torch.nn.Module): _huggingface_ignore: Optional[List[str]] = None """ - An optional list of weights to ignore from a pretrained state_dict. + An optional list of regular expressions that define which weights to ignore from a pretrained state_dict. """ @classmethod @@ -186,7 +186,7 @@ def _get_pretrained_state_dict( # Now load the state dict. logger.info("Reading state dict from %s", weights_path) - state_dict = read_state_dict(weights_path) + state_dict = read_state_dict(weights_path, ignore=cls._huggingface_ignore, strict=False) # Keep just the relevant_module, remove everything else. state_dict = cls._get_relevant_submodule_state(state_dict, relevant_module=relevant_module) @@ -299,10 +299,6 @@ def from_pretrained_module( weights_path=weights_path, relevant_module=relevant_module, ) - # Remove weights we want to ignore. - for key in model._huggingface_ignore or []: - if key in pretrained_state_dict: - del pretrained_state_dict[key] # Now map keys from the HuggingFace state_dict to the corresponding keys from # this class. This is called recursively on each submodule of the current module. state_dict = model._get_mapped_state_dict(pretrained_state_dict, mapping=mapping) @@ -325,7 +321,7 @@ def from_pretrained_module( load_state_dict_distributed(model, state_dict, strict=strict) # Allow missing keys in state_dict for params that are going to be tied. - for param_names in model._tied_weights.values(): + for param_names in (model._tied_weights or {}).values(): for param_name in param_names: if param_name in missing_keys: missing_keys.remove(param_name) diff --git a/tests/modules/transformer/transformer_embeddings_test.py b/tests/modules/transformer/transformer_embeddings_test.py index 3954db17a0d..73eb84fe908 100644 --- a/tests/modules/transformer/transformer_embeddings_test.py +++ b/tests/modules/transformer/transformer_embeddings_test.py @@ -166,15 +166,6 @@ def test_loading_albert(): """ transformer_embedding = TransformerEmbeddings.from_pretrained_module( "albert-base-v2", - mapping={ - "albert.embeddings.LayerNorm": "layer_norm", - "albert.embeddings.LayerNorm": "layer_norm", - "albert.embeddings.word_embeddings": "embeddings.word_embeddings", - "albert.embeddings.position_embeddings": "embeddings.position_embeddings", - "albert.embeddings.token_type_embeddings": "embeddings.token_type_embeddings", - "albert.encoder.embedding_hidden_mapping_in": "linear_transform", - }, - strict=False, ) albert = AutoModel.from_pretrained("albert-base-v2") assert_allclose( From 15e78a5c28dfc407b2060984cedbade91149cd8f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 12 May 2021 16:03:21 -0700 Subject: [PATCH 12/23] fix backbone --- .../modules/backbones/vilbert_backbone.py | 52 ++++--------------- 1 file changed, 10 insertions(+), 42 deletions(-) diff --git a/allennlp/modules/backbones/vilbert_backbone.py b/allennlp/modules/backbones/vilbert_backbone.py index c1b9d1090b7..0f554a7a1d2 100644 --- a/allennlp/modules/backbones/vilbert_backbone.py +++ b/allennlp/modules/backbones/vilbert_backbone.py @@ -7,7 +7,12 @@ from allennlp.data.fields.text_field import TextFieldTensors from allennlp.data.vocabulary import Vocabulary from allennlp.modules.backbones.backbone import Backbone -from allennlp.modules.transformer import BiModalEncoder, ImageFeatureEmbeddings, Embeddings +from allennlp.modules.transformer import ( + BiModalEncoder, + ImageFeatureEmbeddings, + TransformerEmbeddings, + TransformerPooler, +) logger = logging.getLogger(__name__) @@ -23,7 +28,7 @@ class VilbertBackbone(Backbone): def __init__( self, vocab: Vocabulary, - text_embeddings: Embeddings, + text_embeddings: TransformerEmbeddings, image_embeddings: ImageFeatureEmbeddings, encoder: BiModalEncoder, pooled_output_dim: int, @@ -36,7 +41,6 @@ def __init__( self.text_embeddings = text_embeddings self.image_embeddings = image_embeddings self.encoder = encoder - from allennlp.modules.transformer import TransformerPooler self.t_pooler = TransformerPooler(encoder.hidden_size1, pooled_output_dim) self.v_pooler = TransformerPooler(encoder.hidden_size2, pooled_output_dim) @@ -66,44 +70,7 @@ def from_huggingface_model_name( image_fixed_layer: int, fusion_method: str = "sum", ): - from transformers import AutoModel - - transformer = AutoModel.from_pretrained(model_name) - - from copy import deepcopy - - text_embeddings = deepcopy(transformer.embeddings) - - # Albert (and maybe others?) has this "embedding_size", that's different from "hidden_size". - # To get them to the same dimensionality, it uses a linear transform after the embedding - # layer, which we need to pull out and copy here. - if hasattr(transformer.config, "embedding_size"): - config = transformer.config - - from transformers.models.albert.modeling_albert import AlbertModel - - if isinstance(transformer, AlbertModel): - linear_transform = deepcopy(transformer.encoder.embedding_hidden_mapping_in) - else: - logger.warning( - "Unknown model that uses separate embedding size; weights of the linear " - f"transform will not be initialized. Model type is: {transformer.__class__}" - ) - linear_transform = torch.nn.Linear(config.embedding_dim, config.hidden_dim) - - # We can't just use torch.nn.Sequential here, even though that's basically all this is, - # because Sequential doesn't accept *inputs, only a single argument. - - class EmbeddingsShim(torch.nn.Module): - def __init__(self, embeddings: torch.nn.Module, linear_transform: torch.nn.Module): - super().__init__() - self.linear_transform = linear_transform - self.embeddings = embeddings - - def forward(self, *inputs, **kwargs): - return self.linear_transform(self.embeddings(*inputs, **kwargs)) - - text_embeddings = EmbeddingsShim(text_embeddings, linear_transform) + text_embeddings = TransformerEmbeddings.from_pretrained_module(model_name) image_embeddings = ImageFeatureEmbeddings( feature_size=image_feature_dim, @@ -112,7 +79,7 @@ def forward(self, *inputs, **kwargs): ) encoder = BiModalEncoder.from_pretrained_module( - pretrained_module=transformer, + model_name, num_hidden_layers2=image_num_hidden_layers, hidden_size2=image_hidden_size, num_attention_heads2=image_num_attention_heads, @@ -126,6 +93,7 @@ def forward(self, *inputs, **kwargs): fixed_layer1=text_fixed_layer, fixed_layer2=image_fixed_layer, ) + return cls( vocab=vocab, text_embeddings=text_embeddings, From c8704c67585cc0e9fef7b54c1caae0c1f5ad7d78 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 12 May 2021 16:24:44 -0700 Subject: [PATCH 13/23] fix --- allennlp/modules/transformer/transformer_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/modules/transformer/transformer_embeddings.py b/allennlp/modules/transformer/transformer_embeddings.py index 12f77c93af3..5ab587d826a 100644 --- a/allennlp/modules/transformer/transformer_embeddings.py +++ b/allennlp/modules/transformer/transformer_embeddings.py @@ -120,7 +120,7 @@ class TransformerEmbeddings(Embeddings): "albert.encoder.embedding_hidden_mapping_in": "linear_transform", } _huggingface_ignore = [ - # Albert + # Ignore these for Albert case. f"^albert\.pooler\..*", f"^albert\.encoder\.albert_layer_groups\..*", f"^predictions\.*", From e09c1640de9dc360f36ffdf8795462c003522103 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 13 May 2021 11:30:10 -0700 Subject: [PATCH 14/23] fixes --- .../transformer/transformer_embeddings.py | 6 +- .../modules/transformer/transformer_module.py | 8 +- allennlp/nn/util.py | 134 +++++++++++++++--- tests/nn/util_test.py | 80 ++++++++++- 4 files changed, 199 insertions(+), 29 deletions(-) diff --git a/allennlp/modules/transformer/transformer_embeddings.py b/allennlp/modules/transformer/transformer_embeddings.py index 5ab587d826a..80634976846 100644 --- a/allennlp/modules/transformer/transformer_embeddings.py +++ b/allennlp/modules/transformer/transformer_embeddings.py @@ -121,9 +121,9 @@ class TransformerEmbeddings(Embeddings): } _huggingface_ignore = [ # Ignore these for Albert case. - f"^albert\.pooler\..*", - f"^albert\.encoder\.albert_layer_groups\..*", - f"^predictions\.*", + r"^albert\.pooler\..*", + r"^albert\.encoder\.albert_layer_groups\..*", + r"^predictions\.*", ] def __init__( diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index 1d6cdd8d3e2..10f357cbdf5 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -303,9 +303,9 @@ def from_pretrained_module( # this class. This is called recursively on each submodule of the current module. state_dict = model._get_mapped_state_dict(pretrained_state_dict, mapping=mapping) + missing_keys: List[str] + unexpected_keys: List[str] error_msgs: List[str] = [] - missing_keys: List[str] = [] - unexpected_keys: List[str] = [] if not is_distributed() or loading_strategy == DistributedLoadingStrategy.FREE_FOR_ALL: assert state_dict is not None logger.info("Loading state_dict into module") @@ -318,7 +318,9 @@ def from_pretrained_module( dist.barrier() # Now load the state dict into the model. logger.info("Loading state_dict into module (MEMORY_EFFICIENT strategy)") - load_state_dict_distributed(model, state_dict, strict=strict) + missing_keys, unexpected_keys = load_state_dict_distributed( + model, state_dict, strict=False + ) # Allow missing keys in state_dict for params that are going to be tied. for param_names in (model._tied_weights or {}).values(): diff --git a/allennlp/nn/util.py b/allennlp/nn/util.py index a6a88480342..ae0cb396df3 100644 --- a/allennlp/nn/util.py +++ b/allennlp/nn/util.py @@ -9,7 +9,7 @@ import logging from os import PathLike import re -from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, NamedTuple import math import numpy @@ -2180,12 +2180,28 @@ def dist_reduce_sum(value: _V, **kwargs) -> _V: def _collect_state_dict( module: torch.nn.Module, state_dict: Optional[StateDictType], recurse: bool = True -) -> StateDictType: +) -> Tuple[StateDictType, List[str], List[str]]: """ Collect a module's state dict across distributed processes. + + Returns the syncronized state dictionary, which will always be a valid state dict, + and then the missing and unexpected keys corresponding to the original `state_dict`. + Parameters that missing from the original `state_dict` will be populated from the + corresponding parameter in the primary processes' module's state dict. + + !!! Note + + `missing_keys` and `unexpected_keys` are only populated in the primary process. """ # This is the device we'll use for the broadcast operation. - device = distributed_device() + dist_device = distributed_device() + # This is the device we'll put all tensors on in the returned state dict. + state_dict_device = ( + int_to_device(-1) if not state_dict else state_dict[list(state_dict.keys())[0]].device + ) + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] # Gather current state dict and prepare to iterator over it. # We iterate over this state dict instead of `state_dict` so we can be sure @@ -2203,61 +2219,114 @@ def _collect_state_dict( module.named_buffers(recurse=False), ) ) + keys = list(current_state_dict.keys()) + # Gather unexpected_keys. + if is_global_primary(): + assert state_dict is not None + module_keys = set(module.state_dict().keys()) + for key in state_dict: + if key not in module_keys: + unexpected_keys.append(key) + for key in keys: tensor = current_state_dict[key] if is_global_primary(): assert state_dict is not None if key in state_dict: + # Update `tensor` to the value in `state_dict`. tensor = state_dict[key] else: - logger.warning( - f"Missing key {key} from state_dict (available keys: {list(state_dict.keys())})" - ) - tensor = tensor.to(device) + missing_keys.append(key) + tensor = tensor.to(dist_device) dist.broadcast(tensor, 0) - current_state_dict[key] = tensor + current_state_dict[key] = tensor.to(state_dict_device) + + return current_state_dict, missing_keys, unexpected_keys + - return current_state_dict +class _LoadStateDictResult(NamedTuple): + missing_keys: List[str] + unexpected_keys: List[str] def load_state_dict_distributed( module: torch.nn.Module, state_dict: Optional[StateDictType], strict: bool = True -) -> None: +) -> _LoadStateDictResult: """ Load a `state_dict` to the `module` within a distributed process. Only the global primary process requires the `state_dict` to not be `None`. All other processes will have the state tensors broadcasted to them one-by-one. + + If `strict` is `True`, then the keys of `state_dict` must exactly match the keys + returned by `module.state_dict()`. + + !!! Note + The returned `missing_keys` and `unexpected_keys` will only be accurate + in the primary process. + + # Returns + + `_LoadStateDictResult` + A `NamedTuple` with `missing_keys` and `unexpected_keys` fields, both of which + are lists of strings. + + # Raises + + `RuntimeError` + If `strict` is `True` and there are missing or unexpected keys. + """ if is_global_primary(): assert state_dict is not None else: assert state_dict is None + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + submodules = dict(module.named_children()) + def update_key_list(original, updates): + for key in updates: + if key not in original: + original.append(key) + # If we've found a sharded module or there aren't any more submodules of the current module, # we collect the state_dict and load it now instead of recursing further. if getattr(module, _MODULE_SHARDED_FLAG, False) or not submodules: - state_dict = _collect_state_dict(module, state_dict) + # Collect. + state_dict, _missing_keys, _unexpected_keys = _collect_state_dict(module, state_dict) assert state_dict is not None - module.load_state_dict(state_dict, strict=strict) + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) + # And load. + _missing_keys, _unexpected_keys = module.load_state_dict(state_dict, strict=False) + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) else: # We'll recursively call this function on each submodule, but first we need # to collect any parameters that are direct members of this module. - direct_member_state_dict = _collect_state_dict(module, state_dict, recurse=False) - missing_keys, unexpected_keys = module.load_state_dict( + direct_member_state_dict, _missing_keys, _unexpected_keys = _collect_state_dict( + module, state_dict, recurse=False + ) + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) + + # `_missing_keys` here will contain any keys corresponding to submodules, but + # we'll remove those below. + _missing_keys, _unexpected_keys = module.load_state_dict( direct_member_state_dict, strict=False ) - if strict and unexpected_keys: - raise RuntimeError( - f"Error(s) in loading state_dict for {module.__class__.__name__}:" - f"\tUnexpected key(s) in state dict: {unexpected_keys}" - ) + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) # Okay, now for the recursive part. for name, submodule in submodules.items(): + # Update `missing_keys` to remove keys corresponding to this submodule. + # If they are actually missing after this step, we add them back in below. + missing_keys = [k for k in missing_keys if not k.startswith(name + ".")] submodule_state_dict: Optional[StateDictType] = None if is_global_primary(): assert state_dict is not None @@ -2266,4 +2335,29 @@ def load_state_dict_distributed( for key, value in state_dict.items() if key.startswith(name + ".") } - load_state_dict_distributed(submodule, submodule_state_dict, strict=strict) + _missing_keys, _unexpected_keys = load_state_dict_distributed( + submodule, submodule_state_dict, strict=False + ) + update_key_list(missing_keys, [f"{name}.{key}" for key in _missing_keys]) + update_key_list(unexpected_keys, [f"{name}.{key}" for key in _unexpected_keys]) + + if strict: + error_msgs: List[str] = [] + if missing_keys: + error_msgs.append( + "Missing key(s) in state_dict: {}".format(", ".join(f'"{k}"' for k in missing_keys)) + ) + if unexpected_keys: + error_msgs.append( + "Unexpected key(s) in state_dict: {}".format( + ", ".join(f'"{k}"' for k in unexpected_keys) + ) + ) + if error_msgs: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + module.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + + return _LoadStateDictResult(missing_keys, unexpected_keys) diff --git a/tests/nn/util_test.py b/tests/nn/util_test.py index 7ca660ed04d..73a9952a11f 100644 --- a/tests/nn/util_test.py +++ b/tests/nn/util_test.py @@ -9,7 +9,7 @@ from flaky import flaky from allennlp.common.checks import ConfigurationError -from allennlp.common.testing import AllenNlpTestCase +from allennlp.common.testing import AllenNlpTestCase, run_distributed_test from allennlp.common.util import sanitize from allennlp.data import Token, Vocabulary from allennlp.data.fields import TextField @@ -1730,8 +1730,6 @@ def test_dist_reduce_sum(self): ret_value = util.dist_reduce_sum(value) assert (ret_value == value).all().item() - from allennlp.common.testing.distributed_test import run_distributed_test - func_kwargs = {"value": [torch.Tensor([1, 2, 3]), torch.Tensor([4, 5, 6])]} desired_values = torch.Tensor([5, 7, 9]) @@ -1761,3 +1759,79 @@ def global_distributed_func( output = function(**kwargs) assert (output == desired_values).all().item() + + +class DistributedFixtureModel(torch.nn.Module): + """ + Fake model for testing `load_state_dict_distributed()`. + """ + + def __init__(self): + super().__init__() + self.direct_param = torch.nn.Parameter(torch.randn(3, 5)) + self.register_buffer("direct_buffer", torch.randn(2, 2)) + self.custom_submodule = DistributedFixtureSubmodule() + self.custom_sharded_submodule = DistributedFixtureSubmodule(sharded=True) + self.linear_submodule = torch.nn.Linear(3, 5) + + def forward(self, x): + # This doesn't matter, we're not going to actually use it. + pass + + +class DistributedFixtureSubmodule(torch.nn.Module): + def __init__(self, sharded: bool = False): + super().__init__() + self.direct_param = torch.nn.Parameter(torch.randn(3, 5)) + self.register_buffer("direct_buffer", torch.randn(2, 2)) + self.linear_submodule = torch.nn.Linear(3, 5) + if sharded: + setattr(self, util._MODULE_SHARDED_FLAG, True) + + def forward(self, x): + # This doesn't matter, we're not going to actually use it. + pass + + +def _dist_load_ok(global_rank, world_size, gpu_id): + model = DistributedFixtureModel() + state_dict = None if global_rank != 0 else model.state_dict() + missing_keys, unexpected_keys = util.load_state_dict_distributed(model, state_dict) + assert not missing_keys + assert not unexpected_keys + + +def _dist_load_with_errors(global_rank, world_size, gpu_id): + model = DistributedFixtureModel() + state_dict = None if global_rank != 0 else model.state_dict() + _missing_keys = [ + "direct_buffer", + "custom_submodule.linear_submodule.bias", + "custom_submodule.direct_param", + "custom_sharded_submodule.linear_submodule.bias", + "custom_sharded_submodule.direct_buffer", + ] + _unexpected_keys = [ + "not_a_parameter", + "custom_submodule.not_a_parameter", + "custom_submodule.linear.not_a_parameter", + "custom_sharded_submodule.not_a_parameter", + "custom_sharded_submodule.linear.not_a_parameter", + "not_even_submodule.not_a_parameter", + ] + if state_dict is not None: + for key in _missing_keys: + del state_dict[key] + for key in _unexpected_keys: + state_dict[key] = torch.randn(2, 2) + missing_keys, unexpected_keys = util.load_state_dict_distributed( + model, state_dict, strict=False + ) + if global_rank == 0: + assert set(missing_keys) == set(_missing_keys) + assert set(unexpected_keys) == set(_unexpected_keys) + + +@pytest.mark.parametrize("test_func", [_dist_load_ok, _dist_load_with_errors]) +def test_load_state_dict_distributed(test_func): + run_distributed_test([-1, -1], func=test_func) From 3a35f94ab454ad874eed740fdf889295f4094d1a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 13 May 2021 11:58:55 -0700 Subject: [PATCH 15/23] fix --- allennlp/common/testing/distributed_test.py | 9 +++++- .../modules/transformer/transformer_module.py | 1 - .../transformer/transformer_layer_test.py | 28 +++++++++++++++++-- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/allennlp/common/testing/distributed_test.py b/allennlp/common/testing/distributed_test.py index 7ef00e2e0e8..2fae00ff635 100644 --- a/allennlp/common/testing/distributed_test.py +++ b/allennlp/common/testing/distributed_test.py @@ -61,12 +61,19 @@ def run_distributed_test( func: `Callable` `func` needs to be global for spawning the processes, so that it can be pickled. + + start_method: `Optional[str]`, optional (default = `None`) + The start method to use for starting the workers. Defaults to "spawn" for GPU + processes and fork otherwise. """ device_ids = device_ids or [-1, -1] check_for_gpu(device_ids) # "fork" start method is the default and should be preferred, except when we're # running the tests on GPU, in which case we need to use "spawn". - start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork" + if "start_method" in kwargs: + start_method = kwargs.pop("start_method") + else: + start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork" nprocs = world_size = len(device_ids) mp.start_processes( init_process, diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index 10f357cbdf5..2f5fa109d54 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -232,7 +232,6 @@ def from_pretrained_module( """ Initialize this module from a corresponding model on HuggingFace. - !!! Note This method is only available for subclasses that implement `from_config()`. Otherwise a `NotImplementedError` will be raised. diff --git a/tests/modules/transformer/transformer_layer_test.py b/tests/modules/transformer/transformer_layer_test.py index 10ab837a79c..b43d5b23070 100644 --- a/tests/modules/transformer/transformer_layer_test.py +++ b/tests/modules/transformer/transformer_layer_test.py @@ -9,9 +9,13 @@ from transformers.models.electra.configuration_electra import ElectraConfig from transformers.models.electra.modeling_electra import ElectraAttention, ElectraLayer -from allennlp.common import Params -from allennlp.common import cached_transformers -from allennlp.modules.transformer import AttentionLayer, TransformerLayer +from allennlp.common import Params, cached_transformers +from allennlp.common.testing import run_distributed_test +from allennlp.modules.transformer import ( + AttentionLayer, + TransformerLayer, + DistributedLoadingStrategy, +) ATTENTION_PARAMS_DICT = { @@ -285,3 +289,21 @@ def test_layer_from_pretrained(pretrained_name, relevant_top_level_module): hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf)[0] assert torch.allclose(output, hf_output, atol=1e-04) + + +def _load_pretrained(global_rank, world_size, gpu_id): + TransformerLayer.from_pretrained_module( + "epwalsh/bert-xsmall-dummy", + ) + + +def _load_pretrained_mem_efficient(global_rank, world_size, gpu_id): + TransformerLayer.from_pretrained_module( + "epwalsh/bert-xsmall-dummy", + distributed_loading_strategy=DistributedLoadingStrategy.MEMORY_EFFICIENT, + ) + + +@pytest.mark.parametrize("test_func", [_load_pretrained, _load_pretrained_mem_efficient]) +def test_distributed(test_func): + run_distributed_test([-1, -1], func=test_func, start_method="spawn") From 2812b1b9d6276ab044f492d1e4948d4e6286f3d9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 13 May 2021 13:31:26 -0700 Subject: [PATCH 16/23] doc fixes --- allennlp/modules/transformer/bimodal_attention.py | 5 +++-- allennlp/modules/transformer/positional_encoding.py | 3 +++ allennlp/modules/transformer/self_attention.py | 2 ++ allennlp/modules/transformer/t5.py | 2 +- allennlp/modules/transformer/transformer_embeddings.py | 2 ++ allennlp/modules/transformer/transformer_layer.py | 4 ++++ allennlp/modules/transformer/transformer_module.py | 6 +++--- allennlp/modules/transformer/transformer_stack.py | 2 ++ allennlp/nn/util.py | 3 +++ scripts/py2md.py | 6 ++++++ 10 files changed, 29 insertions(+), 6 deletions(-) diff --git a/allennlp/modules/transformer/bimodal_attention.py b/allennlp/modules/transformer/bimodal_attention.py index fc6bb4047f9..cc4bf11aa22 100644 --- a/allennlp/modules/transformer/bimodal_attention.py +++ b/allennlp/modules/transformer/bimodal_attention.py @@ -118,10 +118,12 @@ def forward( input_tensor2, attention_mask1=None, attention_mask2=None, - co_attention_mask=None, + co_attention_mask=None, # TODO: is this flag necessary? use_co_attention_mask=False, ): """ + # Parameters + input_tensor1 : `torch.Tensor` Shape `batch_size x seq_len1 x hidden_dim1` where `seq_len1` can be the sequence length @@ -143,7 +145,6 @@ def forward( if you know which words correspond to which regions in the image, this mask can be applied to limit the attention given the bias. use_co_attention_mask : `bool` - # TODO: is this flag necessary? Whether to use co_attention_mask or not, default = `False`. """ diff --git a/allennlp/modules/transformer/positional_encoding.py b/allennlp/modules/transformer/positional_encoding.py index 1cf63b15c91..b0abc2b91b2 100644 --- a/allennlp/modules/transformer/positional_encoding.py +++ b/allennlp/modules/transformer/positional_encoding.py @@ -42,6 +42,9 @@ def __init__(self, min_timescale: float = 1.0, max_timescale: float = 1.0e4): self.max_timescale = max_timescale def forward(self, input_tensor: torch.Tensor): + """ + Adds a positional encoding to `input_tensor`. + """ # TODO: Another option is to specify the expected size in init, so that we can construct # the positional encoding beforehand, and simply add it to the input tensor in forward. _, timesteps, hidden_dim = input_tensor.size() diff --git a/allennlp/modules/transformer/self_attention.py b/allennlp/modules/transformer/self_attention.py index 71c253f9573..16738b634ec 100644 --- a/allennlp/modules/transformer/self_attention.py +++ b/allennlp/modules/transformer/self_attention.py @@ -94,6 +94,8 @@ def forward( output_attentions: bool = False, ): """ + # Parameters + query_states : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` key_states : `torch.Tensor`, optional diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index b92bb97d490..14c6e734535 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -1,5 +1,5 @@ """ -Adapted from [HuggingFace] +An implementation of [T5](https://api.semanticscholar.org/CorpusID:204838007), adapted from [HuggingFace] (https://github.com/huggingface/transformers/blob/4c32f9f26e6a84f0d9843fec8757e6ce640bb44e/src/transformers/models/t5/modeling_t5.py). """ # noqa: E401 diff --git a/allennlp/modules/transformer/transformer_embeddings.py b/allennlp/modules/transformer/transformer_embeddings.py index 80634976846..bac80be0b31 100644 --- a/allennlp/modules/transformer/transformer_embeddings.py +++ b/allennlp/modules/transformer/transformer_embeddings.py @@ -166,6 +166,8 @@ def forward( # type: ignore ) -> torch.Tensor: """ + # Parameters + input_ids : `torch.Tensor` Shape `batch_size x seq_len` token_type_ids : `torch.Tensor`, optional diff --git a/allennlp/modules/transformer/transformer_layer.py b/allennlp/modules/transformer/transformer_layer.py index bb25262e3b4..5f0cfb1980a 100644 --- a/allennlp/modules/transformer/transformer_layer.py +++ b/allennlp/modules/transformer/transformer_layer.py @@ -53,6 +53,8 @@ def forward( output_attentions: bool = False, ): """ + # Parameters + input_tensor : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` attention_mask : `torch.BoolTensor`, optional @@ -166,6 +168,8 @@ def forward( output_attentions: bool = False, ): """ + # Parameters + hidden_states : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` attention_mask : `torch.BoolTensor`, optional diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index 2f5fa109d54..f4dd5c57556 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -233,7 +233,7 @@ def from_pretrained_module( Initialize this module from a corresponding model on HuggingFace. !!! Note - This method is only available for subclasses that implement `from_config()`. + This method is only available for subclasses that implement `_from_config()`. Otherwise a `NotImplementedError` will be raised. # Parameters @@ -258,7 +258,7 @@ def from_pretrained_module( between this module and the pretrained model from HuggingFace. If not given, the class's default is used: `cls._huggingface_mapping`. - relevant_module : `Optionall[str]`, optional (default = `None`) + relevant_module : `Optional[str]`, optional (default = `None`) An optional submodule of the HuggingFace module to initialize weights from. This is only relevant when `load_weights` is `True`. If not given, the class's default is used: `cls._relevant_module`. @@ -272,7 +272,7 @@ def from_pretrained_module( when `load_weights` is `True`. If not specified, this class's default is used: `cls._distributed_loading_strategy`. - **kwargs : Any + **kwargs : `Any` Key word arguments to pass to `cls.from_config()` when instantiating the module. """ # noqa: E501 from transformers import AutoConfig diff --git a/allennlp/modules/transformer/transformer_stack.py b/allennlp/modules/transformer/transformer_stack.py index 33c69703040..8343990e28d 100644 --- a/allennlp/modules/transformer/transformer_stack.py +++ b/allennlp/modules/transformer/transformer_stack.py @@ -89,6 +89,8 @@ def forward( output_hidden_states: bool = False, ): """ + # Parameters + hidden_states : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` attention_mask : `torch.BoolTensor`, optional diff --git a/allennlp/nn/util.py b/allennlp/nn/util.py index ae0cb396df3..d25239b27f3 100644 --- a/allennlp/nn/util.py +++ b/allennlp/nn/util.py @@ -2278,6 +2278,9 @@ def load_state_dict_distributed( If `strict` is `True` and there are missing or unexpected keys. """ + if not is_distributed(): + return module.load_state_dict(state_dict, strict=strict) + if is_global_primary(): assert state_dict is not None else: diff --git a/scripts/py2md.py b/scripts/py2md.py index 82a31565485..e587184c1a6 100755 --- a/scripts/py2md.py +++ b/scripts/py2md.py @@ -279,6 +279,12 @@ class AllenNlpFilterProcessor(Struct): "__call__", "__iter__", "InfluenceInterpreter._calculate_influence_scores", + "TransformerModule._from_config", + "TransformerModule._huggingface_mapping", + "TransformerModule._relevant_module", + "TransformerModule._distributed_loading_strategy", + "TransformerModule._tied_weights", + "TransformerModule._huggingface_ignore", } def process(self, graph, _resolver): From 267bbe7089f0a5c42c8768d06f4d1bd0095aec2d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 13 May 2021 14:28:36 -0700 Subject: [PATCH 17/23] name changes --- .../transformer/bimodal_connection_layer.py | 2 +- .../modules/transformer/bimodal_encoder.py | 5 +- allennlp/modules/transformer/layer_norm.py | 2 +- allennlp/modules/transformer/output_layer.py | 2 +- .../modules/transformer/self_attention.py | 4 +- allennlp/modules/transformer/t5.py | 12 ++-- .../transformer/transformer_embeddings.py | 7 +- .../modules/transformer/transformer_layer.py | 8 +-- .../modules/transformer/transformer_module.py | 66 ++++++++++++++----- .../modules/transformer/transformer_stack.py | 4 +- scripts/py2md.py | 7 +- .../transformer/bimodal_encoder_test.py | 4 +- .../transformer_embeddings_test.py | 10 +-- .../transformer/transformer_module_test.py | 4 +- 14 files changed, 86 insertions(+), 51 deletions(-) diff --git a/allennlp/modules/transformer/bimodal_connection_layer.py b/allennlp/modules/transformer/bimodal_connection_layer.py index 5d7e4f7fc88..f9656c2b7a5 100644 --- a/allennlp/modules/transformer/bimodal_connection_layer.py +++ b/allennlp/modules/transformer/bimodal_connection_layer.py @@ -31,7 +31,7 @@ def forward(self, hidden_states1, input_tensor1, hidden_states2, input_tensor2): class BiModalConnectionLayer(TransformerModule, FromParams): - _huggingface_mapping = {"biAttention": "bimodal_attention", "biOutput": "bimodal_output"} + _pretrained_mapping = {"biAttention": "bimodal_attention", "biOutput": "bimodal_output"} def __init__( self, diff --git a/allennlp/modules/transformer/bimodal_encoder.py b/allennlp/modules/transformer/bimodal_encoder.py index 4cff9e1b3a3..acc993194df 100644 --- a/allennlp/modules/transformer/bimodal_encoder.py +++ b/allennlp/modules/transformer/bimodal_encoder.py @@ -48,8 +48,9 @@ class BiModalEncoder(TransformerModule, FromParams): in_batch_pairs: `bool` (default = `False`) """ - _huggingface_mapping = {"layer": "layers1"} - _relevant_module = "encoder" + _pretrained_mapping = {"layer": "layers1"} + _pretrained_relevant_module = ["encoder", "bert.encoder"] + _pretrained_allow_missing = [r"^layers2\..*", r"^c_layer\..*"] def __init__( self, diff --git a/allennlp/modules/transformer/layer_norm.py b/allennlp/modules/transformer/layer_norm.py index a7e40d2ad13..0302b705c1d 100644 --- a/allennlp/modules/transformer/layer_norm.py +++ b/allennlp/modules/transformer/layer_norm.py @@ -4,4 +4,4 @@ class LayerNorm(torch.nn.LayerNorm, TransformerModule): - _huggingface_mapping = {"gamma": "weight", "beta": "bias"} + _pretrained_mapping = {"gamma": "weight", "beta": "bias"} diff --git a/allennlp/modules/transformer/output_layer.py b/allennlp/modules/transformer/output_layer.py index 2f4b8afdede..ac38a1794b1 100644 --- a/allennlp/modules/transformer/output_layer.py +++ b/allennlp/modules/transformer/output_layer.py @@ -8,7 +8,7 @@ class OutputLayer(TransformerModule, FromParams): - _huggingface_mapping = {"LayerNorm": "layer_norm"} + _pretrained_mapping = {"LayerNorm": "layer_norm"} def __init__(self, input_size: int, hidden_size: int, dropout: float): super().__init__() diff --git a/allennlp/modules/transformer/self_attention.py b/allennlp/modules/transformer/self_attention.py index 16738b634ec..d464012de81 100644 --- a/allennlp/modules/transformer/self_attention.py +++ b/allennlp/modules/transformer/self_attention.py @@ -29,8 +29,8 @@ class SelfAttention(TransformerModule, FromParams): Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`. """ - _relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] - _huggingface_mapping = { + _pretrained_relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] + _pretrained_mapping = { "layer": "layers", "q_lin": "query", "k_lin": "key", diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 14c6e734535..faf134f81c2 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -97,7 +97,7 @@ def forward(self, hidden_states) -> FloatT: class T5LayerFF(TransformerModule, FromParams): - _huggingface_mapping = {"DenseReluDense": "ff_proj"} + _pretrained_mapping = {"DenseReluDense": "ff_proj"} def __init__( self, @@ -379,7 +379,7 @@ class T5LayerSelfAttentionOutput: class T5LayerSelfAttention(TransformerModule, FromParams): - _huggingface_mapping = {"SelfAttention": "self_attention"} + _pretrained_mapping = {"SelfAttention": "self_attention"} def __init__( self, @@ -433,7 +433,7 @@ class T5LayerCrossAttentionOutput: class T5LayerCrossAttention(TransformerModule, FromParams): - _huggingface_mapping = {"EncDecAttention": "enc_dec_attention"} + _pretrained_mapping = {"EncDecAttention": "enc_dec_attention"} def __init__( self, @@ -624,7 +624,7 @@ class T5StackOutput: class T5Stack(TransformerModule, FromParams): - _huggingface_mapping = {"embed_tokens": "token_embeddings", "block": "blocks"} + _pretrained_mapping = {"embed_tokens": "token_embeddings", "block": "blocks"} def __init__( self, @@ -965,7 +965,7 @@ class T5Output: class T5(TransformerModule, Registrable): - _huggingface_mapping = {"shared": "token_embeddings"} + _pretrained_mapping = {"shared": "token_embeddings"} _tied_weights = { "token_embeddings.weight": [ "encoder.token_embeddings.weight", @@ -974,7 +974,7 @@ class T5(TransformerModule, Registrable): ] } # Don't know why HF has this param in their state_dict. It's not used in their model. - _huggingface_ignore = [ + _pretrained_ignore = [ r"^decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight$" ] diff --git a/allennlp/modules/transformer/transformer_embeddings.py b/allennlp/modules/transformer/transformer_embeddings.py index bac80be0b31..3712d9b0a3a 100644 --- a/allennlp/modules/transformer/transformer_embeddings.py +++ b/allennlp/modules/transformer/transformer_embeddings.py @@ -104,8 +104,8 @@ class TransformerEmbeddings(Embeddings): Optionally apply a linear transform after the dropout, projecting to `output_size`. """ - _relevant_module = "embeddings" - _huggingface_mapping = { + _pretrained_relevant_module = ["embeddings", "bert.embeddings"] + _pretrained_mapping = { "LayerNorm": "layer_norm", "word_embeddings": "embeddings.word_embeddings", "position_embeddings": "embeddings.position_embeddings", @@ -119,7 +119,7 @@ class TransformerEmbeddings(Embeddings): "albert.embeddings.token_type_embeddings": "embeddings.token_type_embeddings", "albert.encoder.embedding_hidden_mapping_in": "linear_transform", } - _huggingface_ignore = [ + _pretrained_ignore = [ # Ignore these for Albert case. r"^albert\.pooler\..*", r"^albert\.encoder\.albert_layer_groups\..*", @@ -167,7 +167,6 @@ def forward( # type: ignore """ # Parameters - input_ids : `torch.Tensor` Shape `batch_size x seq_len` token_type_ids : `torch.Tensor`, optional diff --git a/allennlp/modules/transformer/transformer_layer.py b/allennlp/modules/transformer/transformer_layer.py index 5f0cfb1980a..43a76d33144 100644 --- a/allennlp/modules/transformer/transformer_layer.py +++ b/allennlp/modules/transformer/transformer_layer.py @@ -29,8 +29,8 @@ class AttentionLayer(TransformerModule, FromParams): Dropout probability for the `OutputLayer`. """ - _relevant_module = "encoder.layer.0.attention" - _huggingface_mapping = {"layer": "layers"} + _pretrained_relevant_module = "encoder.layer.0.attention" + _pretrained_mapping = {"layer": "layers"} def __init__( self, @@ -114,8 +114,8 @@ class TransformerLayer(TransformerModule, FromParams): This is helpful when using the layer in a decoder. """ - _relevant_module = "encoder.layer.0" - _huggingface_mapping = { + _pretrained_relevant_module = "encoder.layer.0" + _pretrained_mapping = { "layer": "layers", "intermediate_act_fn": "act_fn", "crossattention": "cross_attention", diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index f4dd5c57556..85952b77551 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -3,6 +3,7 @@ import os from os import PathLike from typing import TYPE_CHECKING, Optional, Dict, Union, List, Any, TypeVar, Type +import re import warnings import torch @@ -57,20 +58,31 @@ class TransformerModule(torch.nn.Module): `from_pretrained_module()`. """ - _huggingface_mapping: Dict[str, str] = {} + _pretrained_mapping: Dict[str, str] = {} """ An optional mapping for each class that determines any differences in the module names between the class modules and the HuggingFace model's modules. Keys correspond to HuggingFace submodule names, values correspond to submodules names of this module. """ - _relevant_module: Optional[Union[str, List[str]]] = None + _pretrained_relevant_module: Optional[Union[str, List[str]]] = None """ An optional string or list of strings which contains the expected name of the module in the HuggingFace pretrained model. It can be a list to account for different names in different models. The search is carried out in the order of the list. """ + _pretrained_ignore: Optional[List[str]] = None + """ + An optional list of regular expressions that define which weights to ignore from a pretrained state_dict. + """ + + _pretrained_allow_missing: Optional[List[str]] = None + """ + An optional list of regular expressions that specifies which weights are allowed to be missing + from a pretrained state dictionary. + """ + _distributed_loading_strategy: DistributedLoadingStrategy = ( DistributedLoadingStrategy.FREE_FOR_ALL ) @@ -84,11 +96,6 @@ class TransformerModule(torch.nn.Module): The values will be tied to the corresponding key. """ - _huggingface_ignore: Optional[List[str]] = None - """ - An optional list of regular expressions that define which weights to ignore from a pretrained state_dict. - """ - @classmethod def _get_mapping( cls, @@ -99,7 +106,7 @@ def _get_mapping( and the default module-level mapping. """ combined_mapping = {} - combined_mapping.update(cls._huggingface_mapping) + combined_mapping.update(cls._pretrained_mapping) if mapping is not None: combined_mapping.update(mapping) return combined_mapping @@ -129,10 +136,10 @@ def _get_relevant_submodule_state( relevant_modules = ( [relevant_module] if isinstance(relevant_module, str) else relevant_module ) - elif isinstance(cls._relevant_module, str): - relevant_modules = [cls._relevant_module] - elif isinstance(cls._relevant_module, list): - relevant_modules = cls._relevant_module + elif isinstance(cls._pretrained_relevant_module, str): + relevant_modules = [cls._pretrained_relevant_module] + elif isinstance(cls._pretrained_relevant_module, list): + relevant_modules = cls._pretrained_relevant_module if relevant_modules: found = False @@ -163,6 +170,7 @@ def _get_pretrained_state_dict( model_name: str, weights_path: Optional[Union[str, PathLike]] = None, relevant_module: Optional[Union[str, List[str]]] = None, + ignore: Optional[List[str]] = None, ) -> StateDictType: """ Get a HuggingFace pretrained `state_dict` corresponding to this module. @@ -186,7 +194,11 @@ def _get_pretrained_state_dict( # Now load the state dict. logger.info("Reading state dict from %s", weights_path) - state_dict = read_state_dict(weights_path, ignore=cls._huggingface_ignore, strict=False) + state_dict = read_state_dict( + weights_path, + ignore=ignore if ignore is not None else cls._pretrained_ignore, + strict=False, + ) # Keep just the relevant_module, remove everything else. state_dict = cls._get_relevant_submodule_state(state_dict, relevant_module=relevant_module) @@ -225,6 +237,8 @@ def from_pretrained_module( auto_config_kwargs: Optional[Dict[str, Any]] = None, mapping: Optional[Dict[str, str]] = None, relevant_module: Optional[Union[str, List[str]]] = None, + ignore: Optional[List[str]] = None, + allow_missing: Optional[List[str]] = None, strict: bool = True, distributed_loading_strategy: Optional[Union[str, DistributedLoadingStrategy]] = None, **kwargs, @@ -256,12 +270,24 @@ def from_pretrained_module( mapping : `Optional[Dict[str, str]]`, optional (default = `None`) Optional mapping that determines any differences in the submodule names between this module and the pretrained model from HuggingFace. - If not given, the class's default is used: `cls._huggingface_mapping`. + If not given, the class's default is used: `cls._pretrained_mapping`. relevant_module : `Optional[str]`, optional (default = `None`) An optional submodule of the HuggingFace module to initialize weights from. This is only relevant when `load_weights` is `True`. - If not given, the class's default is used: `cls._relevant_module`. + If not given, the class's default is used: `cls._pretrained_relevant_module`. + + ignore : `Optional[List[str]]`, optional (default = `None`) + An optional list of regular expressions that define which weights to ignore + from a pretrained state_dict. + This is only relevant when `load_weights` is `True`. + If not specified, the class's default is used: `cls._pretrained_ignore`. + + allow_missing: `Optional[List[str]]`, optional (default = `None`) + An optional list of regular expressions that specifies which weights are allowed to be missing + from the pretrained state dictionary. + This is only relevant when `load_weights` is `True`. + If not specified, the class's default is used: `cls._pretrained_allow_missing`. strict : `bool`, optional (default = `True`) Whether to load the `state_dict` in "strict" model. This only applies @@ -297,6 +323,7 @@ def from_pretrained_module( model_name, weights_path=weights_path, relevant_module=relevant_module, + ignore=ignore, ) # Now map keys from the HuggingFace state_dict to the corresponding keys from # this class. This is called recursively on each submodule of the current module. @@ -321,6 +348,15 @@ def from_pretrained_module( model, state_dict, strict=False ) + # Exclude any keys in `missing_keys` that match with the `allow_missing` + # regular expressions. + if allow_missing is None: + allow_missing = cls._pretrained_allow_missing + if allow_missing: + missing_keys = [ + k for k in missing_keys if not any(re.match(p, k) for p in allow_missing) + ] + # Allow missing keys in state_dict for params that are going to be tied. for param_names in (model._tied_weights or {}).values(): for param_name in param_names: diff --git a/allennlp/modules/transformer/transformer_stack.py b/allennlp/modules/transformer/transformer_stack.py index 8343990e28d..7bc4a7247d3 100644 --- a/allennlp/modules/transformer/transformer_stack.py +++ b/allennlp/modules/transformer/transformer_stack.py @@ -41,8 +41,8 @@ class TransformerStack(TransformerModule, FromParams): This is helpful when using the `TransformerStack` as a decoder. """ - _huggingface_mapping = {"layer": "layers"} - _relevant_module = ["encoder", "bert.encoder"] + _pretrained_mapping = {"layer": "layers"} + _pretrained_relevant_module = ["encoder", "bert.encoder"] def __init__( self, diff --git a/scripts/py2md.py b/scripts/py2md.py index e587184c1a6..c8bc1ca1d43 100755 --- a/scripts/py2md.py +++ b/scripts/py2md.py @@ -280,11 +280,12 @@ class AllenNlpFilterProcessor(Struct): "__iter__", "InfluenceInterpreter._calculate_influence_scores", "TransformerModule._from_config", - "TransformerModule._huggingface_mapping", - "TransformerModule._relevant_module", + "TransformerModule._pretrained_mapping", + "TransformerModule._pretrained_relevant_module", + "TransformerModule._pretrained_ignore", + "TransformerModule._pretrained_allow_missing", "TransformerModule._distributed_loading_strategy", "TransformerModule._tied_weights", - "TransformerModule._huggingface_ignore", } def process(self, graph, _resolver): diff --git a/tests/modules/transformer/bimodal_encoder_test.py b/tests/modules/transformer/bimodal_encoder_test.py index 3ac682cccbf..39bd3b54e8c 100644 --- a/tests/modules/transformer/bimodal_encoder_test.py +++ b/tests/modules/transformer/bimodal_encoder_test.py @@ -75,9 +75,7 @@ def test_loading_from_pretrained_weights(params_dict): ] kwargs = {key: params_dict[key] for key in required_kwargs} - module = BiModalEncoder.from_pretrained_module( - "bert-base-cased", relevant_module="bert.encoder", strict=False, **kwargs - ) + module = BiModalEncoder.from_pretrained_module("bert-base-cased", **kwargs) assert_allclose( module.layers1[0].intermediate.dense.weight.data, pretrained_module.layer[0].intermediate.dense.weight.data, diff --git a/tests/modules/transformer/transformer_embeddings_test.py b/tests/modules/transformer/transformer_embeddings_test.py index 73eb84fe908..d37eae8629b 100644 --- a/tests/modules/transformer/transformer_embeddings_test.py +++ b/tests/modules/transformer/transformer_embeddings_test.py @@ -148,14 +148,14 @@ def test_no_token_type_layer(params): @pytest.mark.parametrize( - "pretrained_name, relevant_module", + "pretrained_name", [ - ("bert-base-cased", "bert.embeddings"), - ("epwalsh/bert-xsmall-dummy", None), + "bert-base-cased", + "epwalsh/bert-xsmall-dummy", ], ) -def test_loading_from_pretrained_module(pretrained_name, relevant_module): - TransformerEmbeddings.from_pretrained_module(pretrained_name, relevant_module=relevant_module) +def test_loading_from_pretrained_module(pretrained_name): + TransformerEmbeddings.from_pretrained_module(pretrained_name) def test_loading_albert(): diff --git a/tests/modules/transformer/transformer_module_test.py b/tests/modules/transformer/transformer_module_test.py index 307c8295ad8..4018229c41d 100644 --- a/tests/modules/transformer/transformer_module_test.py +++ b/tests/modules/transformer/transformer_module_test.py @@ -20,7 +20,7 @@ def forward(self, x): return x class InternalNew(TransformerModule): - _huggingface_mapping = {"ff": "linear", "p": "param", "b": "buffer"} + _pretrained_mapping = {"ff": "linear", "p": "param", "b": "buffer"} def __init__(self, inp, out): super().__init__() @@ -43,7 +43,7 @@ def forward(self, x): return x class ExternalNew(TransformerModule): - _huggingface_mapping = {"internal": "internal_layer", "p": "param"} + _pretrained_mapping = {"internal": "internal_layer", "p": "param"} def __init__(self, inp, out): super().__init__() From 3734353a093f5b005221fd68f5ac1fd26007acde Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 13 May 2021 14:43:42 -0700 Subject: [PATCH 18/23] patch models branch temporarily --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cdcebbfff40..6a39298c0d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -152,6 +152,8 @@ jobs: run: | git clone https://github.com/allenai/allennlp-models.git cd allennlp-models + # TODO: remove this + git checkout transformer-init pip install --upgrade --upgrade-strategy eager -e . -r dev-requirements.txt - name: Run models tests From d0f9f462ee8a0e3bcdb1429d4cc4fb940d2c8d3f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 13 May 2021 14:50:01 -0700 Subject: [PATCH 19/23] update CHANGELOG --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ca16d3c4f8..479306bc318 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,12 +11,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Use `dist_reduce_sum` in distributed metrics. - Allow Google Cloud Storage paths in `cached_path` ("gs://..."). +- Renamed `nn.util.load_state_dict()` to `read_state_dict` to avoid confusion with `torch.nn.Module.load_state_dict()`. +- `TransformerModule.from_pretrained_module` now only accepts a pretrained model ID (e.g. "bert-base-case") instead of + an actual `torch.nn.Module`. Other parameters to this method have changed as well. - Print the first batch to the console by default. ### Added - Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module. - Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files. +- Added `nn.util.distributed_device()` helper function. - Added `allennlp.nn.util.load_state_dict` helper function. - Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers such as the `PretrainedTransformerEmbedder` and `PretrainedTransformerMismatchedEmbedder`. From 16c09e54d0af91f77426748287f8a2c038d0ccc6 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 13 May 2021 14:52:08 -0700 Subject: [PATCH 20/23] change default dist loading strategy to 'MEM_EFFICIENT' for T5 --- allennlp/modules/transformer/t5.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index faf134f81c2..575656688ec 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -14,7 +14,10 @@ from allennlp.common import FromParams, Params, Lazy, Registrable from allennlp.common.checks import ConfigurationError -from allennlp.modules.transformer import TransformerModule +from allennlp.modules.transformer.transformer_module import ( + TransformerModule, + DistributedLoadingStrategy, +) from allennlp.modules.transformer.util import ( apply_mask, get_extended_attention_mask, @@ -977,6 +980,7 @@ class T5(TransformerModule, Registrable): _pretrained_ignore = [ r"^decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight$" ] + _distributed_loading_strategy = DistributedLoadingStrategy.MEMORY_EFFICIENT default_implementation = "default" From a30dd77584ea3a1866c6499b0acd297e75d8de9d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 14 May 2021 08:10:08 -0700 Subject: [PATCH 21/23] fix distilbert test --- tests/modules/transformer/self_attention_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modules/transformer/self_attention_test.py b/tests/modules/transformer/self_attention_test.py index 2bceb73d2b8..7a3dcb81ec8 100644 --- a/tests/modules/transformer/self_attention_test.py +++ b/tests/modules/transformer/self_attention_test.py @@ -71,7 +71,7 @@ def test_loading_from_pretrained_weights_using_model_name(pretrained_name, relev seq_len = 3 dim = module.query.in_features hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, 1, 1, seq_len)) + attention_mask = torch.tensor([[1, 1, 0], [1, 0, 1]])[:, None, None, :] # setting to eval mode to avoid non-deterministic dropout. module = module.eval() From 3bcbca2c8f2b2c686179ac92f9acdcfc1d78bb4b Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 14 May 2021 16:53:32 -0700 Subject: [PATCH 22/23] always use memory efficient distributed loading strategy --- allennlp/modules/transformer/__init__.py | 5 +- allennlp/modules/transformer/t5.py | 2 - .../modules/transformer/transformer_module.py | 55 +------------------ .../transformer/transformer_layer_test.py | 10 +--- 4 files changed, 4 insertions(+), 68 deletions(-) diff --git a/allennlp/modules/transformer/__init__.py b/allennlp/modules/transformer/__init__.py index e748cfa9989..9b944130c7c 100644 --- a/allennlp/modules/transformer/__init__.py +++ b/allennlp/modules/transformer/__init__.py @@ -125,10 +125,7 @@ def forward(self, token_ids: torch.LongTensor, mask: torch.BoolTensor): from allennlp.modules.transformer.layer_norm import LayerNorm from allennlp.modules.transformer.positional_encoding import SinusoidalPositionalEncoding -from allennlp.modules.transformer.transformer_module import ( - TransformerModule, - DistributedLoadingStrategy, -) +from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.transformer_embeddings import ( Embeddings, TransformerEmbeddings, diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 575656688ec..15d34f5b2b1 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -16,7 +16,6 @@ from allennlp.common.checks import ConfigurationError from allennlp.modules.transformer.transformer_module import ( TransformerModule, - DistributedLoadingStrategy, ) from allennlp.modules.transformer.util import ( apply_mask, @@ -980,7 +979,6 @@ class T5(TransformerModule, Registrable): _pretrained_ignore = [ r"^decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight$" ] - _distributed_loading_strategy = DistributedLoadingStrategy.MEMORY_EFFICIENT default_implementation = "default" diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index 85952b77551..2a0ffa092ce 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -1,4 +1,3 @@ -from enum import Enum import logging import os from os import PathLike @@ -22,34 +21,6 @@ _T = TypeVar("_T", bound="TransformerModule") -class DistributedLoadingStrategy(Enum): - """ - Strategy options for loading state dictionaries across distributed processes. - """ - - FREE_FOR_ALL = "FREE_FOR_ALL" - """ - Each process loads its own state dict from disk. - """ - - MEMORY_EFFICIENT = "MEMORY_EFFICIENT" - """ - Only the primary process loads the state dict from disk, then it broadcasts - each state tensor one-by-one to the other process groups. - - This is particularly useful when you have multiple distributed workers on the same - machine (shared CPU memory), and don't have enough memory for each process to load - its own copy of the state dict at the same time. - """ - - @classmethod - def from_str(cls, s: str) -> "DistributedLoadingStrategy": - for option in cls: - if option.value.lower() == s.lower(): - return option - raise ValueError(f"Unknown distributed loading strategy: '{s}'") - - class TransformerModule(torch.nn.Module): """ Base class to help with generalized loading of pretrained weights. @@ -83,13 +54,6 @@ class TransformerModule(torch.nn.Module): from a pretrained state dictionary. """ - _distributed_loading_strategy: DistributedLoadingStrategy = ( - DistributedLoadingStrategy.FREE_FOR_ALL - ) - """ - The default strategy for loading a state dictionary within a distributed process group. - """ - _tied_weights: Optional[Dict[str, List[str]]] = None """ A mapping that defines any weights that need to be tied. Keys and values are parameter names. @@ -240,7 +204,6 @@ def from_pretrained_module( ignore: Optional[List[str]] = None, allow_missing: Optional[List[str]] = None, strict: bool = True, - distributed_loading_strategy: Optional[Union[str, DistributedLoadingStrategy]] = None, **kwargs, ) -> _T: """ @@ -293,11 +256,6 @@ def from_pretrained_module( Whether to load the `state_dict` in "strict" model. This only applies when `load_weights` is `True`. - distributed_loading_strategy : `Optional[Union[str, DistributedLoadingStrategy]]`, optional (default = `None`) - The loading strategy to use within a distributed process group. This only applies - when `load_weights` is `True`. If not specified, this class's default is used: - `cls._distributed_loading_strategy`. - **kwargs : `Any` Key word arguments to pass to `cls.from_config()` when instantiating the module. """ # noqa: E501 @@ -307,17 +265,8 @@ def from_pretrained_module( model = cls._from_config(config, **kwargs) if load_weights: - # Resolve the loading strategy to use. - loading_strategy: DistributedLoadingStrategy - if isinstance(distributed_loading_strategy, DistributedLoadingStrategy): - loading_strategy = distributed_loading_strategy - elif isinstance(distributed_loading_strategy, str): - loading_strategy = DistributedLoadingStrategy.from_str(distributed_loading_strategy) - else: - loading_strategy = cls._distributed_loading_strategy - state_dict: Optional[StateDictType] = None - if is_global_primary() or loading_strategy == DistributedLoadingStrategy.FREE_FOR_ALL: + if is_global_primary(): # Load the pretrained HuggingFace state_dict. pretrained_state_dict = cls._get_pretrained_state_dict( model_name, @@ -332,7 +281,7 @@ def from_pretrained_module( missing_keys: List[str] unexpected_keys: List[str] error_msgs: List[str] = [] - if not is_distributed() or loading_strategy == DistributedLoadingStrategy.FREE_FOR_ALL: + if not is_distributed(): assert state_dict is not None logger.info("Loading state_dict into module") missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) diff --git a/tests/modules/transformer/transformer_layer_test.py b/tests/modules/transformer/transformer_layer_test.py index b43d5b23070..4c1e141a5a8 100644 --- a/tests/modules/transformer/transformer_layer_test.py +++ b/tests/modules/transformer/transformer_layer_test.py @@ -14,7 +14,6 @@ from allennlp.modules.transformer import ( AttentionLayer, TransformerLayer, - DistributedLoadingStrategy, ) @@ -297,13 +296,6 @@ def _load_pretrained(global_rank, world_size, gpu_id): ) -def _load_pretrained_mem_efficient(global_rank, world_size, gpu_id): - TransformerLayer.from_pretrained_module( - "epwalsh/bert-xsmall-dummy", - distributed_loading_strategy=DistributedLoadingStrategy.MEMORY_EFFICIENT, - ) - - -@pytest.mark.parametrize("test_func", [_load_pretrained, _load_pretrained_mem_efficient]) +@pytest.mark.parametrize("test_func", [_load_pretrained]) def test_distributed(test_func): run_distributed_test([-1, -1], func=test_func, start_method="spawn") From 36c7ad4d1658af00dcefd4843ae4f451212104f7 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 17 May 2021 12:02:31 -0700 Subject: [PATCH 23/23] Update .github/workflows/ci.yml Co-authored-by: Pete --- .github/workflows/ci.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6a39298c0d2..cdcebbfff40 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -152,8 +152,6 @@ jobs: run: | git clone https://github.com/allenai/allennlp-models.git cd allennlp-models - # TODO: remove this - git checkout transformer-init pip install --upgrade --upgrade-strategy eager -e . -r dev-requirements.txt - name: Run models tests