Skip to content

Commit 0c1d1c0

Browse files
authored
chore: sys-model (#20)
1 parent d2f99b7 commit 0c1d1c0

File tree

8 files changed

+90
-83
lines changed

8 files changed

+90
-83
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ build-backend = "hatchling.build"
99

1010
[project]
1111
name = "swankit"
12-
version = "0.1.2b1"
12+
version = "0.1.2b2"
1313
dynamic = ["readme", "dependencies"]
1414
description = "Base toolkit for SwanLab"
1515
license = "Apache-2.0"

swankit/callback/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,4 @@ def __str__(self) -> str:
107107
pass
108108

109109

110-
__all__ = ["SwanKitCallback", "MediaBuffer", "MetricInfo", "ColumnInfo", "OperateErrorInfo", "RuntimeInfo"]
110+
__all__ = ["SwanKitCallback", "models"]

swankit/callback/models/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
@Description:
88
与回调函数通信时的模型
99
"""
10-
from .key import MediaBuffer, MetricInfo, ColumnInfo, MetricErrorInfo
10+
from .key import MediaBuffer, MetricInfo, ColumnInfo, MetricErrorInfo, ColumnClass, SectionType, ColumnConfig, YRange
1111
from .error import OperateErrorInfo
1212
from .runtime import RuntimeInfo
1313

@@ -18,4 +18,8 @@
1818
"MetricErrorInfo",
1919
"OperateErrorInfo",
2020
"RuntimeInfo",
21+
"ColumnClass",
22+
"SectionType",
23+
"ColumnConfig",
24+
"YRange",
2125
]

swankit/callback/models/key.py

+31-16
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,24 @@
77
@Description:
88
与Key相关的回调函数触发时的模型
99
"""
10-
from typing import Union, Optional, Dict, List, Literal
11-
from swankit.core import ChartType, ParseErrorInfo, MediaBuffer
10+
from typing import Union, Optional, Dict, List, Literal, Tuple, TypedDict
11+
12+
from swankit.core import ChartType, ParseErrorInfo, MediaBuffer, ChartReference
1213
from urllib.parse import quote
1314
import os
1415

16+
ColumnClass = Literal["CUSTOM", "SYSTEM"]
17+
SectionType = Literal["PINNED", "HIDDEN", "PUBLIC", "SYSTEM"]
18+
YRange = Optional[Tuple[Optional[float], Optional[float]]]
19+
20+
21+
class ColumnConfig(TypedDict):
22+
"""
23+
列信息配置
24+
"""
25+
26+
y_range: YRange
27+
1528

1629
class ColumnInfo:
1730
"""
@@ -21,22 +34,23 @@ class ColumnInfo:
2134
def __init__(
2235
self,
2336
key: str,
24-
key_id: str,
25-
key_name: str,
26-
key_class: Literal["CUSTOM", "SYSTEM"],
37+
kid: str,
38+
name: str,
39+
cls: ColumnClass,
2740
chart_type: ChartType,
28-
chart_reference: Literal["step", "time"],
41+
chart_reference: ChartReference,
2942
section_name: Optional[str],
43+
section_type: SectionType,
3044
section_sort: Optional[int] = None,
3145
error: Optional[ParseErrorInfo] = None,
32-
config: Optional[Dict] = None,
46+
config: Optional[ColumnConfig] = None,
3347
):
3448
"""
3549
生成的列信息对象
36-
:param key: 生成的列名称
37-
:param key_id: 当前实验下,列的唯一id,与保存路径等信息有关
38-
:param key_name: key的别名
39-
:param key_class: 列的类型,CUSTOM为自定义列,SYSTEM为系统生成列
50+
:param key: 生成的列名称,作为索引键值
51+
:param kid: 当前实验下,列的唯一id,与保存路径等信息有关,与云端请求无关
52+
:param name: 列的别名
53+
:param cls: 列的类型,CUSTOM为自定义列,SYSTEM为系统生成列
4054
:param chart_type: 列对应的图表类型
4155
:param chart_reference: 这个列对应图表的参考系,step为步数,time为时间
4256
:param section_name: 列的组名
@@ -45,18 +59,19 @@ def __init__(
4559
:param config: 列的额外配置信息
4660
"""
4761
self.key = key
48-
self.key_id = key_id
49-
self.key_name = key_name
50-
self.key_class = key_class
62+
self.kid = kid
63+
self.name = name
64+
self.cls = cls
5165

5266
self.chart_type = chart_type
5367
self.chart_reference = chart_reference
5468

5569
self.section_name = section_name
5670
self.section_sort = section_sort
71+
self.section_type = section_type
5772

5873
self.error = error
59-
self.config = config if config is not None else {}
74+
self.config = config
6075

6176
@property
6277
def got(self):
@@ -124,7 +139,7 @@ def __init__(
124139
self.metric_summary = metric_summary
125140
self.metric_step = metric_step
126141
self.metric_epoch = metric_epoch
127-
_id = self.column_info.key_id
142+
_id = self.column_info.kid
128143
self.metric_file_path = None if self.is_error else os.path.join(swanlab_logdir, _id, metric_file_name)
129144
self.summary_file_path = None if self.is_error else os.path.join(swanlab_logdir, _id, self.__SUMMARY_NAME)
130145
self.swanlab_media_dir = swanlab_media_dir

swankit/callback/models/runtime.py

+7-17
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
运行时信息模型
99
"""
1010
from abc import ABC, abstractmethod
11-
from typing import Optional, Any
11+
from typing import Optional
1212
import json
1313
import yaml
1414
import os
@@ -108,13 +108,7 @@ def to_dict(self):
108108
# 没有在__init__中直接修改是因为可能会有其他地方需要原始数据,并且会丢失一些性能
109109
if self.__data is not None:
110110
return self.__data
111-
self.__data = {
112-
k: {
113-
"value": v,
114-
"sort": i,
115-
"desc": ""
116-
} for i, (k, v) in enumerate(self.info.items())
117-
}
111+
self.__data = {k: {"value": v, "sort": i, "desc": ""} for i, (k, v) in enumerate(self.info.items())}
118112
return self.__data
119113

120114

@@ -130,14 +124,10 @@ def __init__(self, requirements: str = None, metadata: dict = None, config: dict
130124
:param metadata: 系统信息
131125
:param config: 上传的配置信息
132126
"""
133-
self.requirements: Optional[RequirementInfo] = RequirementInfo(
134-
requirements
135-
) if requirements is not None else None
127+
self.requirements: Optional[RequirementInfo] = (
128+
RequirementInfo(requirements) if requirements is not None else None
129+
)
136130

137-
self.metadata: Optional[MetadataInfo] = MetadataInfo(
138-
metadata
139-
) if metadata is not None else None
131+
self.metadata: Optional[MetadataInfo] = MetadataInfo(metadata) if metadata is not None else None
140132

141-
self.config: Optional[ConfigInfo] = ConfigInfo(
142-
config
143-
) if config is not None else None
133+
self.config: Optional[ConfigInfo] = ConfigInfo(config) if config is not None else None

swankit/core/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
@Description:
88
核心解析模块工具
99
"""
10-
from .data import BaseType, MediaType, DataSuite, MediaBuffer, ParseResult, ParseErrorInfo
10+
from .data import BaseType, MediaType, DataSuite, MediaBuffer, ParseResult, ParseErrorInfo, ChartReference
1111
from .settings import SwanLabSharedSettings
1212

1313
ChartType = BaseType.Chart
@@ -20,5 +20,6 @@
2020
"MediaBuffer",
2121
"ParseResult",
2222
"ParseErrorInfo",
23-
"SwanLabSharedSettings"
23+
"SwanLabSharedSettings",
24+
"ChartReference",
2425
]

swankit/core/data.py

+18-25
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
@Description:
88
数据处理模型
99
"""
10-
from typing import List, Dict, Optional, ByteString, Union, Tuple
10+
from typing import List, Dict, Optional, ByteString, Union, Tuple, Literal
1111
from abc import ABC, abstractmethod
1212
from enum import Enum
1313
from io import BytesIO
1414
import hashlib
1515
import math
1616
import io
1717

18+
ChartReference = Literal["STEP", "TIME"]
19+
1820

1921
class DataSuite:
2022
"""
@@ -172,14 +174,6 @@ def get_section(self) -> str:
172174
"""
173175
return "default"
174176

175-
# noinspection PyMethodMayBeStatic
176-
def get_config(self) -> Optional[Dict]:
177-
"""
178-
获取图表的config配置信息,应该返回一个字典,或者为None
179-
为None时代表不需要配置
180-
"""
181-
return None
182-
183177
# noinspection PyMethodMayBeStatic
184178
def get_more(self) -> Optional[Dict]:
185179
"""
@@ -193,6 +187,7 @@ class MediaType(BaseType): # noqa
193187
"""
194188
媒体类型,用于区分标量和媒体,不用做实例化,应该由子类继承
195189
"""
190+
196191
pass
197192

198193

@@ -220,30 +215,28 @@ class ParseResult:
220215
"""
221216

222217
def __init__(
223-
self,
224-
section: str = None,
225-
chart: BaseType.Chart = None,
226-
data: Union[List[str], float] = None,
227-
config: Optional[List[Dict]] = None,
228-
more: Optional[List[Dict]] = None,
229-
buffers: Optional[List[MediaBuffer]] = None,
218+
self,
219+
section: str = None,
220+
chart: BaseType.Chart = None,
221+
data: Union[List[str], float] = None,
222+
more: Optional[List[Dict]] = None,
223+
buffers: Optional[List[MediaBuffer]] = None,
224+
reference: ChartReference = "STEP",
230225
):
231226
"""
232227
:param section: 转换后数据对应的section
233228
:param chart: 转换后数据对应的图表类型,枚举类型
234229
:param data: 存储在.log中的数据
235-
:param config: 存储在.log中的配置
236230
:param more: 存储在.log中的更多信息
237231
:param buffers: 存储于media文件夹中的原始数据,比特流,特别的,对于某些字符串即原始数据的情况,此处为None
232+
:param reference: 图表数据的参考类型
238233
"""
239234
self.__data = data
240-
self.config = config
241235
self.more = more
242236
self.buffers = buffers
243237
self.section = section
244238
self.chart = chart
245-
# 默认的reference
246-
self.reference = "step"
239+
self.reference = reference
247240
self.step = None
248241

249242
@property
@@ -291,11 +284,11 @@ class ParseErrorInfo:
291284
"""
292285

293286
def __init__(
294-
self,
295-
expected: Optional[str],
296-
got: Optional[str],
297-
chart: Optional[BaseType.Chart],
298-
duplicated: bool = False
287+
self,
288+
expected: Optional[str],
289+
got: Optional[str],
290+
chart: Optional[BaseType.Chart],
291+
duplicated: bool = False,
299292
):
300293
"""
301294
:param expected: 期望的数据类型

test/unit/callback/models/test_key.py

+24-20
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,43 @@
55
def test_column_info():
66
c = K.ColumnInfo(
77
key="a/1",
8-
key_id="b",
9-
key_name="c",
10-
key_class="SYSTEM",
8+
kid="b",
9+
name="c",
10+
cls="SYSTEM",
1111
section_name="e",
1212
section_sort=1,
13+
section_type="PUBLIC",
1314
chart_type=ChartType.TEXT,
14-
chart_reference="step",
15+
chart_reference="STEP",
1516
error=None,
1617
config=None,
1718
)
1819
assert c.got is None
1920
assert c.key == "a/1"
20-
assert c.key_id == "b"
21-
assert c.key_name == "c"
22-
assert c.key_class == "SYSTEM"
21+
assert c.kid == "b"
22+
assert c.name == "c"
23+
assert c.cls == "SYSTEM"
2324
assert c.section_name == "e"
2425
assert c.section_sort == 1
2526
assert c.chart_type == ChartType.TEXT
26-
assert c.chart_reference == "step"
27+
assert c.chart_reference == "STEP"
28+
assert c.section_type == "PUBLIC"
2729
assert c.error is None
28-
assert c.config == {}
30+
assert c.config is None
2931
assert c.key_encode == "a%2F1"
3032

3133

3234
def test_metric_info():
3335
c = K.ColumnInfo(
3436
key="a/1",
35-
key_id="b",
36-
key_name="c",
37-
key_class="SYSTEM",
37+
kid="b",
38+
name="c",
39+
cls="SYSTEM",
3840
section_name="e",
3941
section_sort=1,
4042
chart_type=ChartType.TEXT,
41-
chart_reference="step",
43+
section_type="PUBLIC",
44+
chart_reference="STEP",
4245
error=None,
4346
config=None,
4447
)
@@ -56,25 +59,26 @@ def test_metric_info():
5659
)
5760
assert m.column_info.got is None
5861
assert m.column_info.key == "a/1"
59-
assert m.column_info.key_id == "b"
60-
assert m.column_info.key_name == "c"
61-
assert m.column_info.key_class == "SYSTEM"
62+
assert m.column_info.kid == "b"
63+
assert m.column_info.name == "c"
64+
assert m.column_info.cls == "SYSTEM"
6265
assert m.column_info.section_name == "e"
6366
assert m.column_info.section_sort == 1
6467
assert m.column_info.chart_type == ChartType.TEXT
65-
assert m.column_info.chart_reference == "step"
68+
assert m.column_info.chart_reference == "STEP"
69+
assert m.column_info.section_type == "PUBLIC"
6670
assert m.column_info.error is None
67-
assert m.column_info.config == {}
71+
assert m.column_info.config is None
6872
assert m.column_info.key_encode == "a%2F1"
6973
assert m.column_info.got is None
7074
assert m.column_info.expected is None
7175
assert m.column_info.key_encode == "a%2F1"
7276
assert m.column_info.key == "a/1"
73-
assert m.column_info.key_id == "b"
77+
assert m.column_info.kid == "b"
7478
assert m.metric == {"data": 1}
7579
assert m.metric_buffers is None
7680
assert m.metric_summary == {"data": 1}
7781
assert m.metric_step == 1
7882
assert m.metric_epoch == 1
7983
assert m.swanlab_media_dir == "."
80-
assert m.metric_file_path == f"./{c.key_id}/1.log"
84+
assert m.metric_file_path == f"./{c.kid}/1.log"

0 commit comments

Comments
 (0)