Skip to content

fix: check the number of collected data in post_fp_check_fail #882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 1, 2022
22 changes: 15 additions & 7 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3165,20 +3165,28 @@ def run_fp (iter_index,
def post_fp_check_fail(iter_index,
jdata,
rfailed = None) :

ratio_failed = rfailed if rfailed else jdata.get('ratio_failed',0.05)
iter_name = make_iter_name(iter_index)
work_path = os.path.join(iter_name, fp_name)
fp_tasks = glob.glob(os.path.join(work_path, 'task.*'))
fp_tasks.sort()
if len(fp_tasks) == 0 :
return
# check fail according to tag_failure
fp_failed_tags = glob.glob(os.path.join(work_path, 'task.*', 'tag_failure*'))
fp_failed_tasks = [os.path.dirname(ii) for ii in fp_failed_tags]
fp_failed_tasks = list(set(fp_failed_tasks))

ntask = len(fp_tasks)
nfail = len(fp_failed_tasks)
nfail = 0

# check fail according to the number of collected data
sys_data = glob.glob(os.path.join(work_path, "data.*"))
sys_data.sort()
nframe = 0
for ii in sys_data :
sys_paths = expand_sys_str(ii)
for single_sys in sys_paths:
sys = dpdata.LabeledSystem(os.path.join(single_sys), fmt = 'deepmd/npy')
nframe += len(sys)
nfail = ntask - nframe

rfail = float(nfail) / float(ntask)
dlog.info("failed tasks: %6d in %6d %6.2f %% " % (nfail, ntask, rfail * 100.))
if rfail > ratio_failed:
Expand Down Expand Up @@ -3604,7 +3612,6 @@ def post_fp_amber_diff(iter_index, jdata):
def post_fp (iter_index,
jdata) :
fp_style = jdata['fp_style']
post_fp_check_fail(iter_index, jdata)
if fp_style == "vasp" :
post_fp_vasp(iter_index, jdata)
elif fp_style == "pwscf" :
Expand All @@ -3623,6 +3630,7 @@ def post_fp (iter_index,
post_fp_amber_diff(iter_index, jdata)
else :
raise RuntimeError ("unsupported fp style")
post_fp_check_fail(iter_index, jdata)
# clean traj
clean_traj = True
if 'model_devi_clean_traj' in jdata :
Expand Down