We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2e91ce8 commit ced2562Copy full SHA for ced2562
deepmd/tf/lmp.py
@@ -6,6 +6,9 @@
6
from importlib import (
7
import_module,
8
)
9
+from importlib.util import (
10
+ find_spec,
11
+)
12
from pathlib import (
13
Path,
14
@@ -77,6 +80,11 @@ def get_library_path(module: str, filename: str) -> List[str]:
77
80
78
81
tf_dir = tf.sysconfig.get_lib()
79
82
op_dir = str(SHARED_LIB_DIR)
83
+pt_spec = find_spec("torch")
84
+if pt_spec is not None:
85
+ pt_dir = pt_spec.submodule_search_locations[0]
86
+else:
87
+ pt_dir = None
88
89
90
cuda_library_paths = []
@@ -106,6 +114,7 @@ def get_library_path(module: str, filename: str) -> List[str]:
106
114
[
107
115
os.environ.get(lib_env),
108
116
tf_dir,
117
+ pt_dir,
109
118
os.path.join(tf_dir, "python"),
110
119
op_dir,
111
120
]
0 commit comments