Skip to content

Linux Aarch64 build fix #354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
pip -vv install .
fi
env:
CPU_ONLY: 1
WITH_CUDA: "0"

- name: Lint with flake8
run: |
Expand All @@ -89,7 +89,7 @@ jobs:
- name: Run tests
run: pytest -v -s --durations=10
env:
CPU_ONLY: 1
WITH_CUDA: "0"
SKIP_TORCH_COMPILE: ${{ runner.os == 'Windows' && 'true' || 'false' }}
OMP_PREFIX: ${{ runner.os == 'macOS' && '/Users/runner/miniconda3/envs/test' || '' }}
CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }}
Expand Down
57 changes: 41 additions & 16 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,36 @@
import subprocess
from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, include_paths, CppExtension
from torch.utils.cpp_extension import (
BuildExtension,
CUDAExtension,
include_paths,
CppExtension,
)
import os
import sys

is_windows = sys.platform == 'win32'
is_windows = sys.platform == "win32"

try:
version = (
subprocess.check_output(["git", "describe", "--abbrev=0", "--tags"])
.strip()
.decode("utf-8")
)
except:
except Exception:
print("Failed to retrieve the current version, defaulting to 0")
version = "0"
# If CPU_ONLY is defined
force_cpu_only = os.environ.get("CPU_ONLY", None) is not None
use_cuda = torch.cuda._is_compiled() if not force_cpu_only else False

# If WITH_CUDA is defined
if os.environ.get("WITH_CUDA", "0") == "1":
use_cuda = True
else:
use_cuda = torch.cuda._is_compiled()


def set_torch_cuda_arch_list():
""" Set the CUDA arch list according to the architectures the current torch installation was compiled for.
"""Set the CUDA arch list according to the architectures the current torch installation was compiled for.
This function is a no-op if the environment variable TORCH_CUDA_ARCH_LIST is already set or if torch was not compiled with CUDA support.
"""
if not os.environ.get("TORCH_CUDA_ARCH_LIST"):
Expand All @@ -35,20 +45,24 @@ def set_torch_cuda_arch_list():
formatted_versions += "+PTX"
os.environ["TORCH_CUDA_ARCH_LIST"] = formatted_versions


set_torch_cuda_arch_list()

extension_root= os.path.join("torchmdnet", "extensions")
neighbor_sources=["neighbors_cpu.cpp"]
extension_root = os.path.join("torchmdnet", "extensions")
neighbor_sources = ["neighbors_cpu.cpp"]
if use_cuda:
neighbor_sources.append("neighbors_cuda.cu")
neighbor_sources = [os.path.join(extension_root, "neighbors", source) for source in neighbor_sources]
neighbor_sources = [
os.path.join(extension_root, "neighbors", source) for source in neighbor_sources
]

ExtensionType = CppExtension if not use_cuda else CUDAExtension
extensions = ExtensionType(
name='torchmdnet.extensions.torchmdnet_extensions',
sources=[os.path.join(extension_root, "torchmdnet_extensions.cpp")] + neighbor_sources,
name="torchmdnet.extensions.torchmdnet_extensions",
sources=[os.path.join(extension_root, "torchmdnet_extensions.cpp")]
+ neighbor_sources,
include_dirs=include_paths(),
define_macros=[('WITH_CUDA', 1)] if use_cuda else [],
define_macros=[("WITH_CUDA", 1)] if use_cuda else [],
)

if __name__ == "__main__":
Expand All @@ -58,8 +72,19 @@ def set_torch_cuda_arch_list():
packages=find_packages(),
ext_modules=[extensions],
cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)},
"build_ext": BuildExtension.with_options(
no_python_abi_suffix=True, use_ninja=False
)
},
include_package_data=True,
entry_points={"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]},
package_data={"torchmdnet": ["extensions/torchmdnet_extensions.so"] if not is_windows else ["extensions/torchmdnet_extensions.dll"]},
entry_points={
"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]
},
package_data={
"torchmdnet": (
["extensions/torchmdnet_extensions.so"]
if not is_windows
else ["extensions/torchmdnet_extensions.dll"]
)
},
)
Loading