Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

set transformer to evaluation mode #5073

Merged

Conversation

ArjunSubramonian
Copy link
Contributor

Fixes #4895 .

Changes proposed in this pull request:

  • If train_parameters in PretrainedTransformerEmbedder is False, the transformer's dropout and batch normalization layers are now set to evaluation mode.

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job! I just have a couple comments.

CHANGELOG.md Outdated
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed
- If `train_parameters` in PretrainedTransformerEmbedder is `False`, the transformer's dropout and batch normalization layers are now set to evaluation mode.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a ### Changed section on line 13 below, you can just put this bullet point there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 128 to 129
for param in self.transformer_model.parameters():
param.requires_grad = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be necessary anymore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. This Stack Overflow post says that it's still best practice to turn off gradients computation. Thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading that, sounds like we may want to use with torch.no_grad() in the forward pass?

I'm also curious what happens after this class is initialized as part of a submodule for a Model, when Model.train() is called. Will that revert the .eval() call within this submodule?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth having a test for that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! I fixed this issue, albeit in a hacky way (https://stackoverflow.com/questions/61980943/how-can-i-keep-a-pytorch-submodule-in-eval-mode, https://stackoverflow.com/questions/394770/override-a-method-at-instance-level). Let me know what you think. I also added a test case to ensure that PretrainedTrransformerEmbedder's transformer model remains in eval mode even when module that instantiates it is switched to training mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using with torch.no_grad() temporarily sets require_grad=False. Using it conditionally is pretty cumbersome because it would require setting up a ContextManager wrapper (https://stackoverflow.com/questions/22226708/can-a-with-statement-be-used-conditionally). I think this is why the original code explicitly sets require_grad=False for all parameters. I also don't see any other conditional uses of with torch.no_grad(). That's why I'm against it. @AkshitaB do you have an opinion on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing you could do is just rename the existing .forward() method to ._forward(), and then implement a new .foward() like:

def forward(self, ...):
    if self.train_parameters:
        return self._forward(...)
    else:
        with torch.no_grad():
            return self._foward(...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... that said, I just realized there are other places in our code-base that would expect requires_grad to be False when this module is initialized, if we're not training it. So yea, I think we do need to keep these lines.

AkshitaB and others added 4 commits March 30, 2021 16:51
…former model remains in eval mode even when module that instantiates it is switched to training mode
Comment on lines 132 to 138
# Override train in transformer_model to prevent it from changing modes
def _train(self, mode):
return self

setattr(
self.transformer_model, "train", types.MethodType(_train, self.transformer_model)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me a bit nervous. Another option would be to add something like this to the .forward() method below (after line 182):

if self.train_parameters and self.training and self.transformer_model.training:
    self.transformer_model.eval()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this approach much better, thanks for the suggestion :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue with this though is that inspecting .training will not alway produce the expected result.

Comment on lines 128 to 129
# Calling transformer_model.eval() won't change anything now,
# so we have to explicitly set training = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm why doesn't .eval() work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because .eval() just calls .train(False)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling .eval() or .train(False) will recursively call the same on all submodules, and also set training = False for all submodules (https://github.com/pytorch/pytorch/blob/d490e0120f32dcbb8b23e11eebd638b96b4b0898/torch/nn/modules/module.py#L1594). Isn't that what we want?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right, that is what we want. I just updated the code to be more robust.

Comment on lines 129 to 137
def _train(self, mode):
self.training = False
for module in self.children():
module.train(False)
return self

setattr(
self.transformer_model, "train", types.MethodType(_train, self.transformer_model)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of patching the .train() method on self.transformer model, why not just override the .train() method PretrainedTransformerEmbedder?

We could implement it like

@overrides
def train(self, mode: bool = True):
    self.training = mode
    for name, module in self.named_children():
        if not self.train_parameters and name == "transformer_model":
            module.eval()
        else:
            module.train(mode)
    return self

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh, okay, I agree that this is indeed a great solution. Thanks! Just incorporated it.

Arjun Subramonian added 2 commits March 31, 2021 12:17
Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Nice job!

@ArjunSubramonian ArjunSubramonian merged commit 63a3b48 into main Mar 31, 2021
@ArjunSubramonian ArjunSubramonian deleted the arjuns/pretrained_transformer_embedder_eval_mode branch March 31, 2021 20:33
@nelson-liu
Copy link
Contributor

nelson-liu commented Apr 8, 2021

I realize I should have spoken up sooner about this, but would it be possible to change this to add another parameter that controls eval vs. non-eval mode, versus overriding the default behavior of train_parameters (asking before the next release goes out and it'd become a breaking change)? In particular, this behavior is pretty different from the setting where you use Elmo with requires_grad=False. In that setting, the Elmo weights actually aren't in eval mode.

There are also settings where you want the default behavior (i.e., non-eval mode, but frozen parameters). In particular, there's been recent work on trying to do parameter-efficient fine-tuning by, for instance, only fine-tuning the bias terms of the transformer (https://nlp.biu.ac.il/~yogo/bitfit.pdf). In this case, it's much more ergonomic to set train_parameters to False and then have a regex for [["^_text_field_embedder.token_embedder_tokens.transformer_model.*bias"], {"requires_grad": true}], versus the having to write a regex like [["^_text_field_embedder.token_embedder_tokens.transformer_model.*(?<!bias)$"], {"requires_grad": false}].

Lastly, I think that semantically train_parameters doesn't necessarily imply eval mode---they seem like distinct ways of modifying the model (e.g., how the Elmo embedder works right now).

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Option to run PretrainedTransformer in eval mode
4 participants