Skip to content

Commit 6ae1b94

Browse files
authored
[BugFix] Fix Serialization of Computed Properties in BaseModel (#485)
* Add failing test * Fix computed field serialization
1 parent f349af4 commit 6ae1b94

File tree

3 files changed

+44
-59
lines changed

3 files changed

+44
-59
lines changed

src/sparsezoo/analyze_v1/analysis.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ class YAMLSerializableBaseModel(BaseModel):
103103
A BaseModel that adds a .yaml(...) function to all child classes
104104
"""
105105

106+
model_config = ConfigDict(protected_namespaces=())
107+
108+
def dict(self, *args, **kwargs) -> Dict[str, Any]:
109+
# alias for model_dump for pydantic v2 upgrade
110+
# to allow for easier migration
111+
return self.model_dump(*args, **kwargs)
112+
106113
def yaml(self, file_path: Optional[str] = None) -> Union[str, None]:
107114
"""
108115
:param file_path: optional file path to save yaml to
@@ -111,7 +118,7 @@ def yaml(self, file_path: Optional[str] = None) -> Union[str, None]:
111118
"""
112119
file_stream = None if file_path is None else open(file_path, "w")
113120
ret = yaml.dump(
114-
self.dict(), stream=file_stream, allow_unicode=True, sort_keys=False
121+
self.model_dump(), stream=file_stream, allow_unicode=True, sort_keys=False
115122
)
116123

117124
if file_stream is not None:
@@ -127,7 +134,7 @@ def parse_yaml_file(cls, file_path: str):
127134
"""
128135
with open(file_path, "r") as file:
129136
dict_obj = yaml.safe_load(file)
130-
return cls.parse_obj(dict_obj)
137+
return cls.model_validate(dict_obj)
131138

132139
@classmethod
133140
def parse_yaml_raw(cls, yaml_raw: str):
@@ -136,7 +143,7 @@ def parse_yaml_raw(cls, yaml_raw: str):
136143
:return: instance of ModelAnalysis class
137144
"""
138145
dict_obj = yaml.safe_load(yaml_raw) # unsafe: needs to load numpy
139-
return cls.parse_obj(dict_obj)
146+
return cls.model_validate(dict_obj)
140147

141148

142149
@dataclass

src/sparsezoo/analyze_v1/utils/models.py

+7-56
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import textwrap
1616
from typing import ClassVar, Dict, List, Optional, Tuple, Union
1717

18-
from pydantic import BaseModel, Field
18+
from pydantic import BaseModel, Field, computed_field
1919

2020

2121
__all__ = [
@@ -33,58 +33,6 @@
3333
PrintOrderType = ClassVar[List[str]]
3434

3535

36-
class PropertyBaseModel(BaseModel):
37-
"""
38-
https://github.com/samuelcolvin/pydantic/issues/935#issuecomment-1152457432
39-
40-
Workaround for serializing properties with pydantic until
41-
https://github.com/samuelcolvin/pydantic/issues/935
42-
is solved
43-
"""
44-
45-
@classmethod
46-
def get_properties(cls):
47-
return [
48-
prop
49-
for prop in dir(cls)
50-
if isinstance(getattr(cls, prop), property)
51-
and prop not in ("__values__", "fields")
52-
]
53-
54-
def dict(
55-
self,
56-
*,
57-
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, # noqa: F821
58-
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, # noqa: F821
59-
by_alias: bool = False,
60-
skip_defaults: bool = None,
61-
exclude_unset: bool = False,
62-
exclude_defaults: bool = False,
63-
exclude_none: bool = False,
64-
) -> "DictStrAny": # noqa: F821
65-
attribs = super().dict(
66-
include=include,
67-
exclude=exclude,
68-
by_alias=by_alias,
69-
skip_defaults=skip_defaults,
70-
exclude_unset=exclude_unset,
71-
exclude_defaults=exclude_defaults,
72-
exclude_none=exclude_none,
73-
)
74-
props = self.get_properties()
75-
# Include and exclude properties
76-
if include:
77-
props = [prop for prop in props if prop in include]
78-
if exclude:
79-
props = [prop for prop in props if prop not in exclude]
80-
81-
# Update the attribute dict with the properties
82-
if props:
83-
attribs.update({prop: getattr(self, prop) for prop in props})
84-
85-
return attribs
86-
87-
8836
class NodeCounts(BaseModel):
8937
"""
9038
Pydantic model for specifying the number zero and non-zero operations and the
@@ -114,7 +62,7 @@ class NodeIO(BaseModel):
11462
)
11563

11664

117-
class ZeroNonZeroParams(PropertyBaseModel):
65+
class ZeroNonZeroParams(BaseModel):
11866
"""
11967
Pydantic model for specifying the number zero and non-zero operations and the
12068
associated sparsity
@@ -127,20 +75,22 @@ class ZeroNonZeroParams(PropertyBaseModel):
12775
description="The number of parameters whose value is zero", default=0
12876
)
12977

78+
@computed_field(repr=True, return_type=Union[int, float])
13079
@property
13180
def sparsity(self):
13281
total_values = self.total
13382
if total_values > 0:
13483
return self.zero / total_values
13584
else:
136-
return 0
85+
return 0.0
13786

87+
@computed_field(repr=True, return_type=int)
13888
@property
13989
def total(self):
14090
return self.non_zero + self.zero
14191

14292

143-
class DenseSparseOps(PropertyBaseModel):
93+
class DenseSparseOps(BaseModel):
14494
"""
14595
Pydantic model for specifying the number dense and sparse operations and the
14696
associated operation sparsity
@@ -155,6 +105,7 @@ class DenseSparseOps(PropertyBaseModel):
155105
default=0,
156106
)
157107

108+
@computed_field(repr=True, return_type=Union[int, float])
158109
@property
159110
def sparsity(self):
160111
total_ops = self.sparse + self.dense
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from sparsezoo.analyze_v1.utils.models import DenseSparseOps, ZeroNonZeroParams
18+
19+
20+
@pytest.mark.parametrize("model", [DenseSparseOps, ZeroNonZeroParams])
21+
@pytest.mark.parametrize("computed_fields", [["sparsity"]])
22+
def test_model_dump_has_computed_fields(model, computed_fields):
23+
model = model()
24+
model_dict = model.model_dump()
25+
for computed_field in computed_fields:
26+
assert computed_field in model_dict
27+
assert model_dict[computed_field] == getattr(model, computed_field)

0 commit comments

Comments
 (0)