Marking a non-jittable inner function as non-jittable. #10444
-
Hi there, I've been wondering if there is a way in import jax
import flax.linen as nn
from typing import Any
Array = Any
def some_non_jit_func(x, a, b):
y = a*b
return y[y >= x]
class inner_layer(nn.Module):
param: Array
@nn.compact
def __call__(self, inputs):
x = param*x
class wrapper_layer(nn.Module):
a: int
b: int
x: int
@nn.nowrap:
def get_layer(self, param):
return inner_layer(param)
@nn.compact:
def __call__(self, inputs):
param = f(self.x, self.a, self.b)
layer = self.get_layer(param)
x = layer(param)
return x after which the The helper function Any help would be greatly appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
You can use host_callback. import jax
import flax.linen as nn
from typing import Any
Array = Any
def some_non_jit_func(x, a, b):
y = a*b
return y[y >= x]
def make_model(x, a, b):
z = some_non_jit_func(x, a, b) # do anything
class inner_layer(nn.Module):
param: Array
@nn.compact
def __call__(self, inputs):
return param * inputs
class wrapper_layer(nn.Module):
@nn.nowrap:
def get_layer(self, param):
return inner_layer(param)
@nn.compact:
def __call__(self, inputs):
layer = self.get_layer(z)
x = layer(inputs)
return x
return wrapper_layer |
Beta Was this translation helpful? Give feedback.
-
Thanks for the question! You might be able to use |
Beta Was this translation helpful? Give feedback.
-
Thanks for the replies! After some thinking, I think I might be able to come up with a jittable version of the function in question which leads into another question. The easiest way for me to do that would be to use fancy indexing on I'll try and post a shortened workable example today. |
Beta Was this translation helpful? Give feedback.
You can use host_callback.
Alternatively: