Skip to content

feat(pt): add more information to summary and error message of loading library #3895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion deepmd/pt/cxx_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
import platform

import torch
from packaging.version import (
Version,
)

from deepmd.env import (
GLOBAL_CONFIG,
SHARED_LIB_DIR,
)

Expand Down Expand Up @@ -31,7 +35,59 @@
module_file = (SHARED_LIB_DIR / (prefix + module_name)).with_suffix(ext).resolve()

if module_file.is_file():
torch.ops.load_library(module_file)
try:
torch.ops.load_library(module_file)
except OSError as e:

Check warning on line 40 in deepmd/pt/cxx_op.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/cxx_op.py#L40

Added line #L40 was not covered by tests
# check: CXX11_ABI_FLAG; version
# from our op
PT_VERSION = GLOBAL_CONFIG["pt_version"]
PT_CXX11_ABI_FLAG = int(GLOBAL_CONFIG["pt_cxx11_abi_flag"])

Check warning on line 44 in deepmd/pt/cxx_op.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/cxx_op.py#L43-L44

Added lines #L43 - L44 were not covered by tests
# from torch
# strip the local version
pt_py_version = Version(torch.__version__).public
pt_cxx11_abi_flag = int(torch.compiled_with_cxx11_abi())

Check warning on line 48 in deepmd/pt/cxx_op.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/cxx_op.py#L47-L48

Added lines #L47 - L48 were not covered by tests

if PT_CXX11_ABI_FLAG != pt_cxx11_abi_flag:
raise RuntimeError(

Check warning on line 51 in deepmd/pt/cxx_op.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/cxx_op.py#L50-L51

Added lines #L50 - L51 were not covered by tests
"This deepmd-kit package was compiled with "
"CXX11_ABI_FLAG=%d, but PyTorch runtime was compiled "
"with CXX11_ABI_FLAG=%d. These two library ABIs are "
"incompatible and thus an error is raised when loading %s. "
"You need to rebuild deepmd-kit against this PyTorch "
"runtime."
% (
PT_CXX11_ABI_FLAG,
pt_cxx11_abi_flag,
module_name,
)
) from e

# different versions may cause incompatibility, see TF
if PT_VERSION != pt_py_version:
raise RuntimeError(

Check warning on line 67 in deepmd/pt/cxx_op.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/cxx_op.py#L66-L67

Added lines #L66 - L67 were not covered by tests
"The version of PyTorch used to compile this "
f"deepmd-kit package is {PT_VERSION}, but the version of PyTorch "
f"runtime you are using is {pt_py_version}. These two versions are "
f"incompatible and thus an error is raised when loading {module_name}. "
f"You need to install PyTorch {PT_VERSION}, or rebuild deepmd-kit "
f"against PyTorch {pt_py_version}.\nIf you are using a wheel from "
"PyPI, you may consider to install deepmd-kit execuating "
"`DP_ENABLE_PYTORCH=1 pip install deepmd-kit --no-binary deepmd-kit` "
"instead."
) from e
error_message = (

Check warning on line 78 in deepmd/pt/cxx_op.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/cxx_op.py#L78

Added line #L78 was not covered by tests
"This deepmd-kit package is inconsitent with PyTorch "
f"Runtime, thus an error is raised when loading {module_name}. "
"You need to rebuild deepmd-kit against this PyTorch "
"runtime."
)
if PT_CXX11_ABI_FLAG == 1:

Check warning on line 84 in deepmd/pt/cxx_op.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/cxx_op.py#L84

Added line #L84 was not covered by tests
# #1791
error_message += (

Check warning on line 86 in deepmd/pt/cxx_op.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/cxx_op.py#L86

Added line #L86 was not covered by tests
"\nWARNING: devtoolset on RHEL6 and RHEL7 does not support _GLIBCXX_USE_CXX11_ABI=1. "
"See https://bugzilla.redhat.com/show_bug.cgi?id=1546704"
)
raise RuntimeError(error_message) from e

Check warning on line 90 in deepmd/pt/cxx_op.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/cxx_op.py#L90

Added line #L90 was not covered by tests
return True
return False

Expand Down
12 changes: 12 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from deepmd import (
__version__,
)
from deepmd.env import (
GLOBAL_CONFIG,
)
from deepmd.loggers.loggers import (
set_log_handles,
)
Expand Down Expand Up @@ -199,10 +202,19 @@

def get_backend_info(self) -> dict:
"""Get backend information."""
if ENABLE_CUSTOMIZED_OP:
op_info = {
"build with PT ver": GLOBAL_CONFIG["pt_version"],
"build with PT inc": GLOBAL_CONFIG["pt_include_dir"].replace(";", "\n"),
"build with PT lib": GLOBAL_CONFIG["pt_libs"].replace(";", "\n"),
}
else:
op_info = None

Check warning on line 212 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L212

Added line #L212 was not covered by tests
return {
"Backend": "PyTorch",
"PT ver": f"v{torch.__version__}-g{torch.version.git_version[:11]}",
"Enable custom OP": ENABLE_CUSTOMIZED_OP,
**op_info,
}


Expand Down
4 changes: 4 additions & 0 deletions source/config/run_config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,9 @@ TF_INCLUDE_DIR = @TensorFlow_INCLUDE_DIRS@
TF_LIBS = @TensorFlow_LIBRARY_PATH@
TF_VERSION = @TENSORFLOW_VERSION@
TF_CXX11_ABI_FLAG = @OP_CXX_ABI@
PT_INCLUDE_DIR = @TORCH_INCLUDE_DIRS@
PT_LIBS = @PyTorch_LIBRARY_PATH@
PT_VERSIOn = @Torch_VERSION@
PT_CXX11_ABI_FLAG = @OP_CXX_ABI_PT@
MODEL_VERSION=@MODEL_VERSION@
DP_VARIANT=@DP_VARIANT@