5
5
import subprocess
6
6
from setuptools import setup , find_packages
7
7
import torch
8
- from torch .utils .cpp_extension import BuildExtension , CUDAExtension , include_paths , CppExtension
8
+ from torch .utils .cpp_extension import (
9
+ BuildExtension ,
10
+ CUDAExtension ,
11
+ include_paths ,
12
+ CppExtension ,
13
+ )
9
14
import os
10
15
import sys
11
16
12
- is_windows = sys .platform == ' win32'
17
+ is_windows = sys .platform == " win32"
13
18
14
19
try :
15
20
version = (
16
21
subprocess .check_output (["git" , "describe" , "--abbrev=0" , "--tags" ])
17
22
.strip ()
18
23
.decode ("utf-8" )
19
24
)
20
- except :
25
+ except Exception :
21
26
print ("Failed to retrieve the current version, defaulting to 0" )
22
27
version = "0"
23
- # If CPU_ONLY is defined
24
- force_cpu_only = os .environ .get ("CPU_ONLY" , None ) is not None
25
- use_cuda = torch .cuda ._is_compiled () if not force_cpu_only else False
28
+
29
+ # If WITH_CUDA is defined
30
+ if os .environ .get ("WITH_CUDA" , "0" ) == "1" :
31
+ use_cuda = True
32
+ else :
33
+ use_cuda = torch .cuda ._is_compiled ()
34
+
35
+
26
36
def set_torch_cuda_arch_list ():
27
- """ Set the CUDA arch list according to the architectures the current torch installation was compiled for.
37
+ """Set the CUDA arch list according to the architectures the current torch installation was compiled for.
28
38
This function is a no-op if the environment variable TORCH_CUDA_ARCH_LIST is already set or if torch was not compiled with CUDA support.
29
39
"""
30
40
if not os .environ .get ("TORCH_CUDA_ARCH_LIST" ):
@@ -35,20 +45,24 @@ def set_torch_cuda_arch_list():
35
45
formatted_versions += "+PTX"
36
46
os .environ ["TORCH_CUDA_ARCH_LIST" ] = formatted_versions
37
47
48
+
38
49
set_torch_cuda_arch_list ()
39
50
40
- extension_root = os .path .join ("torchmdnet" , "extensions" )
41
- neighbor_sources = ["neighbors_cpu.cpp" ]
51
+ extension_root = os .path .join ("torchmdnet" , "extensions" )
52
+ neighbor_sources = ["neighbors_cpu.cpp" ]
42
53
if use_cuda :
43
54
neighbor_sources .append ("neighbors_cuda.cu" )
44
- neighbor_sources = [os .path .join (extension_root , "neighbors" , source ) for source in neighbor_sources ]
55
+ neighbor_sources = [
56
+ os .path .join (extension_root , "neighbors" , source ) for source in neighbor_sources
57
+ ]
45
58
46
59
ExtensionType = CppExtension if not use_cuda else CUDAExtension
47
60
extensions = ExtensionType (
48
- name = 'torchmdnet.extensions.torchmdnet_extensions' ,
49
- sources = [os .path .join (extension_root , "torchmdnet_extensions.cpp" )] + neighbor_sources ,
61
+ name = "torchmdnet.extensions.torchmdnet_extensions" ,
62
+ sources = [os .path .join (extension_root , "torchmdnet_extensions.cpp" )]
63
+ + neighbor_sources ,
50
64
include_dirs = include_paths (),
51
- define_macros = [(' WITH_CUDA' , 1 )] if use_cuda else [],
65
+ define_macros = [(" WITH_CUDA" , 1 )] if use_cuda else [],
52
66
)
53
67
54
68
if __name__ == "__main__" :
@@ -58,8 +72,19 @@ def set_torch_cuda_arch_list():
58
72
packages = find_packages (),
59
73
ext_modules = [extensions ],
60
74
cmdclass = {
61
- 'build_ext' : BuildExtension .with_options (no_python_abi_suffix = True , use_ninja = False )},
75
+ "build_ext" : BuildExtension .with_options (
76
+ no_python_abi_suffix = True , use_ninja = False
77
+ )
78
+ },
62
79
include_package_data = True ,
63
- entry_points = {"console_scripts" : ["torchmd-train = torchmdnet.scripts.train:main" ]},
64
- package_data = {"torchmdnet" : ["extensions/torchmdnet_extensions.so" ] if not is_windows else ["extensions/torchmdnet_extensions.dll" ]},
80
+ entry_points = {
81
+ "console_scripts" : ["torchmd-train = torchmdnet.scripts.train:main" ]
82
+ },
83
+ package_data = {
84
+ "torchmdnet" : (
85
+ ["extensions/torchmdnet_extensions.so" ]
86
+ if not is_windows
87
+ else ["extensions/torchmdnet_extensions.dll" ]
88
+ )
89
+ },
65
90
)
0 commit comments