14
14
15
15
"""Utilities for defining custom classes that can be used with jax transformations."""
16
16
17
+ from collections .abc import Callable
17
18
import dataclasses
18
- from typing import TypeVar
19
+ import functools
20
+ from typing import TypeVar , overload
19
21
20
22
import jax
21
23
from typing_extensions import (
@@ -33,7 +35,22 @@ def field(pytree_node=True, *, metadata=None, **kwargs):
33
35
34
36
35
37
@dataclass_transform (field_specifiers = (field ,)) # type: ignore[literal-required]
38
+ @overload
36
39
def dataclass (clz : _T , ** kwargs ) -> _T :
40
+ ...
41
+
42
+
43
+ @dataclass_transform (field_specifiers = (field ,)) # type: ignore[literal-required]
44
+ @overload
45
+ def dataclass (** kwargs ) -> Callable [[_T ], _T ]:
46
+ ...
47
+
48
+
49
+ @dataclass_transform (field_specifiers = (field ,)) # type: ignore[literal-required]
50
+ def dataclass (
51
+ clz : _T | None = None ,
52
+ ** kwargs ,
53
+ ) -> _T | Callable [[_T ], _T ]:
37
54
"""Create a class which can be passed to functional transformations.
38
55
39
56
.. note::
@@ -99,9 +116,15 @@ class method that provides the smart constructor.
99
116
100
117
Args:
101
118
clz: the class that will be transformed by the decorator.
119
+ **kwargs: arguments to pass to the dataclass constructor.
120
+
102
121
Returns:
103
122
The new class.
104
123
"""
124
+ # Support passing arguments to the decorator (e.g. @dataclass(kw_only=True))
125
+ if clz is None :
126
+ return functools .partial (dataclass , ** kwargs )
127
+
105
128
# check if already a flax dataclass
106
129
if '_flax_dataclass' in clz .__dict__ :
107
130
return clz
@@ -119,7 +142,7 @@ class method that provides the smart constructor.
119
142
meta_fields .append (field_info .name )
120
143
121
144
def replace (self , ** updates ):
122
- """ " Returns a new object replacing the specified fields with new values."""
145
+ """Returns a new object replacing the specified fields with new values."""
123
146
return dataclasses .replace (self , ** updates )
124
147
125
148
data_clz .replace = replace
0 commit comments