Skip to content

Simplify API version validation #1556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 69 additions & 84 deletions dpgen/data/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import dpdata
import numpy as np
from packaging.version import Version
from pymatgen.core import Structure
from pymatgen.io.vasp import Incar

Expand All @@ -28,7 +27,7 @@
make_abacus_scf_stru,
make_supercell_abacus,
)
from dpgen.generator.lib.utils import symlink_user_forward_files
from dpgen.generator.lib.utils import check_api_version, symlink_user_forward_files
from dpgen.generator.lib.vasp import incar_upper
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import load_file
Expand Down Expand Up @@ -1158,27 +1157,23 @@
# relax_run_tasks.append(ii)
run_tasks = [os.path.basename(ii) for ii in relax_run_tasks]

api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit jobs
check_api_version(mdata)

Check warning on line 1161 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1161

Added line #L1161 was not covered by tests

submission = make_submission(

Check warning on line 1163 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1163

Added line #L1163 was not covered by tests
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()

Check warning on line 1176 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1176

Added line #L1176 was not covered by tests


def coll_abacus_md(jdata):
Expand Down Expand Up @@ -1298,27 +1293,23 @@
# relax_run_tasks.append(ii)
run_tasks = [os.path.basename(ii) for ii in relax_run_tasks]

api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit jobs
check_api_version(mdata)

Check warning on line 1297 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1297

Added line #L1297 was not covered by tests

submission = make_submission(

Check warning on line 1299 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1299

Added line #L1299 was not covered by tests
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()

Check warning on line 1312 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1312

Added line #L1312 was not covered by tests


def run_vasp_md(jdata, mdata):
Expand Down Expand Up @@ -1359,27 +1350,24 @@
run_tasks = [ii.replace(work_dir + "/", "") for ii in md_run_tasks]
# dlog.info("md_work_dir", work_dir)
# dlog.info("run_tasks",run_tasks)
api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit jobs
check_api_version(mdata)

Check warning on line 1355 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1355

Added line #L1355 was not covered by tests

submission = make_submission(

Check warning on line 1357 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1357

Added line #L1357 was not covered by tests
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()

Check warning on line 1370 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1370

Added line #L1370 was not covered by tests


def run_abacus_md(jdata, mdata):
Expand Down Expand Up @@ -1435,27 +1423,24 @@
run_tasks = [ii.replace(work_dir + "/", "") for ii in md_run_tasks]
# dlog.info("md_work_dir", work_dir)
# dlog.info("run_tasks",run_tasks)
api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit jobs
check_api_version(mdata)

Check warning on line 1428 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1428

Added line #L1428 was not covered by tests

submission = make_submission(

Check warning on line 1430 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1430

Added line #L1430 was not covered by tests
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()

Check warning on line 1443 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1443

Added line #L1443 was not covered by tests


def gen_init_bulk(args):
Expand Down
31 changes: 15 additions & 16 deletions dpgen/dispatcher/Dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,20 @@
"""
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
"API version below 1.0 is no longer supported. Please upgrade to version 1.0 or newer."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
machine,
resources,
commands=commands,
work_path=work_path,
run_tasks=run_tasks,
group_size=group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog=outlog,
errlog=errlog,
)
submission.run_submission()
submission = make_submission(

Check warning on line 144 in dpgen/dispatcher/Dispatcher.py

View check run for this annotation

Codecov / codecov/patch

dpgen/dispatcher/Dispatcher.py#L144

Added line #L144 was not covered by tests
machine,
resources,
commands=commands,
work_path=work_path,
run_tasks=run_tasks,
group_size=group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog=outlog,
errlog=errlog,
)
submission.run_submission()

Check warning on line 157 in dpgen/dispatcher/Dispatcher.py

View check run for this annotation

Codecov / codecov/patch

dpgen/dispatcher/Dispatcher.py#L157

Added line #L157 was not covered by tests
11 changes: 11 additions & 0 deletions dpgen/generator/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import re
import shutil

from packaging.version import Version

iter_format = "%06d"
task_format = "%02d"
log_iter_head = "iter " + iter_format + " task " + task_format + ": "
Expand Down Expand Up @@ -110,3 +112,12 @@
abs_file = os.path.abspath(file)
os.symlink(abs_file, os.path.join(task, os.path.basename(file)))
return


def check_api_version(mdata):
"""Check if the API version in mdata is at least 1.0."""
if Version(mdata.get("api_version", "1.0")) < Version("1.0"):
raise RuntimeError(

Check warning on line 120 in dpgen/generator/lib/utils.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/lib/utils.py#L120

Added line #L120 was not covered by tests
"API version below 1.0 is no longer supported. Please upgrade to version 1.0 or newer."
)
return
Loading