Skip to content

lowering / cost analysis of @nnx.jit functions #4094

Open
@cgarciae

Description

@cgarciae

Discussed in #4093

Originally posted by lzanini July 19, 2024
hi, is there a way to use jax's cost analysis api for nnx.jitted functions ? typically, a training step with a module, an optimizer and a metric. it is unclear to me how to lower the inner jax.jitted function.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions