Skip to content

Commit 1deecd1

Browse files
authored
Merge pull request #354 from torchmd/osx_arm64_build_fix
Linux Aarch64 build fix
2 parents 8374e96 + 35dbb90 commit 1deecd1

File tree

2 files changed

+43
-18
lines changed

2 files changed

+43
-18
lines changed

.github/workflows/CI.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ jobs:
7777
pip -vv install .
7878
fi
7979
env:
80-
CPU_ONLY: 1
80+
WITH_CUDA: "0"
8181

8282
- name: Lint with flake8
8383
run: |
@@ -89,7 +89,7 @@ jobs:
8989
- name: Run tests
9090
run: pytest -v -s --durations=10
9191
env:
92-
CPU_ONLY: 1
92+
WITH_CUDA: "0"
9393
SKIP_TORCH_COMPILE: ${{ runner.os == 'Windows' && 'true' || 'false' }}
9494
OMP_PREFIX: ${{ runner.os == 'macOS' && '/Users/runner/miniconda3/envs/test' || '' }}
9595
CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }}

setup.py

+41-16
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,36 @@
55
import subprocess
66
from setuptools import setup, find_packages
77
import torch
8-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, include_paths, CppExtension
8+
from torch.utils.cpp_extension import (
9+
BuildExtension,
10+
CUDAExtension,
11+
include_paths,
12+
CppExtension,
13+
)
914
import os
1015
import sys
1116

12-
is_windows = sys.platform == 'win32'
17+
is_windows = sys.platform == "win32"
1318

1419
try:
1520
version = (
1621
subprocess.check_output(["git", "describe", "--abbrev=0", "--tags"])
1722
.strip()
1823
.decode("utf-8")
1924
)
20-
except:
25+
except Exception:
2126
print("Failed to retrieve the current version, defaulting to 0")
2227
version = "0"
23-
# If CPU_ONLY is defined
24-
force_cpu_only = os.environ.get("CPU_ONLY", None) is not None
25-
use_cuda = torch.cuda._is_compiled() if not force_cpu_only else False
28+
29+
# If WITH_CUDA is defined
30+
if os.environ.get("WITH_CUDA", "0") == "1":
31+
use_cuda = True
32+
else:
33+
use_cuda = torch.cuda._is_compiled()
34+
35+
2636
def set_torch_cuda_arch_list():
27-
""" Set the CUDA arch list according to the architectures the current torch installation was compiled for.
37+
"""Set the CUDA arch list according to the architectures the current torch installation was compiled for.
2838
This function is a no-op if the environment variable TORCH_CUDA_ARCH_LIST is already set or if torch was not compiled with CUDA support.
2939
"""
3040
if not os.environ.get("TORCH_CUDA_ARCH_LIST"):
@@ -35,20 +45,24 @@ def set_torch_cuda_arch_list():
3545
formatted_versions += "+PTX"
3646
os.environ["TORCH_CUDA_ARCH_LIST"] = formatted_versions
3747

48+
3849
set_torch_cuda_arch_list()
3950

40-
extension_root= os.path.join("torchmdnet", "extensions")
41-
neighbor_sources=["neighbors_cpu.cpp"]
51+
extension_root = os.path.join("torchmdnet", "extensions")
52+
neighbor_sources = ["neighbors_cpu.cpp"]
4253
if use_cuda:
4354
neighbor_sources.append("neighbors_cuda.cu")
44-
neighbor_sources = [os.path.join(extension_root, "neighbors", source) for source in neighbor_sources]
55+
neighbor_sources = [
56+
os.path.join(extension_root, "neighbors", source) for source in neighbor_sources
57+
]
4558

4659
ExtensionType = CppExtension if not use_cuda else CUDAExtension
4760
extensions = ExtensionType(
48-
name='torchmdnet.extensions.torchmdnet_extensions',
49-
sources=[os.path.join(extension_root, "torchmdnet_extensions.cpp")] + neighbor_sources,
61+
name="torchmdnet.extensions.torchmdnet_extensions",
62+
sources=[os.path.join(extension_root, "torchmdnet_extensions.cpp")]
63+
+ neighbor_sources,
5064
include_dirs=include_paths(),
51-
define_macros=[('WITH_CUDA', 1)] if use_cuda else [],
65+
define_macros=[("WITH_CUDA", 1)] if use_cuda else [],
5266
)
5367

5468
if __name__ == "__main__":
@@ -58,8 +72,19 @@ def set_torch_cuda_arch_list():
5872
packages=find_packages(),
5973
ext_modules=[extensions],
6074
cmdclass={
61-
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)},
75+
"build_ext": BuildExtension.with_options(
76+
no_python_abi_suffix=True, use_ninja=False
77+
)
78+
},
6279
include_package_data=True,
63-
entry_points={"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]},
64-
package_data={"torchmdnet": ["extensions/torchmdnet_extensions.so"] if not is_windows else ["extensions/torchmdnet_extensions.dll"]},
80+
entry_points={
81+
"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]
82+
},
83+
package_data={
84+
"torchmdnet": (
85+
["extensions/torchmdnet_extensions.so"]
86+
if not is_windows
87+
else ["extensions/torchmdnet_extensions.dll"]
88+
)
89+
},
6590
)

0 commit comments

Comments
 (0)