Skip to content

Commit 1d7f223

Browse files
support callback on epoch/iter begin&end and turorial (PaddlePaddle#1153)
1 parent ca6ef98 commit 1d7f223

File tree

3 files changed

+145
-0
lines changed

3 files changed

+145
-0
lines changed

docs/zh/user_guide.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,53 @@ PaddleScience 内置了两种模型平均方法:[Stochastic weight averaging(S
10791079
3. 设置平均间隔为 1 个 epoch
10801080
4. 设置平均的起始和终止 epoch 为 75100
10811081
1082+
### 2.7 回调(callback)注册与调用指南
1083+
1084+
在深度学习模型的训练过程中,能够在特定的时机执行自定义逻辑是非常有用的。PaddleScience 的 `Solver` 类提供了一种相对灵活的机制,允许用户在**训练的不同阶段**注册和调用回调函数。
1085+
1086+
具体地,我们提供了如下四种注册回调函数的接口:
1087+
1088+
``` py
1089+
Solver.register_callback_on_epoch_begin # 在每个 epoch 开始时调用
1090+
Solver.register_callback_on_epoch_end # 在每个 epoch 结束时调用
1091+
Solver.register_callback_on_iter_begin # 在每个 iteration 开始时调用
1092+
Solver.register_callback_on_iter_end # 在每个 iteration 结束时调用
1093+
```
1094+
1095+
它们在训练过程中的调用时机如下示例所示:
1096+
1097+
``` py hl_lines="3 6 8 10"
1098+
for epoch_id in range(1, num_epochs + 1):
1099+
# train one epoch...
1100+
_invoke_callbacks_on_epoch_begin() # 此处按注册顺序, 自动调用通过 register_callback_on_epoch_begin 注册的回调函数
1101+
1102+
for iter_id in range(1, num_iters + 1)
1103+
_invoke_callbacks_on_iter_begin() # 此处按注册顺序, 自动调用通过 register_callback_on_iter_begin 注册的回调函数
1104+
# train one iteration...
1105+
_invoke_callbacks_on_iter_end() # 此处按注册顺序, 自动调用通过 register_callback_on_iter_end 注册的回调函数
1106+
1107+
_invoke_callbacks_on_epoch_end() # 此处按注册顺序, 自动调用通过 register_callback_on_epoch_end 注册的回调函数
1108+
```
1109+
1110+
以 `examples/fsi/viv.py` 为例,假设希望在训练时,每隔 100 个 epoch 打印出方程中的可学习参数 `k1`, `k2`,那么可以按照如下示例代码,添加回调函数:
1111+
1112+
``` py hl_lines="11 12 13 14 15"
1113+
# initialize solver
1114+
solver = ppsci.solver.Solver(
1115+
model,
1116+
constraint,
1117+
optimizer=optimizer,
1118+
equation=equation,
1119+
validator=validator,
1120+
visualizer=visualizer,
1121+
cfg=cfg,
1122+
)
1123+
def show_learnable_params(slv):
1124+
if slv.global_step % 100 == 0:
1125+
ppsci.utils.logger.message(f"{equation['VIV'].k1.item():.5f}, {equation['VIV'].k2.item():.5f}")
1126+
solver.register_callback_on_iter_begin(show_learnable_params)
1127+
```
1128+
10821129
## 3. 使用 Nsight 进行性能分析
10831130
10841131
Nsight是NVIDIA面向开发者提供的开发工具套件,能提供深入的跟踪、调试、评测和分析,以优化跨 NVIDIA GPU和CPU的复杂计算应用程序。详细文档可参考:[Nsight Systems Document](https://docs.nvidia.com/nsight-systems/index.html)

ppsci/solver/solver.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,12 @@ def convert_expr(
552552
self.nvtx_flag: bool = os.getenv("NVTX", None) is not None
553553
self.forward_helper.nvtx_flag = self.nvtx_flag
554554

555+
# for callbacks
556+
self.callbacks_on_epoch_begin: List[Callable[[Solver]]] = []
557+
self.callbacks_on_epoch_end: List[Callable[[Solver]]] = []
558+
self.callbacks_on_iter_begin: List[Callable[[Solver]]] = []
559+
self.callbacks_on_iter_end: List[Callable[[Solver]]] = []
560+
555561
def train(self) -> None:
556562
"""Training."""
557563
self.global_step = self.best_metric["epoch"] * self.iters_per_epoch
@@ -569,7 +575,10 @@ def train(self) -> None:
569575
core.nvprof_enable_record_event()
570576

571577
for epoch_id in range(start_epoch, self.epochs + 1):
578+
self._invoke_callbacks_on_epoch_begin() # [optional]
572579
self.train_epoch_func(self, epoch_id, self.log_freq)
580+
self._invoke_callbacks_on_epoch_end() # [optional]
581+
573582
self.train_output_info.clear()
574583

575584
# update average model if exist
@@ -1124,3 +1133,87 @@ def _parse_params_from_cfg(self, cfg: DictConfig):
11241133
self.pretrained_model_path = cfg.EVAL.pretrained_model_path
11251134
elif cfg.mode in ["export", "infer"]:
11261135
self.pretrained_model_path = cfg.INFER.pretrained_model_path
1136+
1137+
def register_callback_on_epoch_begin(
1138+
self: Solver, callback_fn: Callable[[Solver]]
1139+
) -> None:
1140+
"""
1141+
Registers a callback function to be executed at the beginning of each training epoch.
1142+
1143+
Args:
1144+
callback_fn : Callable[[Solver]]
1145+
A function that takes a Solver instance as an argument. This function
1146+
will be called at the start of every epoch.
1147+
"""
1148+
self.callbacks_on_epoch_begin.append(callback_fn)
1149+
1150+
def register_callback_on_epoch_end(
1151+
self: Solver, callback_fn: Callable[[Solver]]
1152+
) -> None:
1153+
"""
1154+
Registers a callback function to be executed at the end of each training epoch.
1155+
1156+
Args:
1157+
callback_fn : Callable[[Solver]]
1158+
A function that takes a Solver instance as an argument. This function
1159+
will be called at the end of every epoch.
1160+
"""
1161+
self.callbacks_on_epoch_end.append(callback_fn)
1162+
1163+
def register_callback_on_iter_begin(
1164+
self: Solver, callback_fn: Callable[[Solver]]
1165+
) -> None:
1166+
"""
1167+
Registers a callback function to be executed at the beginning of each training iteration.
1168+
1169+
Args:
1170+
callback_fn : Callable[[Solver]]
1171+
A function that takes a Solver instance as an argument. This function
1172+
will be called at the start of every iteration.
1173+
"""
1174+
self.callbacks_on_iter_begin.append(callback_fn)
1175+
1176+
def register_callback_on_iter_end(
1177+
self: Solver, callback_fn: Callable[[Solver]]
1178+
) -> None:
1179+
"""
1180+
Registers a callback function to be executed at the end of each training iteration.
1181+
1182+
Args:
1183+
callback_fn : Callable[[Solver]]
1184+
A function that takes a Solver instance as an argument. This function
1185+
will be called at the end of every iteration.
1186+
1187+
Returns:
1188+
-------
1189+
None
1190+
"""
1191+
self.callbacks_on_iter_end.append(callback_fn)
1192+
1193+
def _invoke_callbacks_on_epoch_begin(self: Solver) -> None:
1194+
"""
1195+
Invokes all registered callbacks at the beginning of an epoch.
1196+
"""
1197+
for callback in self.callbacks_on_epoch_begin:
1198+
callback(self)
1199+
1200+
def _invoke_callbacks_on_epoch_end(self: Solver) -> None:
1201+
"""
1202+
Invokes all registered callbacks at the end of an epoch.
1203+
"""
1204+
for callback in self.callbacks_on_epoch_end:
1205+
callback(self)
1206+
1207+
def _invoke_callbacks_on_iter_begin(self: Solver) -> None:
1208+
"""
1209+
Invokes all registered callbacks at the beginning of an iteration.
1210+
"""
1211+
for callback in self.callbacks_on_iter_begin:
1212+
callback(self)
1213+
1214+
def _invoke_callbacks_on_iter_end(self: Solver) -> None:
1215+
"""
1216+
Invokes all registered callbacks at the end of an iteration.
1217+
"""
1218+
for callback in self.callbacks_on_iter_end:
1219+
callback(self)

ppsci/solver/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
6666
batch_tic = time.perf_counter()
6767

6868
for iter_id in range(1, solver.iters_per_epoch + 1):
69+
solver._invoke_callbacks_on_iter_begin()
6970
if solver.nvtx_flag: # only for nsight analysis
7071
core.nvprof_nvtx_push(
7172
f"Training iteration {solver.global_step + 1}"
@@ -212,6 +213,8 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
212213
core.nvprof_stop()
213214
sys.exit(0)
214215

216+
solver._invoke_callbacks_on_iter_end()
217+
215218

216219
def train_LBFGS_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
217220
"""Train function for one epoch with L-BFGS optimizer.
@@ -226,6 +229,7 @@ def train_LBFGS_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int
226229
batch_tic = time.perf_counter()
227230

228231
for iter_id in range(1, solver.iters_per_epoch + 1):
232+
solver._invoke_callbacks_on_iter_begin()
229233
loss_dict = misc.Prettydefaultdict(float)
230234
loss_dict["loss"] = 0.0
231235
total_batch_size = 0
@@ -317,3 +321,4 @@ def closure() -> paddle.Tensor:
317321
printer.log_train_info(solver, total_batch_size, epoch_id, iter_id)
318322

319323
batch_tic = time.perf_counter()
324+
solver._invoke_callbacks_on_iter_end()

0 commit comments

Comments
 (0)