Skip to content

Marking a non-jittable inner function as non-jittable. #10444

Answered by YouJiacheng
ozencgungor asked this question in Q&A
Discussion options

You must be logged in to vote

You can use host_callback.
Alternatively:

import jax
import flax.linen as nn
from typing import Any

Array = Any
def some_non_jit_func(x, a, b):
    y = a*b
    return y[y >= x]

def make_model(x, a, b):
  z = some_non_jit_func(x, a, b) # do anything
  class inner_layer(nn.Module):
      param: Array
      @nn.compact
      def __call__(self, inputs):
          return param * inputs
  
  class wrapper_layer(nn.Module):  
      @nn.nowrap:
      def get_layer(self, param):
          return inner_layer(param)
      
      @nn.compact:
      def __call__(self, inputs):
          layer = self.get_layer(z)
          x = layer(inputs)
          return x
  return wrapper_layer

Replies: 3 comments 3 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by ozencgungor
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
3 replies
@ozencgungor
Comment options

@ozencgungor
Comment options

@VinaySingh561
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants