Skip to content

Commit 0c9ad71

Browse files
committed
feat: add args and kwargs for callback
1 parent 2be4448 commit 0c9ad71

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ build-backend = "hatchling.build"
99

1010
[project]
1111
name = "swankit"
12-
version = "0.1.5"
12+
version = "0.1.6"
1313
dynamic = ["readme", "dependencies"]
1414
description = "Base toolkit for SwanLab"
1515
license = "Apache-2.0"

swankit/callback/__init__.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from abc import ABC, abstractmethod
1212
from .models import *
1313
from swankit.core import SwanLabSharedSettings
14-
from typing import Tuple
14+
from typing import Tuple, Optional
1515

1616

1717
class SwanKitCallback(ABC):
@@ -20,7 +20,7 @@ class SwanKitCallback(ABC):
2020
此处只定义会被调用的函数,用于接口规范
2121
"""
2222

23-
def on_init(self, proj_name: str, workspace: str, logdir: str = None, **kwargs):
23+
def on_init(self, proj_name: str, workspace: str, logdir: str = None, *args, **kwargs):
2424
"""
2525
执行`swanlab.init`时调用,此时运行时环境变量没有被设置,此时修改环境变量还是有效的
2626
:param logdir: str, 用户设置的日志目录
@@ -30,7 +30,7 @@ def on_init(self, proj_name: str, workspace: str, logdir: str = None, **kwargs):
3030
"""
3131
pass
3232

33-
def before_run(self, settings: SwanLabSharedSettings):
33+
def before_run(self, settings: SwanLabSharedSettings, *args, **kwargs):
3434
"""
3535
在运行实验之前调用
3636
:param settings: SwanLabSharedSettings, 运行时的共享配置
@@ -44,6 +44,8 @@ def before_init_experiment(
4444
description: str,
4545
num: int,
4646
colors: Tuple[str, str],
47+
*args,
48+
**kwargs,
4749
):
4850
"""
4951
在初始化实验之前调用,此时SwanLabRun已经初始化完毕
@@ -55,44 +57,44 @@ def before_init_experiment(
5557
"""
5658
pass
5759

58-
def on_run(self):
60+
def on_run(self, *args, **kwargs):
5961
"""
6062
SwanLabRun初始化完毕时调用
6163
"""
6264
pass
6365

64-
def on_run_error_from_operator(self, e: OperateErrorInfo):
66+
def on_run_error_from_operator(self, e: OperateErrorInfo, *args, **kwargs):
6567
"""
6668
执行`on_run`错误时被操作员调用
6769
"""
6870
pass
6971

70-
def on_runtime_info_update(self, r: RuntimeInfo):
72+
def on_runtime_info_update(self, r: RuntimeInfo, *args, **kwargs):
7173
"""
7274
运行时信息更新时调用
7375
:param r: RuntimeInfo, 运行时信息
7476
"""
7577
pass
7678

77-
def on_log(self):
79+
def on_log(self, data: dict, step: Optional[int], *args, **kwargs):
7880
"""
7981
每次执行swanlab.log时调用
8082
"""
8183
pass
8284

83-
def on_column_create(self, column_info: ColumnInfo):
85+
def on_column_create(self, column_info: ColumnInfo, *args, **kwargs):
8486
"""
8587
列创建回调函数,新增列信息时调用
8688
"""
8789
pass
8890

89-
def on_metric_create(self, metric_info: MetricInfo):
91+
def on_metric_create(self, metric_info: MetricInfo, *args, **kwargs):
9092
"""
9193
指标创建回调函数,新增指标信息时调用
9294
"""
9395
pass
9496

95-
def on_stop(self, error: str = None):
97+
def on_stop(self, error: str = None, *args, **kwargs):
9698
"""
9799
训练结束时的回调函数
98100
"""

0 commit comments

Comments
 (0)