Skip to content

Commit 0a3ba43

Browse files
Rework allow_cpu_override to be usable without environment variables (#275)
Co-authored-by: noahho <[email protected]>
1 parent d778401 commit 0a3ba43

File tree

4 files changed

+61
-23
lines changed

4 files changed

+61
-23
lines changed

src/tabpfn/base.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -245,15 +245,25 @@ def create_inference_engine( # noqa: PLR0913
245245

246246

247247
def check_cpu_warning(
248-
device: str | torch.device, X: np.ndarray | torch.Tensor | pd.DataFrame
248+
device: str | torch.device,
249+
X: np.ndarray | torch.Tensor | pd.DataFrame,
250+
*,
251+
allow_cpu_override: bool = False,
249252
) -> None:
250253
"""Check if using CPU with large datasets and warn or error appropriately.
251254
252255
Args:
253256
device: The torch device being used
254257
X: The input data (NumPy array, Pandas DataFrame, or Torch Tensor)
258+
allow_cpu_override: If True, allow CPU usage with large datasets.
255259
"""
256-
allow_cpu_override = os.getenv("TABPFN_ALLOW_CPU_LARGE_DATASET", "0") == "1"
260+
allow_cpu_override = allow_cpu_override or (
261+
os.getenv("TABPFN_ALLOW_CPU_LARGE_DATASET", "0") == "1"
262+
)
263+
264+
if allow_cpu_override:
265+
return
266+
257267
device_mapped = infer_device_and_type(device)
258268

259269
# Determine number of samples
@@ -264,16 +274,16 @@ def check_cpu_warning(
264274

265275
if torch.device(device_mapped).type == "cpu":
266276
if num_samples > 1000:
267-
if not allow_cpu_override:
268-
raise RuntimeError(
269-
"Running on CPU with more than 1000 samples is not allowed "
270-
"by default due to slow performance.\n"
271-
"To override this behavior, set the environment variable "
272-
"TABPFN_ALLOW_CPU_LARGE_DATASET=1.\n"
273-
"Alternatively, consider using a GPU or the tabpfn-client API: "
274-
"https://github.com/PriorLabs/tabpfn-client"
275-
)
276-
elif num_samples > 200:
277+
raise RuntimeError(
278+
"Running on CPU with more than 1000 samples is not allowed "
279+
"by default due to slow performance.\n"
280+
"To override this behavior, set the environment variable "
281+
"TABPFN_ALLOW_CPU_LARGE_DATASET=1 or "
282+
"set ignore_pretraining_limits=True.\n"
283+
"Alternatively, consider using a GPU or the tabpfn-client API: "
284+
"https://github.com/PriorLabs/tabpfn-client"
285+
)
286+
if num_samples > 200:
277287
warnings.warn(
278288
"Running on CPU with more than 200 samples may be slow.\n"
279289
"Consider using a GPU or the tabpfn-client API: "

src/tabpfn/classifier.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ def __init__( # noqa: PLR0913
228228
pre-training range.
229229
230230
- If `True`, the model will not raise an error if the input data is
231-
outside the pre-training range.
231+
outside the pre-training range. Also supresses error when using
232+
the model with more than 1000 samples on CPU.
232233
- If `False`, you can use the model outside the pre-training range, but
233234
the model could perform worse.
234235
@@ -428,7 +429,9 @@ def fit(self, X: XType, y: YType) -> Self:
428429
ignore_pretraining_limits=self.ignore_pretraining_limits,
429430
)
430431

431-
check_cpu_warning(self.device, X)
432+
check_cpu_warning(
433+
self.device, X, allow_cpu_override=self.ignore_pretraining_limits
434+
)
432435

433436
if feature_names_in is not None:
434437
self.feature_names_in_ = feature_names_in

src/tabpfn/regressor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ def __init__( # noqa: PLR0913
251251
pre-training range.
252252
253253
- If `True`, the model will not raise an error if the input data is
254-
outside the pre-training range.
254+
outside the pre-training range. Also supresses error when using
255+
the model with more than 1000 samples on CPU.
255256
- If `False`, you can use the model outside the pre-training range, but
256257
the model could perform worse.
257258
@@ -456,7 +457,9 @@ def fit(self, X: XType, y: YType) -> Self:
456457
ignore_pretraining_limits=self.ignore_pretraining_limits,
457458
)
458459
assert isinstance(X, np.ndarray)
459-
check_cpu_warning(self.device, X)
460+
check_cpu_warning(
461+
self.device, X, allow_cpu_override=self.ignore_pretraining_limits
462+
)
460463

461464
if feature_names_in is not None:
462465
self.feature_names_in_ = feature_names_in

tests/test_regressor_interface.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -355,13 +355,35 @@ def test_cpu_large_dataset_warning():
355355
with pytest.warns(
356356
UserWarning, match="Running on CPU with more than 200 samples may be slow"
357357
):
358-
# Set environment variable to allow large datasets to avoid RuntimeError
359-
os.environ["TABPFN_ALLOW_CPU_LARGE_DATASET"] = "1"
360-
try:
361-
model.fit(X_large, y_large)
362-
finally:
363-
# Clean up environment variable
364-
os.environ.pop("TABPFN_ALLOW_CPU_LARGE_DATASET")
358+
model.fit(X_large, y_large)
359+
360+
361+
def test_cpu_large_dataset_warning_override():
362+
"""Test that runtime error is raised when using CPU with large datasets
363+
and that we can disable the error with ignore_pretraining_limits.
364+
"""
365+
rng = np.random.default_rng(seed=42)
366+
X_large = rng.random((1001, 10))
367+
y_large = rng.random(1001)
368+
369+
model = TabPFNRegressor(device="cpu")
370+
with pytest.raises(
371+
RuntimeError, match="Running on CPU with more than 1000 samples is not"
372+
):
373+
model.fit(X_large, y_large)
374+
375+
# -- Test overrides
376+
model = TabPFNRegressor(device="cpu", ignore_pretraining_limits=True)
377+
model.fit(X_large, y_large)
378+
379+
# Set environment variable to allow large datasets to avoid RuntimeError
380+
os.environ["TABPFN_ALLOW_CPU_LARGE_DATASET"] = "1"
381+
try:
382+
model = TabPFNRegressor(device="cpu", ignore_pretraining_limits=False)
383+
model.fit(X_large, y_large)
384+
finally:
385+
# Clean up environment variable
386+
os.environ.pop("TABPFN_ALLOW_CPU_LARGE_DATASET")
365387

366388

367389
def test_cpu_large_dataset_error():

0 commit comments

Comments
 (0)