Skip to content

Allow specifying method as a string #2809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 25, 2023

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Jan 18, 2023

What does this PR do?

Allows passing method as a string to apply. This will simply get the specified attribute from self.

Rationale

Being able to specify the method without having direct access to the instance is important to improve the ergonomics of abstractions like TrainState where you only have access to apply via TrainState.apply_fn.

Currently we promote the use of state like this:

@jax.jit
def train_step(state, batch):
  ...
  logits = state.apply_fn({'params': state.params, ...)
  ...

However, if you happen to need to use a method you now have to pass it somehow e.g:

@partial(jax.jit, static_argnums=(2,))
def train_step(state, batch, method):
  ...
  logits = state.apply_fn({'params': state.params, ..., method=method)
  ...

This involves a couple of changes throughout the code base for a relatively simple operation. With the proposed change you can now specify it very simply as:

@jax.jit
def train_step(state, batch):
  ...
  logits = state.apply_fn({'params': state.params, ..., method='some_method')
  ...

@codecov-commenter
Copy link

Codecov Report

Merging #2809 (07c769d) into main (e51d017) will increase coverage by 0.03%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main    #2809      +/-   ##
==========================================
+ Coverage   81.23%   81.26%   +0.03%     
==========================================
  Files          53       53              
  Lines        5659     5669      +10     
==========================================
+ Hits         4597     4607      +10     
  Misses       1062     1062              
Impacted Files Coverage Δ
flax/linen/module.py 92.34% <100.00%> (+0.11%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@cgarciae cgarciae requested a review from jheek January 18, 2023 16:26
Copy link
Member

@jheek jheek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cgarciae cgarciae self-assigned this Jan 19, 2023
@cgarciae
Copy link
Collaborator Author

@jheek added the same behavior for init

@copybara-service copybara-service bot merged commit a309273 into google:main Jan 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants