Skip to content

Commit f349af4

Browse files
authored
Update pydantic support (#483)
* Changes from bump-pydantic * update setup.py * Style + use v1 basemodel for FeatureStatus * Update pydantic to <2.8 * Delete extraneous file
1 parent 393db75 commit f349af4

File tree

11 files changed

+56
-50
lines changed

11 files changed

+56
-50
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
"pyyaml>=5.1.0",
5252
"requests>=2.0.0",
5353
"tqdm>=4.0.0",
54-
"pydantic>=1.8.2,<2.0.0",
54+
"pydantic>=2.0.0,<2.8",
5555
"click>=7.1.2,!=8.0.0", # latest version < 8.0 + blocked version with reported bug
5656
"protobuf>=3.12.2",
5757
"pandas>1.3",

src/sparsezoo/analyze_v1/analysis.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import numpy
2828
import yaml
2929
from onnx import ModelProto, NodeProto
30-
from pydantic import BaseModel, Field, PositiveFloat, PositiveInt
30+
from pydantic import BaseModel, ConfigDict, Field, PositiveFloat, PositiveInt
3131

3232
from sparsezoo import Model
3333
from sparsezoo.analyze_v1.utils.helpers import numpy_array_representer
@@ -200,6 +200,7 @@ class BenchmarkScenario(YAMLSerializableBaseModel):
200200
)
201201

202202
num_cores: Optional[int] = Field(
203+
None,
203204
description="The number of cores to use for benchmarking, can also take "
204205
"in a `None` value, which represents all cores",
205206
)
@@ -311,9 +312,7 @@ class NodeAnalysis(YAMLSerializableBaseModel):
311312
zero_point: Union[int, numpy.ndarray] = Field(
312313
description="Node zero point for quantization, default zero"
313314
)
314-
315-
class Config:
316-
arbitrary_types_allowed = True
315+
model_config = ConfigDict(arbitrary_types_allowed=True)
317316

318317
@classmethod
319318
def from_node(

src/sparsezoo/analyze_v1/utils/models.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import logging
1515
import textwrap
16-
from typing import Dict, List, Optional, Tuple, Union
16+
from typing import ClassVar, Dict, List, Optional, Tuple, Union
1717

1818
from pydantic import BaseModel, Field
1919

@@ -30,6 +30,7 @@
3030
]
3131

3232
_LOGGER = logging.getLogger(__name__)
33+
PrintOrderType = ClassVar[List[str]]
3334

3435

3536
class PropertyBaseModel(BaseModel):
@@ -104,11 +105,12 @@ class NodeIO(BaseModel):
104105

105106
name: str = Field(description="Name of the input/output in onnx model graph")
106107
shape: Optional[List[Union[None, int]]] = Field(
108+
None,
107109
description="Shape of the input/output in onnx model graph (assuming a "
108-
"batch size of 1)"
110+
"batch size of 1)",
109111
)
110112
dtype: Optional[str] = Field(
111-
description="Data type of the values from the input/output"
113+
None, description="Data type of the values from the input/output"
112114
)
113115

114116

@@ -220,9 +222,9 @@ class ParameterComponent(BaseModel):
220222
"""
221223

222224
alias: str = Field(description="The type of parameter (weight, bias)")
223-
name: Optional[str] = Field(description="The name of the parameter")
225+
name: Optional[str] = Field(None, description="The name of the parameter")
224226
shape: Optional[List[Union[None, int]]] = Field(
225-
description="The shape of the parameter"
227+
None, description="The shape of the parameter"
226228
)
227229
parameter_summary: ParameterSummary = Field(
228230
description="A summary of the parameter"
@@ -235,7 +237,7 @@ class Entry(BaseModel):
235237
A BaseModel with subtraction and pretty_print support
236238
"""
237239

238-
_print_order: List[str] = []
240+
_print_order: PrintOrderType = []
239241

240242
def __sub__(self, other):
241243
"""
@@ -306,7 +308,7 @@ class BaseEntry(Entry):
306308
sparsity: float
307309
quantized: float
308310

309-
_print_order = ["sparsity", "quantized"]
311+
_print_order: PrintOrderType = ["sparsity", "quantized"]
310312

311313

312314
class NamedEntry(BaseEntry):
@@ -318,7 +320,7 @@ class NamedEntry(BaseEntry):
318320
total: float
319321
size: int
320322

321-
_print_order = ["name", "total", "size"] + BaseEntry._print_order
323+
_print_order: PrintOrderType = ["name", "total", "size"] + BaseEntry._print_order
322324

323325

324326
class TypedEntry(BaseEntry):
@@ -329,7 +331,7 @@ class TypedEntry(BaseEntry):
329331
type: str
330332
size: int
331333

332-
_print_order = ["type", "size"] + BaseEntry._print_order
334+
_print_order: PrintOrderType = ["type", "size"] + BaseEntry._print_order
333335

334336

335337
class ModelEntry(BaseEntry):
@@ -338,7 +340,7 @@ class ModelEntry(BaseEntry):
338340
"""
339341

340342
model: str
341-
_print_order = ["model"] + BaseEntry._print_order
343+
_print_order: PrintOrderType = ["model"] + BaseEntry._print_order
342344

343345

344346
class SizedModelEntry(ModelEntry):
@@ -347,8 +349,8 @@ class SizedModelEntry(ModelEntry):
347349
"""
348350

349351
count: int
350-
size: int
351-
_print_order = ModelEntry._print_order + ["count", "size"]
352+
size: Union[int, float]
353+
_print_order: PrintOrderType = ModelEntry._print_order + ["count", "size"]
352354

353355

354356
class PerformanceEntry(BaseEntry):
@@ -361,7 +363,7 @@ class PerformanceEntry(BaseEntry):
361363
throughput: float
362364
supported_graph: float
363365

364-
_print_order = [
366+
_print_order: PrintOrderType = [
365367
"model",
366368
"latency",
367369
"throughput",
@@ -377,7 +379,7 @@ class NodeTimingEntry(Entry):
377379
node_name: str
378380
avg_runtime: float
379381

380-
_print_order = [
382+
_print_order: PrintOrderType = [
381383
"node_name",
382384
"avg_runtime",
383385
] + Entry._print_order

src/sparsezoo/analyze_v2/schemas/distribution_analysis.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,34 @@
1414

1515
from typing import Dict, List, Optional
1616

17-
from pydantic import BaseModel, Field, validator
17+
from pydantic import BaseModel, Field, field_validator
1818

1919
from sparsezoo.analyze_v2.schemas.utils import type_validator
2020

2121

2222
class DistributionAnalysisSchema(BaseModel):
2323
counts: Optional[int] = Field(..., description="Total number of parameters")
24-
mean: Optional[float]
25-
median: Optional[float]
26-
modes: Optional[List]
27-
sum_val: Optional[float]
28-
min_val: Optional[float]
29-
max_val: Optional[float]
30-
percentiles: Optional[Dict[float, float]]
31-
std_dev: Optional[float]
32-
skewness: Optional[float]
33-
kurtosis: Optional[float]
34-
entropy: Optional[float]
35-
bin_width: Optional[float]
36-
num_bins: Optional[int]
24+
mean: Optional[float] = None
25+
median: Optional[float] = None
26+
modes: Optional[List] = None
27+
sum_val: Optional[float] = None
28+
min_val: Optional[float] = None
29+
max_val: Optional[float] = None
30+
percentiles: Optional[Dict[float, float]] = None
31+
std_dev: Optional[float] = None
32+
skewness: Optional[float] = None
33+
kurtosis: Optional[float] = None
34+
entropy: Optional[float] = None
35+
bin_width: Optional[float] = None
36+
num_bins: Optional[int] = None
3737
hist: Optional[List[float]] = Field(
3838
..., description="Frequency of the parameters, with respect to the bin edges"
3939
)
4040
bin_edges: Optional[List[float]] = Field(
4141
..., description="Lower bound edges of each bin"
4242
)
4343

44-
@validator("*", pre=True)
44+
@field_validator("*", mode="before")
45+
@classmethod
4546
def validate_types(cls, value):
4647
return type_validator(value)

src/sparsezoo/analyze_v2/schemas/node_analysis.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typing import List, Optional
1616

17-
from pydantic import BaseModel, Field, validator
17+
from pydantic import BaseModel, Field, field_validator
1818

1919
from sparsezoo.analyze_v2.schemas.memory_access_analysis import (
2020
MemoryAccessAnalysisSchema,
@@ -33,6 +33,7 @@ class NodeAnalysisSchema(BaseModel):
3333
params: ParameterAnalysisSchema
3434
mem_access: MemoryAccessAnalysisSchema
3535

36-
@validator("input", "output", pre=True)
36+
@field_validator("input", "output", mode="before")
37+
@classmethod
3738
def validate_types(cls, value):
3839
return [val for val in value]

src/sparsezoo/analyze_v2/schemas/quantization_analysis.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typing import Optional
1616

17-
from pydantic import BaseModel, Field, validator
17+
from pydantic import BaseModel, Field, field_validator, validator
1818

1919
from sparsezoo.analyze_v2.schemas.utils import type_validator
2020

@@ -40,7 +40,8 @@ class QuantizationSummaryAnalysisSchema(BaseModel):
4040
None, description="Percentage of counts_sparse over counts"
4141
)
4242

43-
@validator("*", pre=True)
43+
@field_validator("*", mode="before")
44+
@classmethod
4445
def validate_types(cls, value):
4546
return type_validator(value)
4647

src/sparsezoo/analyze_v2/schemas/sparsity_analysis.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typing import Optional
1616

17-
from pydantic import BaseModel, Field, validator
17+
from pydantic import BaseModel, Field, field_validator, validator
1818

1919
from sparsezoo.analyze_v2.schemas.utils import type_validator
2020

@@ -28,7 +28,8 @@ class SparsitySummaryAnalysisSchema(BaseModel):
2828
None, description="Percentage of counts_sparse over counts"
2929
)
3030

31-
@validator("*", pre=True)
31+
@field_validator("*", mode="before")
32+
@classmethod
3233
def validate_types(cls, value):
3334
return type_validator(value)
3435

src/sparsezoo/evaluation/results.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ class Metric(BaseModel):
7171

7272

7373
class Dataset(BaseModel):
74-
type: Optional[str] = Field(description="Type of dataset")
74+
type: Optional[str] = Field(None, description="Type of dataset")
7575
name: str = Field(description="Name of the dataset")
76-
config: Any = Field(description="Configuration for the dataset")
77-
split: Optional[str] = Field(description="Split of the dataset")
76+
config: Any = Field(None, description="Configuration for the dataset")
77+
split: Optional[str] = Field(None, description="Split of the dataset")
7878

7979

8080
class EvalSample(BaseModel):
81-
input: Any = Field(description="Sample input to the model")
82-
output: Any = Field(description="Sample output from the model")
81+
input: Any = Field(None, description="Sample input to the model")
82+
output: Any = Field(None, description="Sample output from the model")
8383

8484

8585
class Evaluation(BaseModel):
@@ -90,7 +90,7 @@ class Evaluation(BaseModel):
9090
dataset: Dataset = Field(description="Dataset that the evaluation was performed on")
9191
metrics: List[Metric] = Field(description="List of metrics for the evaluation")
9292
samples: Optional[List[EvalSample]] = Field(
93-
description="List of samples for the evaluation"
93+
None, description="List of samples for the evaluation"
9494
)
9595

9696

@@ -99,8 +99,9 @@ class Result(BaseModel):
9999
description="Evaluation result represented in the unified, structured format"
100100
)
101101
raw: Any = Field(
102+
None,
102103
description="Evaluation result represented in the raw format "
103-
"(characteristic for the specific evaluation integration)"
104+
"(characteristic for the specific evaluation integration)",
104105
)
105106

106107

src/sparsezoo/utils/standardization/feature_status_page.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import List
2323

2424
import yaml
25-
from pydantic import BaseModel, Field
25+
from pydantic.v1 import BaseModel, Field
2626

2727
from sparsezoo.utils.standardization.feature_status import FeatureStatus
2828
from sparsezoo.utils.standardization.feature_status_table import FeatureStatusTable

src/sparsezoo/utils/standardization/feature_status_table.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from abc import ABC, abstractmethod
2020
from typing import List, Tuple
2121

22-
from pydantic import BaseModel, Field
22+
from pydantic.v1 import BaseModel, Field
2323

2424
from sparsezoo.utils.standardization.feature_status import FeatureStatus
2525
from sparsezoo.utils.standardization.markdown_utils import create_markdown_table

tests/sparsezoo/utils/standardization/test_feature_status_page.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from pydantic import Field
15+
from pydantic.v1 import Field
1616

1717
from sparsezoo.utils.standardization import (
1818
FeatureStatus,

0 commit comments

Comments
 (0)