Skip to content

Commit c8cbc6f

Browse files
TabPFNRegressor preprocessing fails on bigger datasets fix (#255)
Co-authored-by: noahho <[email protected]>
1 parent a83b20c commit c8cbc6f

File tree

5 files changed

+116
-33
lines changed

5 files changed

+116
-33
lines changed

scripts/get_max_dependencies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@ def main() -> None:
3737

3838

3939
if __name__ == "__main__":
40-
main()
40+
main()

scripts/get_min_dependencies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ def main() -> None:
3131

3232

3333
if __name__ == "__main__":
34-
main()
34+
main()

src/tabpfn/model/loading.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dataclasses import dataclass
1414
from enum import Enum
1515
from pathlib import Path
16-
from typing import Literal, overload
16+
from typing import Literal, cast, overload
1717
from urllib.error import URLError
1818

1919
import torch
@@ -101,7 +101,7 @@ def _get_model_source(version: ModelVersion, model_type: ModelType) -> ModelSour
101101
)
102102

103103

104-
def _suppress_hf_token_warning():
104+
def _suppress_hf_token_warning() -> None:
105105
"""Suppress warning about missing HuggingFace token."""
106106
import warnings
107107

@@ -287,7 +287,7 @@ def download_all_models(to: Path) -> None:
287287
download_model(
288288
to=to / ckpt_name,
289289
version="v2",
290-
which=model_type,
290+
which=cast(Literal["classifier", "regressor"], model_type),
291291
model_name=ckpt_name,
292292
)
293293

@@ -370,31 +370,6 @@ def load_model_criterion_config(
370370
) -> tuple[PerFeatureTransformer, FullSupportBarDistribution, InferenceConfig]: ...
371371

372372

373-
def resolve_model_path(
374-
model_path: None | str | Path,
375-
which: Literal["regressor", "classifier"],
376-
version: Literal["v2"] = "v2",
377-
) -> tuple[Path, Path, str, str]:
378-
if model_path is None:
379-
USER_TABPFN_CACHE_DIR_LOCATION = os.environ.get("TABPFN_MODEL_CACHE_DIR", "")
380-
if USER_TABPFN_CACHE_DIR_LOCATION.strip() != "":
381-
model_dir = Path(USER_TABPFN_CACHE_DIR_LOCATION)
382-
else:
383-
model_dir = _user_cache_dir(platform=sys.platform, appname="tabpfn")
384-
385-
model_name = f"tabpfn-{version}-{which}.ckpt"
386-
model_path = model_dir / model_name
387-
else:
388-
if not isinstance(model_path, (str, Path)):
389-
raise ValueError(f"Invalid model_path: {model_path}")
390-
391-
model_path = Path(model_path)
392-
model_dir = model_path.parent
393-
model_name = model_path.name
394-
395-
return model_path, model_dir, model_name, which
396-
397-
398373
def load_model_criterion_config(
399374
model_path: None | str | Path,
400375
*,
@@ -452,7 +427,7 @@ def load_model_criterion_config(
452427
res = download_model(
453428
model_path,
454429
version=version,
455-
which=which,
430+
which=cast(Literal["classifier", "regressor"], which),
456431
model_name=model_name,
457432
)
458433
if res != "ok":
@@ -478,6 +453,31 @@ def load_model_criterion_config(
478453
return loaded_model, criterion, config
479454

480455

456+
def resolve_model_path(
457+
model_path: None | str | Path,
458+
which: Literal["regressor", "classifier"],
459+
version: Literal["v2"] = "v2",
460+
) -> tuple[Path, Path, str, str]:
461+
if model_path is None:
462+
USER_TABPFN_CACHE_DIR_LOCATION = os.environ.get("TABPFN_MODEL_CACHE_DIR", "")
463+
if USER_TABPFN_CACHE_DIR_LOCATION.strip() != "":
464+
model_dir = Path(USER_TABPFN_CACHE_DIR_LOCATION)
465+
else:
466+
model_dir = _user_cache_dir(platform=sys.platform, appname="tabpfn")
467+
468+
model_name = f"tabpfn-{version}-{which}.ckpt"
469+
model_path = model_dir / model_name
470+
else:
471+
if not isinstance(model_path, (str, Path)):
472+
raise ValueError(f"Invalid model_path: {model_path}")
473+
474+
model_path = Path(model_path)
475+
model_dir = model_path.parent
476+
model_name = model_path.name
477+
478+
return model_path, model_dir, model_name, which
479+
480+
481481
def get_loss_criterion(
482482
config: InferenceConfig,
483483
) -> nn.BCEWithLogitsLoss | nn.CrossEntropyLoss | FullSupportBarDistribution:

src/tabpfn/model/preprocessing.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,57 @@ def transform(self, X: torch.Tensor | np.ndarray) -> np.ndarray:
8282
return X # type: ignore
8383

8484

85+
class AdaptiveQuantileTransformer(QuantileTransformer):
86+
"""A QuantileTransformer that automatically adapts the 'n_quantiles' parameter
87+
based on the number of samples provided during the 'fit' method.
88+
89+
This prevents errors that occur when the requested 'n_quantiles' is
90+
greater than the number of available samples in the input data (X).
91+
This situation can arises because we first initialize the transformer
92+
based on total samples and then subsample.
93+
"""
94+
95+
def __init__(self, *, n_quantiles: int = 1000, **kwargs: Any) -> None:
96+
# Store the user's desired n_quantiles to use as an upper bound
97+
self._user_n_quantiles = n_quantiles
98+
# Initialize parent with this, but it will be adapted in fit
99+
super().__init__(n_quantiles=n_quantiles, **kwargs)
100+
101+
def fit(
102+
self, X: np.ndarray, y: np.ndarray | None = None
103+
) -> AdaptiveQuantileTransformer:
104+
X = self._validate_data(
105+
X, copy=self.copy, estimator=self, dtype=float, force_all_finite="allow-nan"
106+
)
107+
n_samples = X.shape[0]
108+
109+
# Adapt n_quantiles for this fit: min of user's preference and available samples
110+
# Ensure n_quantiles is at least 1
111+
effective_n_quantiles = max(1, min(self._user_n_quantiles, n_samples))
112+
113+
# Set self.n_quantiles to the effective value BEFORE calling super().fit()
114+
# This ensures the parent class uses the adapted value for fitting
115+
# and self.n_quantiles will reflect the value used for the fit.
116+
self.n_quantiles = effective_n_quantiles
117+
118+
return super().fit(X, y)
119+
120+
# For completeness and scikit-learn compatibility, allow getting params
121+
# to show the original user setting if desired, though self.n_quantiles
122+
# will show the fitted effective value.
123+
def get_params(self, *, deep: bool = True) -> dict:
124+
params = super().get_params(deep)
125+
# Report the original user_n_quantiles if it's in params
126+
if "_user_n_quantiles" in self.__dict__: # Check if it was set
127+
params["n_quantiles"] = self._user_n_quantiles
128+
return params
129+
130+
def set_params(self, **params: Any) -> AdaptiveQuantileTransformer:
131+
if "n_quantiles" in params:
132+
self._user_n_quantiles = params["n_quantiles"]
133+
return super().set_params(**params)
134+
135+
85136
ALPHAS = (
86137
0.05,
87138
0.1,
@@ -656,9 +707,9 @@ def get_adaptive_preprocessors(
656707
),
657708
(
658709
"other",
659-
QuantileTransformer(
710+
AdaptiveQuantileTransformer(
660711
output_distribution="normal",
661-
n_quantiles=num_examples // 10,
712+
n_quantiles=max(num_examples // 10, 2),
662713
random_state=random_state,
663714
),
664715
# "other" or "ordinal"

tests/test_preprocessing.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
5+
from tabpfn.model.preprocessing import ReshapeFeatureDistributionsStep
6+
7+
8+
def test_preprocessing_large_dataset():
9+
# Generate a synthetic dataset with more than 10,000 samples
10+
num_samples = 15000
11+
num_features = 10
12+
rng = np.random.default_rng()
13+
X = rng.random((num_samples, num_features))
14+
15+
# Create an instance of ReshapeFeatureDistributionsStep
16+
preprocessing_step = ReshapeFeatureDistributionsStep(
17+
transform_name="quantile_norm",
18+
apply_to_categorical=False,
19+
append_to_original=False,
20+
subsample_features=-1,
21+
global_transformer_name=None,
22+
random_state=42,
23+
)
24+
25+
# Define categorical features (empty in this case)
26+
categorical_features = []
27+
28+
# Run the preprocessing step
29+
result = preprocessing_step.fit_transform(X, categorical_features)
30+
31+
# Assert the result is not None
32+
assert result is not None

0 commit comments

Comments
 (0)