Skip to content

Feature request: ability to apply stop gradient to some parameters #1931

Answered by jheek
NeilGirdhar asked this question in Q&A
Discussion options

You must be logged in to vote

Here's an sketch of what that would look like:

from flax import traverse_util

 def selective_stop_grad(variables):
      flat_vars = traverse_util.flatten_dict(variables)
      new_vars = {k: lax.stop_gradient(v) if some_filter_fn(k) else v for k, v in flat_vars.items()}
      return traverse_util.unflatten_dict(new_vars)


class MySGModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    MySGSubModule = nn.map_variables(MySubModule, "params", selective_stop_grad, init=True)
    return MySGSubModule(...)(x)

Replies: 7 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@jheek
Comment options

@NeilGirdhar
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants
Converted from issue

This discussion was converted from issue #1857 on February 22, 2022 09:56.