15
15
import textwrap
16
16
from typing import ClassVar , Dict , List , Optional , Tuple , Union
17
17
18
- from pydantic import BaseModel , Field
18
+ from pydantic import BaseModel , Field , computed_field
19
19
20
20
21
21
__all__ = [
33
33
PrintOrderType = ClassVar [List [str ]]
34
34
35
35
36
- class PropertyBaseModel (BaseModel ):
37
- """
38
- https://github.com/samuelcolvin/pydantic/issues/935#issuecomment-1152457432
39
-
40
- Workaround for serializing properties with pydantic until
41
- https://github.com/samuelcolvin/pydantic/issues/935
42
- is solved
43
- """
44
-
45
- @classmethod
46
- def get_properties (cls ):
47
- return [
48
- prop
49
- for prop in dir (cls )
50
- if isinstance (getattr (cls , prop ), property )
51
- and prop not in ("__values__" , "fields" )
52
- ]
53
-
54
- def dict (
55
- self ,
56
- * ,
57
- include : Union ["AbstractSetIntStr" , "MappingIntStrAny" ] = None , # noqa: F821
58
- exclude : Union ["AbstractSetIntStr" , "MappingIntStrAny" ] = None , # noqa: F821
59
- by_alias : bool = False ,
60
- skip_defaults : bool = None ,
61
- exclude_unset : bool = False ,
62
- exclude_defaults : bool = False ,
63
- exclude_none : bool = False ,
64
- ) -> "DictStrAny" : # noqa: F821
65
- attribs = super ().dict (
66
- include = include ,
67
- exclude = exclude ,
68
- by_alias = by_alias ,
69
- skip_defaults = skip_defaults ,
70
- exclude_unset = exclude_unset ,
71
- exclude_defaults = exclude_defaults ,
72
- exclude_none = exclude_none ,
73
- )
74
- props = self .get_properties ()
75
- # Include and exclude properties
76
- if include :
77
- props = [prop for prop in props if prop in include ]
78
- if exclude :
79
- props = [prop for prop in props if prop not in exclude ]
80
-
81
- # Update the attribute dict with the properties
82
- if props :
83
- attribs .update ({prop : getattr (self , prop ) for prop in props })
84
-
85
- return attribs
86
-
87
-
88
36
class NodeCounts (BaseModel ):
89
37
"""
90
38
Pydantic model for specifying the number zero and non-zero operations and the
@@ -114,7 +62,7 @@ class NodeIO(BaseModel):
114
62
)
115
63
116
64
117
- class ZeroNonZeroParams (PropertyBaseModel ):
65
+ class ZeroNonZeroParams (BaseModel ):
118
66
"""
119
67
Pydantic model for specifying the number zero and non-zero operations and the
120
68
associated sparsity
@@ -127,20 +75,22 @@ class ZeroNonZeroParams(PropertyBaseModel):
127
75
description = "The number of parameters whose value is zero" , default = 0
128
76
)
129
77
78
+ @computed_field (repr = True , return_type = Union [int , float ])
130
79
@property
131
80
def sparsity (self ):
132
81
total_values = self .total
133
82
if total_values > 0 :
134
83
return self .zero / total_values
135
84
else :
136
- return 0
85
+ return 0.0
137
86
87
+ @computed_field (repr = True , return_type = int )
138
88
@property
139
89
def total (self ):
140
90
return self .non_zero + self .zero
141
91
142
92
143
- class DenseSparseOps (PropertyBaseModel ):
93
+ class DenseSparseOps (BaseModel ):
144
94
"""
145
95
Pydantic model for specifying the number dense and sparse operations and the
146
96
associated operation sparsity
@@ -155,6 +105,7 @@ class DenseSparseOps(PropertyBaseModel):
155
105
default = 0 ,
156
106
)
157
107
108
+ @computed_field (repr = True , return_type = Union [int , float ])
158
109
@property
159
110
def sparsity (self ):
160
111
total_ops = self .sparse + self .dense
0 commit comments