@@ -552,6 +552,12 @@ def convert_expr(
552
552
self .nvtx_flag : bool = os .getenv ("NVTX" , None ) is not None
553
553
self .forward_helper .nvtx_flag = self .nvtx_flag
554
554
555
+ # for callbacks
556
+ self .callbacks_on_epoch_begin : List [Callable [[Solver ]]] = []
557
+ self .callbacks_on_epoch_end : List [Callable [[Solver ]]] = []
558
+ self .callbacks_on_iter_begin : List [Callable [[Solver ]]] = []
559
+ self .callbacks_on_iter_end : List [Callable [[Solver ]]] = []
560
+
555
561
def train (self ) -> None :
556
562
"""Training."""
557
563
self .global_step = self .best_metric ["epoch" ] * self .iters_per_epoch
@@ -569,7 +575,10 @@ def train(self) -> None:
569
575
core .nvprof_enable_record_event ()
570
576
571
577
for epoch_id in range (start_epoch , self .epochs + 1 ):
578
+ self ._invoke_callbacks_on_epoch_begin () # [optional]
572
579
self .train_epoch_func (self , epoch_id , self .log_freq )
580
+ self ._invoke_callbacks_on_epoch_end () # [optional]
581
+
573
582
self .train_output_info .clear ()
574
583
575
584
# update average model if exist
@@ -1124,3 +1133,87 @@ def _parse_params_from_cfg(self, cfg: DictConfig):
1124
1133
self .pretrained_model_path = cfg .EVAL .pretrained_model_path
1125
1134
elif cfg .mode in ["export" , "infer" ]:
1126
1135
self .pretrained_model_path = cfg .INFER .pretrained_model_path
1136
+
1137
+ def register_callback_on_epoch_begin (
1138
+ self : Solver , callback_fn : Callable [[Solver ]]
1139
+ ) -> None :
1140
+ """
1141
+ Registers a callback function to be executed at the beginning of each training epoch.
1142
+
1143
+ Args:
1144
+ callback_fn : Callable[[Solver]]
1145
+ A function that takes a Solver instance as an argument. This function
1146
+ will be called at the start of every epoch.
1147
+ """
1148
+ self .callbacks_on_epoch_begin .append (callback_fn )
1149
+
1150
+ def register_callback_on_epoch_end (
1151
+ self : Solver , callback_fn : Callable [[Solver ]]
1152
+ ) -> None :
1153
+ """
1154
+ Registers a callback function to be executed at the end of each training epoch.
1155
+
1156
+ Args:
1157
+ callback_fn : Callable[[Solver]]
1158
+ A function that takes a Solver instance as an argument. This function
1159
+ will be called at the end of every epoch.
1160
+ """
1161
+ self .callbacks_on_epoch_end .append (callback_fn )
1162
+
1163
+ def register_callback_on_iter_begin (
1164
+ self : Solver , callback_fn : Callable [[Solver ]]
1165
+ ) -> None :
1166
+ """
1167
+ Registers a callback function to be executed at the beginning of each training iteration.
1168
+
1169
+ Args:
1170
+ callback_fn : Callable[[Solver]]
1171
+ A function that takes a Solver instance as an argument. This function
1172
+ will be called at the start of every iteration.
1173
+ """
1174
+ self .callbacks_on_iter_begin .append (callback_fn )
1175
+
1176
+ def register_callback_on_iter_end (
1177
+ self : Solver , callback_fn : Callable [[Solver ]]
1178
+ ) -> None :
1179
+ """
1180
+ Registers a callback function to be executed at the end of each training iteration.
1181
+
1182
+ Args:
1183
+ callback_fn : Callable[[Solver]]
1184
+ A function that takes a Solver instance as an argument. This function
1185
+ will be called at the end of every iteration.
1186
+
1187
+ Returns:
1188
+ -------
1189
+ None
1190
+ """
1191
+ self .callbacks_on_iter_end .append (callback_fn )
1192
+
1193
+ def _invoke_callbacks_on_epoch_begin (self : Solver ) -> None :
1194
+ """
1195
+ Invokes all registered callbacks at the beginning of an epoch.
1196
+ """
1197
+ for callback in self .callbacks_on_epoch_begin :
1198
+ callback (self )
1199
+
1200
+ def _invoke_callbacks_on_epoch_end (self : Solver ) -> None :
1201
+ """
1202
+ Invokes all registered callbacks at the end of an epoch.
1203
+ """
1204
+ for callback in self .callbacks_on_epoch_end :
1205
+ callback (self )
1206
+
1207
+ def _invoke_callbacks_on_iter_begin (self : Solver ) -> None :
1208
+ """
1209
+ Invokes all registered callbacks at the beginning of an iteration.
1210
+ """
1211
+ for callback in self .callbacks_on_iter_begin :
1212
+ callback (self )
1213
+
1214
+ def _invoke_callbacks_on_iter_end (self : Solver ) -> None :
1215
+ """
1216
+ Invokes all registered callbacks at the end of an iteration.
1217
+ """
1218
+ for callback in self .callbacks_on_iter_end :
1219
+ callback (self )
0 commit comments