Skip to content

Simplify API version validation #1556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 69 additions & 84 deletions dpgen/data/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import dpdata
import numpy as np
from packaging.version import Version
from pymatgen.core import Structure
from pymatgen.io.vasp import Incar

Expand All @@ -28,7 +27,7 @@
make_abacus_scf_stru,
make_supercell_abacus,
)
from dpgen.generator.lib.utils import symlink_user_forward_files
from dpgen.generator.lib.utils import check_api_version, symlink_user_forward_files
from dpgen.generator.lib.vasp import incar_upper
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import load_file
Expand Down Expand Up @@ -1158,27 +1157,23 @@
# relax_run_tasks.append(ii)
run_tasks = [os.path.basename(ii) for ii in relax_run_tasks]

api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit the jobs
check_api_version(mdata)

Check warning on line 1161 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1161

Added line #L1161 was not covered by tests

submission = make_submission(

Check warning on line 1163 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1163

Added line #L1163 was not covered by tests
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()

Check warning on line 1176 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1176

Added line #L1176 was not covered by tests


def coll_abacus_md(jdata):
Expand Down Expand Up @@ -1298,27 +1293,23 @@
# relax_run_tasks.append(ii)
run_tasks = [os.path.basename(ii) for ii in relax_run_tasks]

api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit the jobs
check_api_version(mdata)

Check warning on line 1297 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1297

Added line #L1297 was not covered by tests

submission = make_submission(

Check warning on line 1299 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1299

Added line #L1299 was not covered by tests
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()

Check warning on line 1312 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1312

Added line #L1312 was not covered by tests


def run_vasp_md(jdata, mdata):
Expand Down Expand Up @@ -1359,27 +1350,24 @@
run_tasks = [ii.replace(work_dir + "/", "") for ii in md_run_tasks]
# dlog.info("md_work_dir", work_dir)
# dlog.info("run_tasks",run_tasks)
api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit the jobs
check_api_version(mdata)

Check warning on line 1355 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1355

Added line #L1355 was not covered by tests

submission = make_submission(

Check warning on line 1357 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1357

Added line #L1357 was not covered by tests
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()

Check warning on line 1370 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1370

Added line #L1370 was not covered by tests


def run_abacus_md(jdata, mdata):
Expand Down Expand Up @@ -1435,27 +1423,24 @@
run_tasks = [ii.replace(work_dir + "/", "") for ii in md_run_tasks]
# dlog.info("md_work_dir", work_dir)
# dlog.info("run_tasks",run_tasks)
api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit the jobs
check_api_version(mdata)

Check warning on line 1428 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1428

Added line #L1428 was not covered by tests

submission = make_submission(

Check warning on line 1430 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1430

Added line #L1430 was not covered by tests
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()

Check warning on line 1443 in dpgen/data/gen.py

View check run for this annotation

Codecov / codecov/patch

dpgen/data/gen.py#L1443

Added line #L1443 was not covered by tests


def gen_init_bulk(args):
Expand Down
10 changes: 10 additions & 0 deletions dpgen/generator/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import re
import shutil

from packaging.version import Version

iter_format = "%06d"
task_format = "%02d"
log_iter_head = "iter " + iter_format + " task " + task_format + ": "
Expand Down Expand Up @@ -110,3 +112,11 @@
abs_file = os.path.abspath(file)
os.symlink(abs_file, os.path.join(task, os.path.basename(file)))
return


def check_api_version(mdata):
if Version(mdata.get("api_version", "1.0")) < Version("1.0"):
raise RuntimeError(

Check warning on line 119 in dpgen/generator/lib/utils.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/lib/utils.py#L119

Added line #L119 was not covered by tests
"API version below 1.0 is no longer supported. Please upgrade to version 1.0 or newer."
)
return
114 changes: 52 additions & 62 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
)
from dpgen.generator.lib.siesta import make_siesta_input
from dpgen.generator.lib.utils import (
check_api_version,
create_path,
log_iter,
log_task,
Expand Down Expand Up @@ -874,31 +875,27 @@
except Exception:
train_group_size = 1

api_version = mdata.get("api_version", "1.0")

user_forward_files = mdata.get("train" + "_user_forward_files", [])
forward_files += [os.path.basename(file) for file in user_forward_files]
backward_files += mdata.get("train" + "_user_backward_files", [])
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["train_machine"],
mdata["train_resources"],
commands=commands,
work_path=work_path,
run_tasks=run_tasks,
group_size=train_group_size,
forward_common_files=trans_comm_data,
forward_files=forward_files,
backward_files=backward_files,
outlog="train.log",
errlog="train.log",
)
submission.run_submission()
### Submit the jobs
check_api_version(mdata)

submission = make_submission(
mdata["train_machine"],
mdata["train_resources"],
commands=commands,
work_path=work_path,
run_tasks=run_tasks,
group_size=train_group_size,
forward_common_files=trans_comm_data,
forward_files=forward_files,
backward_files=backward_files,
outlog="train.log",
errlog="train.log",
)
submission.run_submission()


def post_train(iter_index, jdata, mdata):
Expand Down Expand Up @@ -2090,31 +2087,28 @@
user_forward_files = mdata.get("model_devi" + "_user_forward_files", [])
forward_files += [os.path.basename(file) for file in user_forward_files]
backward_files += mdata.get("model_devi" + "_user_backward_files", [])
api_version = mdata.get("api_version", "1.0")
if len(run_tasks) == 0:
raise RuntimeError(
"run_tasks for model_devi should not be empty! Please check your files."
)
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["model_devi_machine"],
mdata["model_devi_resources"],
commands=commands,
work_path=work_path,
run_tasks=run_tasks,
group_size=model_devi_group_size,
forward_common_files=model_names,
forward_files=forward_files,
backward_files=backward_files,
outlog="model_devi.log",
errlog="model_devi.log",
)
submission.run_submission()
### Submit the jobs
check_api_version(mdata)

submission = make_submission(
mdata["model_devi_machine"],
mdata["model_devi_resources"],
commands=commands,
work_path=work_path,
run_tasks=run_tasks,
group_size=model_devi_group_size,
forward_common_files=model_names,
forward_files=forward_files,
backward_files=backward_files,
outlog="model_devi.log",
errlog="model_devi.log",
)
submission.run_submission()


def run_model_devi(iter_index, jdata, mdata):
Expand Down Expand Up @@ -3964,27 +3958,23 @@
forward_files += [os.path.basename(file) for file in user_forward_files]
backward_files += mdata.get("fp" + "_user_backward_files", [])

api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_path,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog=log_file,
errlog=log_file,
)
submission.run_submission()
### Submit the jobs
check_api_version(mdata)

Check warning on line 3962 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L3962

Added line #L3962 was not covered by tests

submission = make_submission(

Check warning on line 3964 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L3964

Added line #L3964 was not covered by tests
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_path,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog=log_file,
errlog=log_file,
)
submission.run_submission()

Check warning on line 3977 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L3977

Added line #L3977 was not covered by tests


def run_fp(iter_index, jdata, mdata):
Expand Down
Loading