-
Notifications
You must be signed in to change notification settings - Fork 2.2k
set transformer to evaluation mode #5073
set transformer to evaluation mode #5073
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job! I just have a couple comments.
CHANGELOG.md
Outdated
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 | |||
|
|||
## Unreleased | |||
|
|||
### Changed | |||
- If `train_parameters` in PretrainedTransformerEmbedder is `False`, the transformer's dropout and batch normalization layers are now set to evaluation mode. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have a ### Changed
section on line 13 below, you can just put this bullet point there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
for param in self.transformer_model.parameters(): | ||
param.requires_grad = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might not be necessary anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. This Stack Overflow post says that it's still best practice to turn off gradients computation. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reading that, sounds like we may want to use with torch.no_grad()
in the forward pass?
I'm also curious what happens after this class is initialized as part of a submodule for a Model
, when Model.train()
is called. Will that revert the .eval()
call within this submodule?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth having a test for that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! I fixed this issue, albeit in a hacky way (https://stackoverflow.com/questions/61980943/how-can-i-keep-a-pytorch-submodule-in-eval-mode, https://stackoverflow.com/questions/394770/override-a-method-at-instance-level). Let me know what you think. I also added a test case to ensure that PretrainedTrransformerEmbedder's transformer model remains in eval mode even when module that instantiates it is switched to training mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using with torch.no_grad()
temporarily sets require_grad=False
. Using it conditionally is pretty cumbersome because it would require setting up a ContextManager wrapper (https://stackoverflow.com/questions/22226708/can-a-with-statement-be-used-conditionally). I think this is why the original code explicitly sets require_grad=False
for all parameters. I also don't see any other conditional uses of with torch.no_grad()
. That's why I'm against it. @AkshitaB do you have an opinion on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing you could do is just rename the existing .forward()
method to ._forward()
, and then implement a new .foward()
like:
def forward(self, ...):
if self.train_parameters:
return self._forward(...)
else:
with torch.no_grad():
return self._foward(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... that said, I just realized there are other places in our code-base that would expect requires_grad
to be False
when this module is initialized, if we're not training it. So yea, I think we do need to keep these lines.
…tps://github.com/allenai/allennlp into arjuns/pretrained_transformer_embedder_eval_mode
…former model remains in eval mode even when module that instantiates it is switched to training mode
tests/modules/token_embedders/pretrained_transformer_embedder_test.py
Outdated
Show resolved
Hide resolved
# Override train in transformer_model to prevent it from changing modes | ||
def _train(self, mode): | ||
return self | ||
|
||
setattr( | ||
self.transformer_model, "train", types.MethodType(_train, self.transformer_model) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes me a bit nervous. Another option would be to add something like this to the .forward()
method below (after line 182):
if self.train_parameters and self.training and self.transformer_model.training:
self.transformer_model.eval()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this approach much better, thanks for the suggestion :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One issue with this though is that inspecting .training
will not alway produce the expected result.
# Calling transformer_model.eval() won't change anything now, | ||
# so we have to explicitly set training = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm why doesn't .eval()
work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because .eval()
just calls .train(False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling .eval()
or .train(False)
will recursively call the same on all submodules, and also set training = False
for all submodules (https://github.com/pytorch/pytorch/blob/d490e0120f32dcbb8b23e11eebd638b96b4b0898/torch/nn/modules/module.py#L1594). Isn't that what we want?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you're right, that is what we want. I just updated the code to be more robust.
…test.py Co-authored-by: Evan Pete Walsh <[email protected]>
…tps://github.com/allenai/allennlp into arjuns/pretrained_transformer_embedder_eval_mode
…tps://github.com/allenai/allennlp into arjuns/pretrained_transformer_embedder_eval_mode
def _train(self, mode): | ||
self.training = False | ||
for module in self.children(): | ||
module.train(False) | ||
return self | ||
|
||
setattr( | ||
self.transformer_model, "train", types.MethodType(_train, self.transformer_model) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of patching the .train()
method on self.transformer
model, why not just override the .train()
method PretrainedTransformerEmbedder
?
We could implement it like
@overrides
def train(self, mode: bool = True):
self.training = mode
for name, module in self.named_children():
if not self.train_parameters and name == "transformer_model":
module.eval()
else:
module.train(mode)
return self
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooh, okay, I agree that this is indeed a great solution. Thanks! Just incorporated it.
…s/pretrained_transformer_embedder_eval_mode
allennlp/modules/token_embedders/pretrained_transformer_embedder.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Nice job!
I realize I should have spoken up sooner about this, but would it be possible to change this to add another parameter that controls eval vs. non-eval mode, versus overriding the default behavior of There are also settings where you want the default behavior (i.e., non-eval mode, but frozen parameters). In particular, there's been recent work on trying to do parameter-efficient fine-tuning by, for instance, only fine-tuning the bias terms of the transformer (https://nlp.biu.ac.il/~yogo/bitfit.pdf). In this case, it's much more ergonomic to set Lastly, I think that semantically |
Fixes #4895 .
Changes proposed in this pull request:
train_parameters
in PretrainedTransformerEmbedder isFalse
, the transformer's dropout and batch normalization layers are now set to evaluation mode.