Skip to content

Thoughts on the on_grad_computed API #378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
thomasjpfan opened this issue Oct 29, 2018 · 2 comments
Closed

Thoughts on the on_grad_computed API #378

thomasjpfan opened this issue Oct 29, 2018 · 2 comments

Comments

@thomasjpfan
Copy link
Member

thomasjpfan commented Oct 29, 2018

The on_grad_computed api contains both net and named_parameters, where named_parameters is set by calling list(self.module_.named_parameters()). The callback can already access the module by calling net.module_.named_parameters().

The calling of list on the named_parameters does add some overhead per training loop:

 import torchvision

m = torchvision.models.densenet201() 
%time b = list(m.named_parameters())
# 2.47 ms ± 33.1 µs per loop

This overhead is added per batch, so for an epoch with 500 batches, this would add 1.25 seconds to that epoch.

I propose simplifying on_grad_computed by removing named_parameters. This is a fairly small change, but does break backwards compatibility.

@ottonemo
Copy link
Member

ottonemo commented Oct 29, 2018

The reason why we do it this way is that the receiver does not know how many modules there are, if the training loop was overwritten and if there is more than one gradient update per loop for different parameter sets. For example there might be a generator/discriminator cycle where you'll have two gradient computation steps per epoch with different sets of parameters.

Another approach would be to have a lazy generator that yields the named parameters to everyone who asks, something like this:

notify('on_grad_computed', LazyGenerator(lambda: m.named_parameters()))

class LazyGenerator:
    def __init__(self, p): self.p = p
    def __iter__(self): 
        if not isinstance(self.p, list): self.p = list(self.p())
        yield from self.p

@thomasjpfan
Copy link
Member Author

I agree, the LazyGenerator idea works well to solve the overhead of calling list, without changing the API.

For example there might be a generator/discriminator cycle

Setting skorch up to train a GAN, would be a good tutorial to write. It should help resolve #295.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants