Skip to content

Commit b60c059

Browse files
committed
cleanup the duplicate logic of WITH_CUDA and ACCELERATOR
1 parent 204ee7a commit b60c059

File tree

5 files changed

+9
-20
lines changed

5 files changed

+9
-20
lines changed

.github/workflows/docs.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ jobs:
2626

2727
- name: Install docs dependencies
2828
run: |
29-
pip install -vv .
29+
pip install -vv . --no-deps
3030
pip install -r docs/requirements.txt
3131
env:
32-
WITH_CUDA: "0"
32+
ACCELERATOR: "cpu"
3333

3434
- name: Build Sphinx Documentation
3535
run: |

.github/workflows/publish.yml

-2
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,13 @@ jobs:
5959
env:
6060
ACCELERATOR: ${{ matrix.accelerator }}
6161
CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }}
62-
WITH_CUDA: ${{ matrix.accelerator != 'cpu' && '1' || '0' }}
6362

6463
- name: Build wheels
6564
if: matrix.os == 'windows-2019'
6665
shell: cmd # Use cmd on Windows to avoid bash environment taking priority over MSVC variables
6766
run: python -m cibuildwheel --output-dir wheelhouse
6867
env:
6968
ACCELERATOR: ${{ matrix.accelerator }}
70-
WITH_CUDA: ${{ matrix.accelerator != 'cpu' && '1' || '0' }}
7169
DISTUTILS_USE_SDK: "1" # Windows requires this to use vc for building
7270
SKIP_TORCH_COMPILE: "true"
7371

.github/workflows/test.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
python -m build
7474
pip install dist/*.whl
7575
env:
76-
WITH_CUDA: "0"
76+
ACCELERATOR: "cpu"
7777

7878
# - name: Install nnpops
7979
# if: matrix.os == 'ubuntu-latest' || matrix.os == 'macos-latest'
@@ -85,7 +85,7 @@ jobs:
8585
- name: Run tests
8686
run: pytest -v -s --durations=10
8787
env:
88-
WITH_CUDA: "0"
88+
ACCELERATOR: "cpu"
8989
SKIP_TORCH_COMPILE: ${{ runner.os == 'Windows' && 'true' || 'false' }}
9090
OMP_PREFIX: ${{ runner.os == 'macOS' && '/Users/runner/miniconda3/envs/test' || '' }}
9191
CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }}

cibuildwheel_support/before_all_linux.sh

-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ if [ "$ACCELERATOR" == "cu118" ]; then
3030
ln -s /opt/rh/gcc-toolset-11/root/usr/bin/g++ /usr/local/cuda/bin/g++
3131

3232
export CUDA_HOME="/usr/local/cuda"
33-
export WITH_CUDA=1
34-
3533

3634
# Configure pip to use PyTorch extra-index-url for CUDA 11.8
3735
mkdir -p $HOME/.config/pip

setup.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,9 @@
99
import platform
1010

1111

12-
# If WITH_CUDA is defined
13-
env_with_cuda = os.getenv("WITH_CUDA")
14-
if env_with_cuda is not None:
15-
if env_with_cuda not in ("0", "1"):
16-
raise ValueError(
17-
f"WITH_CUDA environment variable got invalid value {env_with_cuda}. Expected '0' or '1'"
18-
)
19-
use_cuda = env_with_cuda == "1"
20-
else:
21-
use_cuda = torch.cuda._is_compiled()
12+
use_cuda = (
13+
os.environ.get("ACCELERATOR", "").startswith("cu") or torch.cuda._is_compiled()
14+
)
2215

2316

2417
def set_torch_cuda_arch_list():
@@ -56,8 +49,8 @@ def set_torch_cuda_arch_list():
5649
]
5750

5851
extra_deps = []
59-
if use_cuda:
60-
cuda_ver = os.getenv("ACCELERATOR", "")[2:4]
52+
if os.getenv("ACCELERATOR", "").startswith("cu"):
53+
cuda_ver = os.getenv("ACCELERATOR")[2:4]
6154
extra_deps = [f"nvidia-cuda-runtime-cu{cuda_ver}"]
6255

6356
ExtensionType = CppExtension if not use_cuda else CUDAExtension

0 commit comments

Comments
 (0)