Skip to content

Commit d778401

Browse files
authored
fix device bug on GPU (#254)
1 parent a215c5b commit d778401

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/tabpfn/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
InferenceEngineOnDemand,
2929
)
3030
from tabpfn.model.loading import load_model_criterion_config
31-
from tabpfn.utils import infer_fp16_inference_mode
31+
from tabpfn.utils import infer_device_and_type, infer_fp16_inference_mode
3232

3333
if TYPE_CHECKING:
3434
import numpy as np
@@ -254,14 +254,15 @@ def check_cpu_warning(
254254
X: The input data (NumPy array, Pandas DataFrame, or Torch Tensor)
255255
"""
256256
allow_cpu_override = os.getenv("TABPFN_ALLOW_CPU_LARGE_DATASET", "0") == "1"
257+
device_mapped = infer_device_and_type(device)
257258

258259
# Determine number of samples
259260
try:
260261
num_samples = X.shape[0]
261262
except AttributeError:
262263
return
263264

264-
if device == torch.device("cpu") or device == "cpu" or "cpu" in device:
265+
if torch.device(device_mapped).type == "cpu":
265266
if num_samples > 1000:
266267
if not allow_cpu_override:
267268
raise RuntimeError(

0 commit comments

Comments
 (0)