Skip to content

Commit c5812fb

Browse files
Simplify API version validation (deepmodeling#1556)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Simplified and streamlined the submission processes for various job types, improving efficiency and reducing redundancy. - Centralized API version checking with a new utility function, enhancing maintainability and consistency across the application. - **Bug Fixes** - Improved error handling for API versions below 1.0, ensuring clearer and more consistent error messages. - **New Features** - Introduced a new function for API version validation, ensuring compatibility and proper error handling. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 54e48c6 commit c5812fb

File tree

5 files changed

+165
-184
lines changed

5 files changed

+165
-184
lines changed

dpgen/data/gen.py

+69-84
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import dpdata
1212
import numpy as np
13-
from packaging.version import Version
1413
from pymatgen.core import Structure
1514
from pymatgen.io.vasp import Incar
1615

@@ -28,7 +27,7 @@
2827
make_abacus_scf_stru,
2928
make_supercell_abacus,
3029
)
31-
from dpgen.generator.lib.utils import symlink_user_forward_files
30+
from dpgen.generator.lib.utils import check_api_version, symlink_user_forward_files
3231
from dpgen.generator.lib.vasp import incar_upper
3332
from dpgen.remote.decide_machine import convert_mdata
3433
from dpgen.util import load_file
@@ -1158,27 +1157,23 @@ def run_vasp_relax(jdata, mdata):
11581157
# relax_run_tasks.append(ii)
11591158
run_tasks = [os.path.basename(ii) for ii in relax_run_tasks]
11601159

1161-
api_version = mdata.get("api_version", "1.0")
1162-
if Version(api_version) < Version("1.0"):
1163-
raise RuntimeError(
1164-
f"API version {api_version} has been removed. Please upgrade to 1.0."
1165-
)
1166-
1167-
elif Version(api_version) >= Version("1.0"):
1168-
submission = make_submission(
1169-
mdata["fp_machine"],
1170-
mdata["fp_resources"],
1171-
commands=[fp_command],
1172-
work_path=work_dir,
1173-
run_tasks=run_tasks,
1174-
group_size=fp_group_size,
1175-
forward_common_files=forward_common_files,
1176-
forward_files=forward_files,
1177-
backward_files=backward_files,
1178-
outlog="fp.log",
1179-
errlog="fp.log",
1180-
)
1181-
submission.run_submission()
1160+
### Submit jobs
1161+
check_api_version(mdata)
1162+
1163+
submission = make_submission(
1164+
mdata["fp_machine"],
1165+
mdata["fp_resources"],
1166+
commands=[fp_command],
1167+
work_path=work_dir,
1168+
run_tasks=run_tasks,
1169+
group_size=fp_group_size,
1170+
forward_common_files=forward_common_files,
1171+
forward_files=forward_files,
1172+
backward_files=backward_files,
1173+
outlog="fp.log",
1174+
errlog="fp.log",
1175+
)
1176+
submission.run_submission()
11821177

11831178

11841179
def coll_abacus_md(jdata):
@@ -1298,27 +1293,23 @@ def run_abacus_relax(jdata, mdata):
12981293
# relax_run_tasks.append(ii)
12991294
run_tasks = [os.path.basename(ii) for ii in relax_run_tasks]
13001295

1301-
api_version = mdata.get("api_version", "1.0")
1302-
if Version(api_version) < Version("1.0"):
1303-
raise RuntimeError(
1304-
f"API version {api_version} has been removed. Please upgrade to 1.0."
1305-
)
1306-
1307-
elif Version(api_version) >= Version("1.0"):
1308-
submission = make_submission(
1309-
mdata["fp_machine"],
1310-
mdata["fp_resources"],
1311-
commands=[fp_command],
1312-
work_path=work_dir,
1313-
run_tasks=run_tasks,
1314-
group_size=fp_group_size,
1315-
forward_common_files=forward_common_files,
1316-
forward_files=forward_files,
1317-
backward_files=backward_files,
1318-
outlog="fp.log",
1319-
errlog="fp.log",
1320-
)
1321-
submission.run_submission()
1296+
### Submit jobs
1297+
check_api_version(mdata)
1298+
1299+
submission = make_submission(
1300+
mdata["fp_machine"],
1301+
mdata["fp_resources"],
1302+
commands=[fp_command],
1303+
work_path=work_dir,
1304+
run_tasks=run_tasks,
1305+
group_size=fp_group_size,
1306+
forward_common_files=forward_common_files,
1307+
forward_files=forward_files,
1308+
backward_files=backward_files,
1309+
outlog="fp.log",
1310+
errlog="fp.log",
1311+
)
1312+
submission.run_submission()
13221313

13231314

13241315
def run_vasp_md(jdata, mdata):
@@ -1359,27 +1350,24 @@ def run_vasp_md(jdata, mdata):
13591350
run_tasks = [ii.replace(work_dir + "/", "") for ii in md_run_tasks]
13601351
# dlog.info("md_work_dir", work_dir)
13611352
# dlog.info("run_tasks",run_tasks)
1362-
api_version = mdata.get("api_version", "1.0")
1363-
if Version(api_version) < Version("1.0"):
1364-
raise RuntimeError(
1365-
f"API version {api_version} has been removed. Please upgrade to 1.0."
1366-
)
13671353

1368-
elif Version(api_version) >= Version("1.0"):
1369-
submission = make_submission(
1370-
mdata["fp_machine"],
1371-
mdata["fp_resources"],
1372-
commands=[fp_command],
1373-
work_path=work_dir,
1374-
run_tasks=run_tasks,
1375-
group_size=fp_group_size,
1376-
forward_common_files=forward_common_files,
1377-
forward_files=forward_files,
1378-
backward_files=backward_files,
1379-
outlog="fp.log",
1380-
errlog="fp.log",
1381-
)
1382-
submission.run_submission()
1354+
### Submit jobs
1355+
check_api_version(mdata)
1356+
1357+
submission = make_submission(
1358+
mdata["fp_machine"],
1359+
mdata["fp_resources"],
1360+
commands=[fp_command],
1361+
work_path=work_dir,
1362+
run_tasks=run_tasks,
1363+
group_size=fp_group_size,
1364+
forward_common_files=forward_common_files,
1365+
forward_files=forward_files,
1366+
backward_files=backward_files,
1367+
outlog="fp.log",
1368+
errlog="fp.log",
1369+
)
1370+
submission.run_submission()
13831371

13841372

13851373
def run_abacus_md(jdata, mdata):
@@ -1435,27 +1423,24 @@ def run_abacus_md(jdata, mdata):
14351423
run_tasks = [ii.replace(work_dir + "/", "") for ii in md_run_tasks]
14361424
# dlog.info("md_work_dir", work_dir)
14371425
# dlog.info("run_tasks",run_tasks)
1438-
api_version = mdata.get("api_version", "1.0")
1439-
if Version(api_version) < Version("1.0"):
1440-
raise RuntimeError(
1441-
f"API version {api_version} has been removed. Please upgrade to 1.0."
1442-
)
14431426

1444-
elif Version(api_version) >= Version("1.0"):
1445-
submission = make_submission(
1446-
mdata["fp_machine"],
1447-
mdata["fp_resources"],
1448-
commands=[fp_command],
1449-
work_path=work_dir,
1450-
run_tasks=run_tasks,
1451-
group_size=fp_group_size,
1452-
forward_common_files=forward_common_files,
1453-
forward_files=forward_files,
1454-
backward_files=backward_files,
1455-
outlog="fp.log",
1456-
errlog="fp.log",
1457-
)
1458-
submission.run_submission()
1427+
### Submit jobs
1428+
check_api_version(mdata)
1429+
1430+
submission = make_submission(
1431+
mdata["fp_machine"],
1432+
mdata["fp_resources"],
1433+
commands=[fp_command],
1434+
work_path=work_dir,
1435+
run_tasks=run_tasks,
1436+
group_size=fp_group_size,
1437+
forward_common_files=forward_common_files,
1438+
forward_files=forward_files,
1439+
backward_files=backward_files,
1440+
outlog="fp.log",
1441+
errlog="fp.log",
1442+
)
1443+
submission.run_submission()
14591444

14601445

14611446
def gen_init_bulk(args):

dpgen/dispatcher/Dispatcher.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -138,21 +138,20 @@ def make_submission_compat(
138138
"""
139139
if Version(api_version) < Version("1.0"):
140140
raise RuntimeError(
141-
f"API version {api_version} has been removed. Please upgrade to 1.0."
141+
"API version below 1.0 is no longer supported. Please upgrade to version 1.0 or newer."
142142
)
143143

144-
elif Version(api_version) >= Version("1.0"):
145-
submission = make_submission(
146-
machine,
147-
resources,
148-
commands=commands,
149-
work_path=work_path,
150-
run_tasks=run_tasks,
151-
group_size=group_size,
152-
forward_common_files=forward_common_files,
153-
forward_files=forward_files,
154-
backward_files=backward_files,
155-
outlog=outlog,
156-
errlog=errlog,
157-
)
158-
submission.run_submission()
144+
submission = make_submission(
145+
machine,
146+
resources,
147+
commands=commands,
148+
work_path=work_path,
149+
run_tasks=run_tasks,
150+
group_size=group_size,
151+
forward_common_files=forward_common_files,
152+
forward_files=forward_files,
153+
backward_files=backward_files,
154+
outlog=outlog,
155+
errlog=errlog,
156+
)
157+
submission.run_submission()

dpgen/generator/lib/utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import re
77
import shutil
88

9+
from packaging.version import Version
10+
911
iter_format = "%06d"
1012
task_format = "%02d"
1113
log_iter_head = "iter " + iter_format + " task " + task_format + ": "
@@ -110,3 +112,12 @@ def symlink_user_forward_files(mdata, task_type, work_path, task_format=None):
110112
abs_file = os.path.abspath(file)
111113
os.symlink(abs_file, os.path.join(task, os.path.basename(file)))
112114
return
115+
116+
117+
def check_api_version(mdata):
118+
"""Check if the API version in mdata is at least 1.0."""
119+
if Version(mdata.get("api_version", "1.0")) < Version("1.0"):
120+
raise RuntimeError(
121+
"API version below 1.0 is no longer supported. Please upgrade to version 1.0 or newer."
122+
)
123+
return

0 commit comments

Comments
 (0)