File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change 28
28
InferenceEngineOnDemand ,
29
29
)
30
30
from tabpfn .model .loading import load_model_criterion_config
31
- from tabpfn .utils import infer_fp16_inference_mode
31
+ from tabpfn .utils import infer_device_and_type , infer_fp16_inference_mode
32
32
33
33
if TYPE_CHECKING :
34
34
import numpy as np
@@ -254,14 +254,15 @@ def check_cpu_warning(
254
254
X: The input data (NumPy array, Pandas DataFrame, or Torch Tensor)
255
255
"""
256
256
allow_cpu_override = os .getenv ("TABPFN_ALLOW_CPU_LARGE_DATASET" , "0" ) == "1"
257
+ device_mapped = infer_device_and_type (device )
257
258
258
259
# Determine number of samples
259
260
try :
260
261
num_samples = X .shape [0 ]
261
262
except AttributeError :
262
263
return
263
264
264
- if device == torch .device ("cpu" ) or device == "cpu" or "cpu" in device :
265
+ if torch .device (device_mapped ). type == "cpu" :
265
266
if num_samples > 1000 :
266
267
if not allow_cpu_override :
267
268
raise RuntimeError (
You can’t perform that action at this time.
0 commit comments