11
11
from abc import ABC , abstractmethod
12
12
from .models import *
13
13
from swankit .core import SwanLabSharedSettings
14
- from typing import Tuple
14
+ from typing import Tuple , Optional
15
15
16
16
17
17
class SwanKitCallback (ABC ):
@@ -20,7 +20,7 @@ class SwanKitCallback(ABC):
20
20
此处只定义会被调用的函数,用于接口规范
21
21
"""
22
22
23
- def on_init (self , proj_name : str , workspace : str , logdir : str = None , ** kwargs ):
23
+ def on_init (self , proj_name : str , workspace : str , logdir : str = None , * args , * *kwargs ):
24
24
"""
25
25
执行`swanlab.init`时调用,此时运行时环境变量没有被设置,此时修改环境变量还是有效的
26
26
:param logdir: str, 用户设置的日志目录
@@ -30,7 +30,7 @@ def on_init(self, proj_name: str, workspace: str, logdir: str = None, **kwargs):
30
30
"""
31
31
pass
32
32
33
- def before_run (self , settings : SwanLabSharedSettings ):
33
+ def before_run (self , settings : SwanLabSharedSettings , * args , ** kwargs ):
34
34
"""
35
35
在运行实验之前调用
36
36
:param settings: SwanLabSharedSettings, 运行时的共享配置
@@ -44,6 +44,8 @@ def before_init_experiment(
44
44
description : str ,
45
45
num : int ,
46
46
colors : Tuple [str , str ],
47
+ * args ,
48
+ ** kwargs ,
47
49
):
48
50
"""
49
51
在初始化实验之前调用,此时SwanLabRun已经初始化完毕
@@ -55,44 +57,44 @@ def before_init_experiment(
55
57
"""
56
58
pass
57
59
58
- def on_run (self ):
60
+ def on_run (self , * args , ** kwargs ):
59
61
"""
60
62
SwanLabRun初始化完毕时调用
61
63
"""
62
64
pass
63
65
64
- def on_run_error_from_operator (self , e : OperateErrorInfo ):
66
+ def on_run_error_from_operator (self , e : OperateErrorInfo , * args , ** kwargs ):
65
67
"""
66
68
执行`on_run`错误时被操作员调用
67
69
"""
68
70
pass
69
71
70
- def on_runtime_info_update (self , r : RuntimeInfo ):
72
+ def on_runtime_info_update (self , r : RuntimeInfo , * args , ** kwargs ):
71
73
"""
72
74
运行时信息更新时调用
73
75
:param r: RuntimeInfo, 运行时信息
74
76
"""
75
77
pass
76
78
77
- def on_log (self ):
79
+ def on_log (self , data : dict , step : Optional [ int ], * args , ** kwargs ):
78
80
"""
79
81
每次执行swanlab.log时调用
80
82
"""
81
83
pass
82
84
83
- def on_column_create (self , column_info : ColumnInfo ):
85
+ def on_column_create (self , column_info : ColumnInfo , * args , ** kwargs ):
84
86
"""
85
87
列创建回调函数,新增列信息时调用
86
88
"""
87
89
pass
88
90
89
- def on_metric_create (self , metric_info : MetricInfo ):
91
+ def on_metric_create (self , metric_info : MetricInfo , * args , ** kwargs ):
90
92
"""
91
93
指标创建回调函数,新增指标信息时调用
92
94
"""
93
95
pass
94
96
95
- def on_stop (self , error : str = None ):
97
+ def on_stop (self , error : str = None , * args , ** kwargs ):
96
98
"""
97
99
训练结束时的回调函数
98
100
"""
0 commit comments