Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Changes and improvements to how we initialize transformer modules from pretrained models #5200

Merged
merged 27 commits into from
May 17, 2021

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented May 13, 2021

Changes proposed in this pull request:

  • Changed the way TransformerModule is initialized from a pretrained HuggingFace model. We initialize the architecture of the TransformerModule from the corresponding HuggingFace config and then load the state dictionary directly instead of initializing our architecture and weights from the corresponding HuggingFace model instance.

    In the non-distributed case, this cuts down the maximum memory usage of model weights by 1/3 since we only need 2 versions of the weights in memory while we are loading the pretrained weights (our uninitialized weights and the pretrained state_dict from HuggingFace) as opposed to 3 copies (our uninitialized weights, the uninitialized weights of the HuggingFace model, and the pretrained state_dict from HuggingFace).

  • Added an even more memory efficient way of loading a state dictionary in the distributed scenario. This works by loading the state dictionary in only the primary distributed worker, and then broadcasting the weights one-by-one to the other workers.

    The logic for this is agnostic to the model architecture, so we could easily use it in other models. It's wrapped up in a util function: allennlp.nn.util.load_state_dict_distributed.

Before submitting

  • I've read and followed all steps in the Making a pull request
    section of the CONTRIBUTING docs.
  • I've updated or added any relevant docstrings following the syntax described in the
    Writing docstrings section of the CONTRIBUTING docs.
  • If this PR fixes a bug, I've added a test that will fail without my fix.
  • If this PR adds a new feature, I've added tests that sufficiently cover my new functionality.

After submitting

  • All GitHub Actions jobs for my pull request have passed.
  • codecov/patch reports high test coverage (at least 90%).
    You can find this under the "Actions" tab of the pull request once the other checks have finished.

@epwalsh epwalsh requested a review from AkshitaB May 13, 2021 22:18
Comment on lines +73 to +76
if "start_method" in kwargs:
start_method = kwargs.pop("start_method")
else:
start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of hacky, but I had an issue on Mac where the only work-around was the use "spawn" as the start method, and I didn't want to change the signature of this function because it's used in a lot of places.

Copy link
Contributor

@AkshitaB AkshitaB left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great overall! I have a few initial thoughts/comments.

Strategy options for loading state dictionaries across distributed processes.
"""

FREE_FOR_ALL = "FREE_FOR_ALL"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which cases will we choose the FREE_FOR_ALL option? Or to put it differently, why can't we be memory efficient all the time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I guess I was thinking that since FREE_FOR_ALL is faster, it will probably preferred when you have enough memory.

I should benchmark the two approaches. If there's not a significant different for medium sized models then I agree, we should use the memory efficient strategy all the time. I'll get back to you on that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There wasn't much of a difference for t5-base on my Mac with 2 workers (3.55 seconds vs 3.87 seconds). So I think it's worth it to always use the memory efficient strategy.

kwargs = {key: params_dict[key] for key in required_kwargs}

module = BiModalEncoder.from_pretrained_module("bert-base-cased", **kwargs)
assert_allclose(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only compare a single parameter instead of all? I think comparing all of them is essential, especially when we account for cases where we are loading partial weights from a huggingface model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so the main reason I changed that here is that I ended up removing/modifying the _construct_default_mapping method which was needed to do this comparison.

That said, I think we have enough enough test coverage because .from_pretrained_module() will throw an error if there are missing or unexpected keys or size mismatches in the part of the pretrained state dictionary corresponding to the current module. And we also have tests that ensure the outputs from .forward() match. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, fair enough. I think the coverage is pretty good.

def from_pretrained_module( # type: ignore
cls,
pretrained_module: Union[str, torch.nn.Module],
num_hidden_layers: Optional[Union[int, range]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are losing the ability to load only some of the layers here. I thought we had a test for this; looks like we don't. But this is the functionality I'm talking about:

"bert-base-uncased", num_hidden_layers=range(0, 8)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry I should have left a comment about this. This is still possible, and we have a test for it here:

self.separate_transformer = TransformerStack.from_pretrained_module(
"bert-base-cased",
relevant_module="bert.encoder",
num_hidden_layers=8,
strict=False,
)
self.combined_transformer = TransformerStack.from_pretrained_module(
"bert-base-cased",
relevant_module="bert.encoder",
num_hidden_layers=4,
mapping={f"layer.{l}": f"layers.{i}" for (i, l) in enumerate(range(8, 12))},
strict=False,
)

I changed this up a bit so that we didn't have to override the from_pretrained_module method and change the signature. Also I don't think passing a range type would work through config files anyway, right? But if this way is too clunky we could up with alternate, easier way to do this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this works!

@AkshitaB AkshitaB merged commit cf113d7 into main May 17, 2021
@AkshitaB AkshitaB deleted the transformer-init branch May 17, 2021 19:25
Abhishek-P pushed a commit to Abhishek-P/allennlp that referenced this pull request Aug 11, 2021
…m pretrained models (allenai#5200)

* updates

* rename 'load_state_dict' -> 'read_state_dict'

* fix TransformerStack

* more fixes

* fix embeddings

* fix toolkit tests

* fix self attention

* fix bimodal encoder tests

* fix more tests

* fix T5!

* fixes

* fix backbone

* fix

* fixes

* fix

* doc fixes

* name changes

* patch models branch temporarily

* update CHANGELOG

* change default dist loading strategy to 'MEM_EFFICIENT' for T5

* fix distilbert test

* always use memory efficient distributed loading strategy

* Update .github/workflows/ci.yml

Co-authored-by: Pete <[email protected]>

Co-authored-by: Akshita Bhagia <[email protected]>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants