Skip to content

Commit c9cebee

Browse files
QwlouseFlax Authors
authored andcommitted
support passing arguments directly to the struct.dataclass decorator
PiperOrigin-RevId: 694517927
1 parent 0f631a2 commit c9cebee

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

flax/struct.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
"""Utilities for defining custom classes that can be used with jax transformations."""
1616

17+
from collections.abc import Callable
1718
import dataclasses
18-
from typing import TypeVar
19+
import functools
20+
from typing import TypeVar, overload
1921

2022
import jax
2123
from typing_extensions import (
@@ -33,7 +35,22 @@ def field(pytree_node=True, *, metadata=None, **kwargs):
3335

3436

3537
@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
38+
@overload
3639
def dataclass(clz: _T, **kwargs) -> _T:
40+
...
41+
42+
43+
@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
44+
@overload
45+
def dataclass(**kwargs) -> Callable[[_T], _T]:
46+
...
47+
48+
49+
@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
50+
def dataclass(
51+
clz: _T | None = None,
52+
**kwargs,
53+
) -> _T | Callable[[_T], _T]:
3754
"""Create a class which can be passed to functional transformations.
3855
3956
.. note::
@@ -99,9 +116,15 @@ class method that provides the smart constructor.
99116
100117
Args:
101118
clz: the class that will be transformed by the decorator.
119+
**kwargs: arguments to pass to the dataclass constructor.
120+
102121
Returns:
103122
The new class.
104123
"""
124+
# Support passing arguments to the decorator (e.g. @dataclass(kw_only=True))
125+
if clz is None:
126+
return functools.partial(dataclass, **kwargs)
127+
105128
# check if already a flax dataclass
106129
if '_flax_dataclass' in clz.__dict__:
107130
return clz
@@ -119,7 +142,7 @@ class method that provides the smart constructor.
119142
meta_fields.append(field_info.name)
120143

121144
def replace(self, **updates):
122-
""" "Returns a new object replacing the specified fields with new values."""
145+
"""Returns a new object replacing the specified fields with new values."""
123146
return dataclasses.replace(self, **updates)
124147

125148
data_clz.replace = replace

tests/struct_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Tests for flax.struct."""
1616

1717
import dataclasses
18-
import functools
1918
from typing import Any
2019

2120
import jax
@@ -49,8 +48,8 @@ def test_mutation(self):
4948
p.y = 3
5049

5150
def test_slots(self):
52-
slots_dataclass = functools.partial(struct.dataclass, frozen=False, slots=True)
53-
@slots_dataclass
51+
52+
@struct.dataclass(frozen=False, slots=True)
5453
class SlotsPoint:
5554
x: float
5655
y: float
@@ -100,7 +99,7 @@ def test_kw_only(self, mode):
10099
class A:
101100
a: int = 1
102101

103-
@functools.partial(struct.dataclass, kw_only=True)
102+
@struct.dataclass(kw_only=True)
104103
class B(A):
105104
b: int
106105
elif mode == 'pytreenode':
@@ -139,7 +138,7 @@ def test_mutable(self, mode):
139138
class A:
140139
a: int = 1
141140

142-
@functools.partial(struct.dataclass, frozen=False)
141+
@struct.dataclass(frozen=False)
143142
class B:
144143
b: int = 1
145144
elif mode == 'pytreenode':

0 commit comments

Comments
 (0)