Feature request: ability to apply stop gradient to some parameters #1931
-
To motivate this feature request, I'll explain what I'm currently doing (without Flax), and the other solutions I've considered. Then I'll suggest some Flax solution. Problem:In the process of inferring one of my modules, I need to mask varying subsets of the weights with stop-gradient in a single function: (click for long snippet)
def infer(encoding: EncodingElement,
observation: PoolingMessage,
prediction: PredictionMessage,
rng: Generator,
weights: FrozenVariableDict) -> TwoPassEncodingConfiguration:
sampler_rng, code_rng = rng.split()
# Create four copies of the weights:
# * weights_sg has stop_gradient applied to all weights, and
# * the other three have stop_gradient applied to different partitions of
# the weights.
weights_sg, weights_g, weights_c, weights_e = _stop_gradient_on_some_weights(weights)
# Inference ------------------------------------------------------------------------------------
# This function uses weights_sg so this calculation won't poison the weight cotangents.
# However, cotangents still propagate back to observation.
code_message = encoding.code_message(observation, weights_sg)
# GLN loss -------------------------------------------------------------------------------------
# The scan parameters depend on weights_g.
encoding_parameters_g = SamplerParameters(observation, prediction, weights_g)
# This use of stop_gradient prevents the cotangents from propagating back from the scan through
# to the observation.
initial_code_message = stop_gradient(code_message)
# This class manages an iterated function (a scan)
sampler = EncodingSampler(encoding)
sampler_iterations = encoding.inference_parameters.sampler_iterations
initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng)
# This is an extremely computationally expensive scan.
sampler_state, sampler_trajectory = sampler.sample_trajectory(
encoding_parameters_g, initial_sampler_state, sampler_iterations, None)
# We calculate a GLN loss, which can only affects the subset of weights in weights_g.
gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss)
/ sampler_iterations)
iterative_code_message = sampler_state.code_message
# Code loss ------------------------------------------------------------------------------------
# The code loss trains the code and selection links to produce a code message that predicts the
# code message that we inferred by iteration.
# This is the same code_message function as above, but uses weights_c.
c_code_message = encoding.code_message(observation, weights_c, rng=code_rng,
use_code_signal_noise=True)
# When this loss is minimized only the weights that are not marked stop-gradient in weights_c
# are adjusted. Cotangents are also blocked from poisoning the scan by applying stop_gradient
# to its outputs.
code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence)
- c_code_message.log_presence))
code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value)
- c_code_message.code_value))
code_loss = code_presence_loss + code_value_loss
# Snipped a lot of code here that uses weights_e and produces output primals.
return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss)
# Below is the code that uses Haiku to partition the weights and apply stop gradient to different
# partitions.
_module_classes = [{'gln'}, {'code_value', 'code_presence'}, {'explanation'}]
def _module_predicate(module_name: str,
name: str,
value: Array) -> int:
prefix = module_name.split('/')[0]
for i, prefix_set in enumerate(_module_classes):
if prefix in prefix_set:
return i
raise RuntimeError
# I was using Haiku before, but I'll have to port this to Flax somehow.
def _partition_by_module(weights: FrozenVariableDict) -> tuple[FrozenVariableDict, ...]:
return hk.data_structures.partition_n(_module_predicate, # type: ignore[arg-type]
weights, len(_module_classes))
def _stop_gradient_on_some_weights(weights: FrozenVariableDict) -> list[FrozenVariableDict]:
weights_sg = stop_gradient(weights)
weights_p = _partition_by_module(weights)
weights_sg_p = _partition_by_module(weights_sg)
return ([weights_sg]
+ [hk.data_structures.merge(weights_pi,
*[weights_sg_pi
for j, weights_sg_pi in enumerate(weights_sg_p)
if i != j])
for i, weights_pi in enumerate(weights_p)]) Non-solution:I discussed this with @cgarciae and brainstormed a non-solution: I could try to put the "C", "G", and "E" weights into different "collections". And then run inference three times. This doesn't work because:
Possible Flax interface:We came up with two Flax interfaces that might work. I suggested some kind of context manager (click for long snippet)
def infer(encoding: EncodingElement,
observation: PoolingMessage,
prediction: PredictionMessage,
rng: Generator,
weights: FrozenVariableDict) -> TwoPassEncodingConfiguration:
sampler_rng, code_rng = rng.split()
# Inference ------------------------------------------------------------------------------------
# This function uses weights_sg so this calculation won't poison the weight cotangents.
# However, cotangents still propagate back to observation.
with nn.stop_gradient(lambda c: True):
code_message = encoding.code_message(observation)
# GLN loss -------------------------------------------------------------------------------------
encoding_parameters_g = SamplerParameters(observation, prediction)
# This use of stop_gradient prevents the cotangents from propagating back from the scan through
# to the observation.
initial_code_message = stop_gradient(code_message)
sampler = EncodingSampler(encoding)
sampler_iterations = encoding.inference_parameters.sampler_iterations
initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng)
# The scan parameters depend on weights_g.
with nn.stop_gradient(lambda c: c.name.starts_with('gln')):
# This class manages an iterated function (a scan)
# This is an extremely computationally expensive scan.
sampler_state, sampler_trajectory = sampler.sample_trajectory(
encoding_parameters_g, initial_sampler_state, sampler_iterations, None)
# We calculate a GLN loss, which can only affects the subset of weights in weights_g.
gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss)
/ sampler_iterations)
iterative_code_message = sampler_state.code_message
# Code loss ------------------------------------------------------------------------------------
# The code loss trains the code and selection links to produce a code message that predicts the
# code message that we inferred by iteration.
# This is the same code_message function as above, but uses weights_c.
with nn.stop_gradient(lambda c: c.name.starts_with('code')):
c_code_message = encoding.code_message(observation, rng=code_rng,
use_code_signal_noise=True)
# When this loss is minimized only the weights that are not marked stop-gradient in weights_c
# are adjusted. Cotangents are also blocked from poisoning the scan by applying stop_gradient
# to its outputs.
code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence)
- c_code_message.log_presence))
code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value)
- c_code_message.code_value))
code_loss = code_presence_loss + code_value_loss
# Snipped a lot of code here that uses weights_e and produces output primals.
return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss) Cristian suggested a lifting transformation like those found in Possible side benefitsBesides applying stop-gradient, this kind of system may be able to do other things with parameters such as:
Of course, that's beyond this feature request, but I mention these ideas as something to keep in mind when considering solutions. ConclusionAm I missing an easy solution to my problem? If not, I will need to solve this problem in order to use Flax since this use of stop-gradient is integral to my research. Thanks for reading! |
Beta Was this translation helpful? Give feedback.
Replies: 7 comments 2 replies
-
So in the haiku version of the code you are solving this "outside" of Haiku by operating on the variables dict directly and making 4 copies. In Flax you could do something similair for example by using |
Beta Was this translation helpful? Give feedback.
-
Here's an sketch of what that would look like:
|
Beta Was this translation helpful? Give feedback.
-
@jheek the A HOWTO about freezing parameters using this strategy would be great. |
Beta Was this translation helpful? Give feedback.
-
I've been trying to implement your solution, but I can't seem to get it working for me. Here's roughly what I have: from __future__ import annotations
from collections.abc import Callable
from dataclasses import asdict
from typing import Any, Generic, TypeVar
import flax.linen as nn
import jax.numpy as jnp
from flax import traverse_util
from flax.core.scope import FrozenVariableDict
from jax.lax import stop_gradient
from jax.random import PRNGKey
T = TypeVar('T', bound=nn.Module)
class StopGradientModule(nn.Module, Generic[T]):
filter_f: Callable[[tuple[str, ...]], bool]
submodule_cls: Callable[..., T]
def setup(self) -> None:
self.submodule = nn.map_variables(self.submodule_cls, True, self._selective_stop_gradient)
def __call__(self, module: T) -> T:
return self.submodule(**asdict(module))
def _selective_stop_gradient(self, variables: FrozenVariableDict) -> dict[str, Any]:
flat_vars = traverse_util.flatten_dict(variables) # type: ignore[no-untyped-call]
new_vars = {k: stop_gradient(v)
if self.filter_f(k) else v
for k, v in flat_vars.items()}
return traverse_util.unflatten_dict(new_vars) # type: ignore[no-untyped-call]
class X(nn.Module):
def setup(self) -> None:
self.dense = nn.Dense(10)
# stop_gradient_all is a copy of self whose parameters are identical, but whose parameter
# cotangents are always zero.
self.stop_gradient_all = StopGradientModule(lambda _: True, X)
def f(self, x: Any) -> Any:
return self.dense(x), self.stop_gradient_all(self).dense(x)
print(X().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=X.f)) gives Traceback (most recent call last):
File "/home/neil/src/cmm/a.py", line 44, in <module>
print(X().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=X.f))
File "/home/neil/src/cmm/a.py", line 41, in f
return self.dense(x), self.stop_gradient_all(self).dense(x)
File "/home/neil/src/cmm/a.py", line 37, in setup
self.dense = nn.Dense(10)
ValueError: Duplicate use of scope name: "dense" I realize that this is currently a recursive mess, and I'm exploring the simplest way of accomplishing what I'm trying to accomplish. |
Beta Was this translation helpful? Give feedback.
-
I'm still trying to get this working. Here's what I have now: from __future__ import annotations
from collections.abc import Callable
from typing import Any, Generic, TypeVar
import flax.linen as nn
import jax.numpy as jnp
from flax import traverse_util
from flax.core.scope import FrozenVariableDict
from jax.lax import stop_gradient
from jax.random import PRNGKey
from tjax import print_generic
T = TypeVar('T', bound=nn.Module)
class StopGradientModule(nn.Module, Generic[T]):
filter_f: Callable[[tuple[str, ...]], bool]
submodule_cls: Callable[..., T]
def setup(self) -> None:
mapped_cls = nn.map_variables(self.submodule_cls, True, self._selective_stop_gradient,
methods=['f'])
self.submodule = mapped_cls()
def f(self, x: Any) -> Any:
print("Calling")
return self.submodule.f(x)
def _selective_stop_gradient(self, variables: FrozenVariableDict) -> dict[str, Any]:
flat_vars = traverse_util.flatten_dict(variables) # type: ignore[no-untyped-call]
new_vars = {k: stop_gradient(v)
if self.filter_f(k) else v
for k, v in flat_vars.items()}
return traverse_util.unflatten_dict(new_vars) # type: ignore[no-untyped-call]
def __call__(self):
assert False
class X(nn.Module):
def setup(self) -> None:
self.dense = nn.Dense(3)
def f(self, x: Any) -> Any:
return self.dense(x)
def __call__(self):
assert False
class Y(nn.Module):
def setup(self) -> None:
self.x = X()
# stop_gradient_all is a copy of x whose parameters are identical, but whose parameter
# cotangents are always zero.
self.stop_gradient_all = StopGradientModule(lambda _: True, X)
def f(self, x: Any) -> Any:
y = self.x.f(x)
return y, self.stop_gradient_all.f(x)
def __call__(self):
assert False
(y, y_prime), variables = Y().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=Y.f)
print(y, y_prime)
print_generic(variables) gives Traceback (most recent call last):
File "/home/neil/src/cmm/a.py", line 66, in <module>
(y, y_prime), variables = Y().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=Y.f)
File "/home/neil/src/cmm/a.py", line 61, in f
return y, self.stop_gradient_all.f(x)
File "/home/neil/src/cmm/a.py", line 28, in f
return self.submodule.f(x)
File "/home/neil/src/cmm/a.py", line 46, in f
return self.dense(x)
File "/home/neil/src/flax/flax/linen/linear.py", line 177, in __call__
kernel = self.param('kernel',
flax.errors.ScopeCollectionNotFound: Tried to access "kernel" from collection "params"" in "/stop_gradient_all/map_variables(submodule)/dense" but the collection is emtpy. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeCollectionNotFound) |
Beta Was this translation helpful? Give feedback.
-
Ah that's because map_variables makes collections immutable. You need to provide a function that maps back the variables on output or pass init=True such that during init the map_variables isn't called. Can you try passing |
Beta Was this translation helpful? Give feedback.
-
@jheek Thanks, that gets it to run, but it's still not reflecting a copy of x's parameters? It outputs:
So |
Beta Was this translation helpful? Give feedback.
Here's an sketch of what that would look like: