16
16
import typing as tp
17
17
from typing import Any
18
18
19
- from flax import nnx
20
19
from flax import linen
20
+ from flax import nnx
21
+ from flax .core import FrozenDict
21
22
from flax .core import meta
22
23
from flax .nnx import graph
23
24
from flax .nnx .bridge import variables as bv
24
25
from flax .nnx .module import GraphDef , Module
26
+ from flax .nnx .object import Object
25
27
from flax .nnx .rnglib import Rngs
26
28
from flax .nnx .statelib import State
27
- from flax .nnx .object import Object
28
29
import jax
29
30
from jax import tree_util as jtu
30
31
@@ -220,7 +221,7 @@ class ToLinen(linen.Module):
220
221
"""
221
222
nnx_class : tp .Callable [..., Module ]
222
223
args : tp .Sequence = ()
223
- kwargs : tp .Mapping = dataclasses . field ( default_factory = dict )
224
+ kwargs : tp .Mapping [ str , tp . Any ] = FrozenDict ({} )
224
225
skip_rng : bool = False
225
226
metadata_type : tp .Type = bv .NNXMeta
226
227
@@ -277,4 +278,4 @@ def _update_variables(self, module):
277
278
def to_linen (nnx_class : tp .Callable [..., Module ], * args ,
278
279
name : str | None = None , ** kwargs ):
279
280
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
280
- return ToLinen (nnx_class , args = args , kwargs = kwargs , name = name )
281
+ return ToLinen (nnx_class , args = args , kwargs = FrozenDict ( kwargs ) , name = name )
0 commit comments