13
13
# limitations under the License.
14
14
import logging
15
15
import textwrap
16
- from typing import Dict , List , Optional , Tuple , Union
16
+ from typing import ClassVar , Dict , List , Optional , Tuple , Union
17
17
18
18
from pydantic import BaseModel , Field
19
19
30
30
]
31
31
32
32
_LOGGER = logging .getLogger (__name__ )
33
+ PrintOrderType = ClassVar [List [str ]]
33
34
34
35
35
36
class PropertyBaseModel (BaseModel ):
@@ -104,11 +105,12 @@ class NodeIO(BaseModel):
104
105
105
106
name : str = Field (description = "Name of the input/output in onnx model graph" )
106
107
shape : Optional [List [Union [None , int ]]] = Field (
108
+ None ,
107
109
description = "Shape of the input/output in onnx model graph (assuming a "
108
- "batch size of 1)"
110
+ "batch size of 1)" ,
109
111
)
110
112
dtype : Optional [str ] = Field (
111
- description = "Data type of the values from the input/output"
113
+ None , description = "Data type of the values from the input/output"
112
114
)
113
115
114
116
@@ -220,9 +222,9 @@ class ParameterComponent(BaseModel):
220
222
"""
221
223
222
224
alias : str = Field (description = "The type of parameter (weight, bias)" )
223
- name : Optional [str ] = Field (description = "The name of the parameter" )
225
+ name : Optional [str ] = Field (None , description = "The name of the parameter" )
224
226
shape : Optional [List [Union [None , int ]]] = Field (
225
- description = "The shape of the parameter"
227
+ None , description = "The shape of the parameter"
226
228
)
227
229
parameter_summary : ParameterSummary = Field (
228
230
description = "A summary of the parameter"
@@ -235,7 +237,7 @@ class Entry(BaseModel):
235
237
A BaseModel with subtraction and pretty_print support
236
238
"""
237
239
238
- _print_order : List [ str ] = []
240
+ _print_order : PrintOrderType = []
239
241
240
242
def __sub__ (self , other ):
241
243
"""
@@ -306,7 +308,7 @@ class BaseEntry(Entry):
306
308
sparsity : float
307
309
quantized : float
308
310
309
- _print_order = ["sparsity" , "quantized" ]
311
+ _print_order : PrintOrderType = ["sparsity" , "quantized" ]
310
312
311
313
312
314
class NamedEntry (BaseEntry ):
@@ -318,7 +320,7 @@ class NamedEntry(BaseEntry):
318
320
total : float
319
321
size : int
320
322
321
- _print_order = ["name" , "total" , "size" ] + BaseEntry ._print_order
323
+ _print_order : PrintOrderType = ["name" , "total" , "size" ] + BaseEntry ._print_order
322
324
323
325
324
326
class TypedEntry (BaseEntry ):
@@ -329,7 +331,7 @@ class TypedEntry(BaseEntry):
329
331
type : str
330
332
size : int
331
333
332
- _print_order = ["type" , "size" ] + BaseEntry ._print_order
334
+ _print_order : PrintOrderType = ["type" , "size" ] + BaseEntry ._print_order
333
335
334
336
335
337
class ModelEntry (BaseEntry ):
@@ -338,7 +340,7 @@ class ModelEntry(BaseEntry):
338
340
"""
339
341
340
342
model : str
341
- _print_order = ["model" ] + BaseEntry ._print_order
343
+ _print_order : PrintOrderType = ["model" ] + BaseEntry ._print_order
342
344
343
345
344
346
class SizedModelEntry (ModelEntry ):
@@ -347,8 +349,8 @@ class SizedModelEntry(ModelEntry):
347
349
"""
348
350
349
351
count : int
350
- size : int
351
- _print_order = ModelEntry ._print_order + ["count" , "size" ]
352
+ size : Union [ int , float ]
353
+ _print_order : PrintOrderType = ModelEntry ._print_order + ["count" , "size" ]
352
354
353
355
354
356
class PerformanceEntry (BaseEntry ):
@@ -361,7 +363,7 @@ class PerformanceEntry(BaseEntry):
361
363
throughput : float
362
364
supported_graph : float
363
365
364
- _print_order = [
366
+ _print_order : PrintOrderType = [
365
367
"model" ,
366
368
"latency" ,
367
369
"throughput" ,
@@ -377,7 +379,7 @@ class NodeTimingEntry(Entry):
377
379
node_name : str
378
380
avg_runtime : float
379
381
380
- _print_order = [
382
+ _print_order : PrintOrderType = [
381
383
"node_name" ,
382
384
"avg_runtime" ,
383
385
] + Entry ._print_order
0 commit comments