Skip to content

Thoughts on the on_grad_computed API #378

Closed
@thomasjpfan

Description

@thomasjpfan

The on_grad_computed api contains both net and named_parameters, where named_parameters is set by calling list(self.module_.named_parameters()). The callback can already access the module by calling net.module_.named_parameters().

The calling of list on the named_parameters does add some overhead per training loop:

 import torchvision

m = torchvision.models.densenet201() 
%time b = list(m.named_parameters())
# 2.47 ms ± 33.1 µs per loop

This overhead is added per batch, so for an epoch with 500 batches, this would add 1.25 seconds to that epoch.

I propose simplifying on_grad_computed by removing named_parameters. This is a fairly small change, but does break backwards compatibility.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions