Skip to content

Commit fd82f04

Browse files
iProzdnjzjz
andauthored
Add max_ckpt_keep for trainer (deepmodeling#3441)
Signed-off-by: Duo <[email protected]> Co-authored-by: Jinzhe Zeng <[email protected]>
1 parent a9bcf41 commit fd82f04

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

deepmd/pt/train/training.py

+10
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132
self.disp_freq = training_params.get("disp_freq", 1000)
133133
self.save_ckpt = training_params.get("save_ckpt", "model.ckpt")
134134
self.save_freq = training_params.get("save_freq", 1000)
135+
self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5)
135136
self.lcurve_should_print_header = True
136137

137138
def get_opt_param(params):
@@ -924,6 +925,15 @@ def save_model(self, save_path, lr=0.0, step=0):
924925
{"model": module.state_dict(), "optimizer": self.optimizer.state_dict()},
925926
save_path,
926927
)
928+
checkpoint_dir = save_path.parent
929+
checkpoint_files = [
930+
f
931+
for f in checkpoint_dir.glob("*.pt")
932+
if not f.is_symlink() and f.name.startswith(self.save_ckpt)
933+
]
934+
if len(checkpoint_files) > self.max_ckpt_keep:
935+
checkpoint_files.sort(key=lambda x: x.stat().st_mtime)
936+
checkpoint_files[0].unlink()
927937

928938
def get_data(self, is_train=True, task_key="Default"):
929939
if not self.multi_task:

deepmd/tf/train/trainer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def get_lr_and_coef(lr_param):
164164
self.disp_freq = tr_data.get("disp_freq", 1000)
165165
self.save_freq = tr_data.get("save_freq", 1000)
166166
self.save_ckpt = tr_data.get("save_ckpt", "model.ckpt")
167+
self.max_ckpt_keep = tr_data.get("max_ckpt_keep", 5)
167168
self.display_in_training = tr_data.get("disp_training", True)
168169
self.timing_in_training = tr_data.get("time_training", True)
169170
self.profiling = self.run_opt.is_chief and tr_data.get("profiling", False)
@@ -498,7 +499,9 @@ def _init_session(self):
498499
# Initializes or restore global variables
499500
init_op = tf.global_variables_initializer()
500501
if self.run_opt.is_chief:
501-
self.saver = tf.train.Saver(save_relative_paths=True)
502+
self.saver = tf.train.Saver(
503+
save_relative_paths=True, max_to_keep=self.max_ckpt_keep
504+
)
502505
if self.run_opt.init_mode == "init_from_scratch":
503506
log.info("initialize model from scratch")
504507
run_sess(self.sess, init_op)

deepmd/utils/argcheck.py

+6
Original file line numberDiff line numberDiff line change
@@ -2134,6 +2134,11 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
21342134
doc_disp_freq = "The frequency of printing learning curve."
21352135
doc_save_freq = "The frequency of saving check point."
21362136
doc_save_ckpt = "The path prefix of saving check point files."
2137+
doc_max_ckpt_keep = (
2138+
"The maximum number of checkpoints to keep. "
2139+
"The oldest checkpoints will be deleted once the number of checkpoints exceeds max_ckpt_keep. "
2140+
"Defaults to 5."
2141+
)
21372142
doc_disp_training = "Displaying verbose information during training."
21382143
doc_time_training = "Timing durining training."
21392144
doc_profiling = "Profiling during training."
@@ -2192,6 +2197,7 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
21922197
Argument(
21932198
"save_ckpt", str, optional=True, default="model.ckpt", doc=doc_save_ckpt
21942199
),
2200+
Argument("max_ckpt_keep", int, optional=True, default=5, doc=doc_max_ckpt_keep),
21952201
Argument(
21962202
"disp_training", bool, optional=True, default=True, doc=doc_disp_training
21972203
),

0 commit comments

Comments
 (0)