Skip to content

Commit a0684ca

Browse files
committed
Merge branch 'PR' of https://github.com/thangckt/dpgen into PR
2 parents 11dca54 + a1b3ff8 commit a0684ca

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

dpgen/generator/run.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@
126126

127127

128128
def _get_model_suffix(jdata) -> str:
129-
"""return the model suffix based on the backend"""
129+
"""Return the model suffix based on the backend"""
130130
backend = jdata.get("train_backend", "tensorflow")
131131
if backend == "tensorflow":
132132
suffix = ".pb"
@@ -193,7 +193,10 @@ def copy_model(numb_model, prv_iter_index, cur_iter_index, suffix=".pb"):
193193
prv_train_task = os.path.join(prv_train_path, train_task_fmt % ii)
194194
os.chdir(cur_train_path)
195195
os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii)
196-
os.symlink(os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), "graph.%03d%s" % (ii, suffix))
196+
os.symlink(
197+
os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
198+
"graph.%03d%s" % (ii, suffix),
199+
)
197200
os.chdir(cwd)
198201
with open(os.path.join(cur_train_path, "copied"), "w") as fp:
199202
None
@@ -657,7 +660,9 @@ def make_train(iter_index, jdata, mdata):
657660
)
658661
if copied_models is not None:
659662
for ii in range(len(copied_models)):
660-
_link_old_models(work_path, [copied_models[ii]], ii, basename="init%s" % suffix)
663+
_link_old_models(
664+
work_path, [copied_models[ii]], ii, basename="init%s" % suffix
665+
)
661666
# Copy user defined forward files
662667
symlink_user_forward_files(mdata=mdata, task_type="train", work_path=work_path)
663668
# HDF5 format for training data
@@ -811,11 +816,18 @@ def run_train(iter_index, jdata, mdata):
811816
elif training_init_frozen_model is not None or training_finetune_model is not None:
812817
forward_files.append(os.path.join("old", "init%s" % suffix))
813818

814-
backward_files = ["frozen_model%s" % suffix, "lcurve.out", "train.log", "checkpoint"]
819+
backward_files = [
820+
"frozen_model%s" % suffix,
821+
"lcurve.out",
822+
"train.log",
823+
"checkpoint",
824+
]
815825
if suffix == ".pb":
816-
backward_files += ["model.ckpt.meta",
817-
"model.ckpt.index",
818-
"model.ckpt.data-00000-of-00001"]
826+
backward_files += [
827+
"model.ckpt.meta",
828+
"model.ckpt.index",
829+
"model.ckpt.data-00000-of-00001",
830+
]
819831
if jdata.get("dp_compress", False):
820832
backward_files.append("frozen_model_compressed%s" % suffix)
821833

dpgen/simplify/simplify.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
record_iter,
3232
)
3333
from dpgen.generator.run import (
34+
_get_model_suffix,
3435
data_system_fmt,
3536
fp_name,
3637
fp_task_fmt,
@@ -43,7 +44,6 @@
4344
run_train,
4445
train_name,
4546
train_task_fmt,
46-
_get_model_suffix,
4747
)
4848
from dpgen.remote.decide_machine import convert_mdata
4949
from dpgen.util import expand_sys_str, load_file, normalize, sepline, setup_ele_temp

0 commit comments

Comments
 (0)