-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Changes and improvements to how we initialize transformer modules from pretrained models #5200
Conversation
if "start_method" in kwargs: | ||
start_method = kwargs.pop("start_method") | ||
else: | ||
start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kind of hacky, but I had an issue on Mac where the only work-around was the use "spawn" as the start method, and I didn't want to change the signature of this function because it's used in a lot of places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great overall! I have a few initial thoughts/comments.
Strategy options for loading state dictionaries across distributed processes. | ||
""" | ||
|
||
FREE_FOR_ALL = "FREE_FOR_ALL" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In which cases will we choose the FREE_FOR_ALL
option? Or to put it differently, why can't we be memory efficient all the time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I guess I was thinking that since FREE_FOR_ALL
is faster, it will probably preferred when you have enough memory.
I should benchmark the two approaches. If there's not a significant different for medium sized models then I agree, we should use the memory efficient strategy all the time. I'll get back to you on that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There wasn't much of a difference for t5-base
on my Mac with 2 workers (3.55 seconds vs 3.87 seconds). So I think it's worth it to always use the memory efficient strategy.
kwargs = {key: params_dict[key] for key in required_kwargs} | ||
|
||
module = BiModalEncoder.from_pretrained_module("bert-base-cased", **kwargs) | ||
assert_allclose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only compare a single parameter instead of all? I think comparing all of them is essential, especially when we account for cases where we are loading partial weights from a huggingface model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, so the main reason I changed that here is that I ended up removing/modifying the _construct_default_mapping
method which was needed to do this comparison.
That said, I think we have enough enough test coverage because .from_pretrained_module()
will throw an error if there are missing or unexpected keys or size mismatches in the part of the pretrained state dictionary corresponding to the current module. And we also have tests that ensure the outputs from .forward()
match. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, fair enough. I think the coverage is pretty good.
def from_pretrained_module( # type: ignore | ||
cls, | ||
pretrained_module: Union[str, torch.nn.Module], | ||
num_hidden_layers: Optional[Union[int, range]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are losing the ability to load only some of the layers here. I thought we had a test for this; looks like we don't. But this is the functionality I'm talking about:
"bert-base-uncased", num_hidden_layers=range(0, 8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, sorry I should have left a comment about this. This is still possible, and we have a test for it here:
allennlp/tests/modules/transformer/toolkit_test.py
Lines 85 to 97 in d2f43c7
self.separate_transformer = TransformerStack.from_pretrained_module( | |
"bert-base-cased", | |
relevant_module="bert.encoder", | |
num_hidden_layers=8, | |
strict=False, | |
) | |
self.combined_transformer = TransformerStack.from_pretrained_module( | |
"bert-base-cased", | |
relevant_module="bert.encoder", | |
num_hidden_layers=4, | |
mapping={f"layer.{l}": f"layers.{i}" for (i, l) in enumerate(range(8, 12))}, | |
strict=False, | |
) |
I changed this up a bit so that we didn't have to override the from_pretrained_module
method and change the signature. Also I don't think passing a range
type would work through config files anyway, right? But if this way is too clunky we could up with alternate, easier way to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, this works!
Co-authored-by: Pete <[email protected]>
…m pretrained models (allenai#5200) * updates * rename 'load_state_dict' -> 'read_state_dict' * fix TransformerStack * more fixes * fix embeddings * fix toolkit tests * fix self attention * fix bimodal encoder tests * fix more tests * fix T5! * fixes * fix backbone * fix * fixes * fix * doc fixes * name changes * patch models branch temporarily * update CHANGELOG * change default dist loading strategy to 'MEM_EFFICIENT' for T5 * fix distilbert test * always use memory efficient distributed loading strategy * Update .github/workflows/ci.yml Co-authored-by: Pete <[email protected]> Co-authored-by: Akshita Bhagia <[email protected]>
Changes proposed in this pull request:
Changed the way
TransformerModule
is initialized from a pretrained HuggingFace model. We initialize the architecture of theTransformerModule
from the corresponding HuggingFace config and then load the state dictionary directly instead of initializing our architecture and weights from the corresponding HuggingFace model instance.In the non-distributed case, this cuts down the maximum memory usage of model weights by 1/3 since we only need 2 versions of the weights in memory while we are loading the pretrained weights (our uninitialized weights and the pretrained
state_dict
from HuggingFace) as opposed to 3 copies (our uninitialized weights, the uninitialized weights of the HuggingFace model, and the pretrainedstate_dict
from HuggingFace).Added an even more memory efficient way of loading a state dictionary in the distributed scenario. This works by loading the state dictionary in only the primary distributed worker, and then broadcasting the weights one-by-one to the other workers.
The logic for this is agnostic to the model architecture, so we could easily use it in other models. It's wrapped up in a util function:
allennlp.nn.util.load_state_dict_distributed
.Before submitting
section of the
CONTRIBUTING
docs.Writing docstrings section of the
CONTRIBUTING
docs.After submitting
codecov/patch
reports high test coverage (at least 90%).You can find this under the "Actions" tab of the pull request once the other checks have finished.