Skip to content

Recommended way to implement gradient accumulation #506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
fabiocapsouza opened this issue Aug 8, 2019 · 4 comments
Closed

Recommended way to implement gradient accumulation #506

fabiocapsouza opened this issue Aug 8, 2019 · 4 comments

Comments

@fabiocapsouza
Copy link

fabiocapsouza commented Aug 8, 2019

Hi,

I am new to Skorch and I am implementing BERT fine-tuning using Skorch.
One of the features that is missing is gradient accumulation, where loss.backward() is called on every batch but the optimizer is called only after N consecutive batches.
I believe I have to override train_step_single to remove zero_grad() before every step and train_step to perform optimizer_.step() conditionally.
Also, the Net has to keep track of the number of batches during training and be able to keep a state that can be accessed inside train_step.
Where is the recommended place to save this kind of state?

Thanks,

@BenjaminBossan
Copy link
Collaborator

I didn't know about gradient accumulation, so I don't have any code for you. That being said, the number of training batches can be inferred from the history: len(self.history[-1, 'batches']) (this assumes that training steps are performed before validation steps, which is what normally happens).

Regarding running loss.backward() several time before calling the optimizer: That I think is not so trivial. Apart from overriding train_step, you might even need to override fit_loop. The reason is that we assume there that the optimizer is called for each batch.

I don't know enough about the topic to give you a solution to this. But if you tinker a bit and find a working solution, please post it here. Maybe we can then point you towards improving it.

@BenjaminBossan
Copy link
Collaborator

@fabiocapsouza Did you make progress on this?

@fabiocapsouza
Copy link
Author

@BenjaminBossan, thanks for your first reply and sorry for the delay.
After your comment, I decided to postpone my migration to Skorch for this task, since it reinforced my feelings that it would require a lot of work. I would not have the time to test all the needed changes that could easily introduce bugs.

Thanks again for your quick response, it really helped me.

@BenjaminBossan
Copy link
Collaborator

since it reinforced my feelings that it would require a lot of work.

I think it will be worth exploring if we can refactor skorch to make this easier. Until then, you are probably better off using your current solution. If you want to try with skorch again at some point and maybe help improving it, feel free to open the issue again or starting a new one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants