Skip to content

Commit 86ff7af

Browse files
Cristian GarciaFlax Authors
authored andcommitted
[nnx] fix ToLinen kwargs
PiperOrigin-RevId: 695920522
1 parent d31f290 commit 86ff7af

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

flax/nnx/bridge/wrappers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616
import typing as tp
1717
from typing import Any
1818

19-
from flax import nnx
2019
from flax import linen
20+
from flax import nnx
21+
from flax.core import FrozenDict
2122
from flax.core import meta
2223
from flax.nnx import graph
2324
from flax.nnx.bridge import variables as bv
2425
from flax.nnx.module import GraphDef, Module
26+
from flax.nnx.object import Object
2527
from flax.nnx.rnglib import Rngs
2628
from flax.nnx.statelib import State
27-
from flax.nnx.object import Object
2829
import jax
2930
from jax import tree_util as jtu
3031

@@ -220,7 +221,7 @@ class ToLinen(linen.Module):
220221
"""
221222
nnx_class: tp.Callable[..., Module]
222223
args: tp.Sequence = ()
223-
kwargs: tp.Mapping = dataclasses.field(default_factory=dict)
224+
kwargs: tp.Mapping[str, tp.Any] = FrozenDict({})
224225
skip_rng: bool = False
225226
metadata_type: tp.Type = bv.NNXMeta
226227

@@ -277,4 +278,4 @@ def _update_variables(self, module):
277278
def to_linen(nnx_class: tp.Callable[..., Module], *args,
278279
name: str | None = None, **kwargs):
279280
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
280-
return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name)
281+
return ToLinen(nnx_class, args=args, kwargs=FrozenDict(kwargs), name=name)

0 commit comments

Comments
 (0)