Skip to content

Linear combinations of parameters across models #1401

Answered by jheek
pharringtonp19 asked this question in Q&A
Discussion options

You must be logged in to vote

To do this you can make use the jax.tree_map util to apply a function (mean in this case) over each param as follows:

avg_params = jax.tree_map(lambda p: jnp.mean(p, axis=0), params)

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@pharringtonp19
Comment options

Comment options

You must be logged in to vote
1 reply
@pharringtonp19
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants