Skip to content

Commit 72a3981

Browse files
authored
[python-package][sklearn] Support PyArrow Table as an input in scikit-learn methods (#6910)
1 parent c58320f commit 72a3981

File tree

3 files changed

+87
-76
lines changed

3 files changed

+87
-76
lines changed

python-package/lightgbm/compat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def __init__(self, *args: Any, **kwargs: Any):
269269
from pyarrow import Array as pa_Array
270270
from pyarrow import ChunkedArray as pa_ChunkedArray
271271
from pyarrow import Table as pa_Table
272+
from pyarrow import array as pa_array
272273
from pyarrow import chunked_array as pa_chunked_array
273274
from pyarrow.types import is_boolean as arrow_is_boolean
274275
from pyarrow.types import is_floating as arrow_is_floating
@@ -302,6 +303,7 @@ class pa_compute: # type: ignore
302303
all = None
303304
equal = None
304305

306+
pa_array = None
305307
pa_chunked_array = None
306308
arrow_is_boolean = None
307309
arrow_is_integer = None

python-package/lightgbm/sklearn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_LGBMRegressorBase,
4242
_LGBMValidateData,
4343
_sklearn_version,
44+
pa_Table,
4445
pd_DataFrame,
4546
)
4647
from .engine import train
@@ -60,6 +61,7 @@
6061
List[Union[List[float], List[int]]],
6162
np.ndarray,
6263
pd_DataFrame,
64+
pa_Table,
6365
scipy.sparse.spmatrix,
6466
]
6567
_LGBM_ScikitCustomObjectiveFunction = Union[
@@ -943,7 +945,7 @@ def fit(
943945
params["metric"] = [e for e in eval_metrics_builtin if e not in params["metric"]] + params["metric"]
944946
params["metric"] = [metric for metric in params["metric"] if metric is not None]
945947

946-
if not isinstance(X, pd_DataFrame):
948+
if not isinstance(X, (pd_DataFrame, pa_Table)):
947949
_X, _y = _LGBMValidateData(
948950
self,
949951
X,
@@ -1075,7 +1077,7 @@ def fit(
10751077

10761078
fit.__doc__ = (
10771079
_lgbmmodel_doc_fit.format(
1078-
X_shape="numpy array, pandas DataFrame, scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
1080+
X_shape="numpy array, pandas DataFrame, pyarrow Table, scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
10791081
y_shape="numpy array, pandas DataFrame, pandas Series, list of int or float, pyarrow Array, pyarrow ChunkedArray of shape = [n_samples]",
10801082
sample_weight_shape="numpy array, pandas Series, list of int or float, pyarrow Array, pyarrow ChunkedArray of shape = [n_samples] or None, optional (default=None)",
10811083
init_score_shape="numpy array, pandas DataFrame, pandas Series, list of int or float, list of lists, pyarrow Array, pyarrow ChunkedArray, pyarrow Table of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) or shape = [n_samples, n_classes] (for multi-class task) or None, optional (default=None)",
@@ -1102,7 +1104,7 @@ def predict(
11021104
"""Docstring is set after definition, using a template."""
11031105
if not self.__sklearn_is_fitted__():
11041106
raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.")
1105-
if not isinstance(X, pd_DataFrame):
1107+
if not isinstance(X, (pd_DataFrame, pa_Table)):
11061108
X = _LGBMValidateData(
11071109
self,
11081110
X,

tests/python_package_test/test_sklearn.py

Lines changed: 80 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from lightgbm.compat import (
2626
DASK_INSTALLED,
2727
PANDAS_INSTALLED,
28+
PYARROW_INSTALLED,
2829
_sklearn_version,
30+
pa_array,
31+
pa_chunked_array,
32+
pa_Table,
2933
pd_DataFrame,
3034
pd_Series,
3135
)
@@ -54,6 +58,9 @@
5458
"regression": lgb.LGBMRegressor,
5559
}
5660
all_tasks = tuple(task_to_model_factory.keys())
61+
all_x_types = ("list2d", "numpy", "pd_DataFrame", "pa_Table", "scipy_csc", "scipy_csr")
62+
all_y_types = ("list1d", "numpy", "pd_Series", "pd_DataFrame", "pa_Array", "pa_ChunkedArray")
63+
all_group_types = ("list1d_float", "list1d_int", "numpy", "pd_Series", "pa_Array", "pa_ChunkedArray")
5764

5865

5966
def _create_data(task, n_samples=100, n_features=4):
@@ -1884,16 +1891,11 @@ def test_predict_rejects_inputs_with_incorrect_number_of_features(predict_disabl
18841891
assert preds.shape[0] == y.shape[0]
18851892

18861893

1887-
@pytest.mark.parametrize("X_type", ["list2d", "numpy", "scipy_csc", "scipy_csr", "pd_DataFrame"])
1888-
@pytest.mark.parametrize("y_type", ["list1d", "numpy", "pd_Series", "pd_DataFrame"])
1889-
@pytest.mark.parametrize("task", ["binary-classification", "multiclass-classification", "regression"])
1890-
def test_classification_and_regression_minimally_work_with_all_all_accepted_data_types(X_type, y_type, task, rng):
1891-
if any(t.startswith("pd_") for t in [X_type, y_type]) and not PANDAS_INSTALLED:
1892-
pytest.skip("pandas is not installed")
1894+
def run_minimal_test(X_type, y_type, g_type, task, rng):
18931895
X, y, g = _create_data(task, n_samples=2_000)
18941896
weights = np.abs(rng.standard_normal(size=(y.shape[0],)))
18951897

1896-
if task == "binary-classification" or task == "regression":
1898+
if task in {"binary-classification", "regression", "ranking"}:
18971899
init_score = np.full_like(y, np.mean(y))
18981900
elif task == "multiclass-classification":
18991901
init_score = np.outer(y, np.array([0.1, 0.2, 0.7]))
@@ -1909,6 +1911,8 @@ def test_classification_and_regression_minimally_work_with_all_all_accepted_data
19091911
X = scipy.sparse.csr_matrix(X)
19101912
elif X_type == "pd_DataFrame":
19111913
X = pd_DataFrame(X)
1914+
elif X_type == "pa_Table":
1915+
X = pa_Table.from_pandas(pd_DataFrame(X))
19121916
elif X_type != "numpy":
19131917
raise ValueError(f"Unrecognized X_type: '{X_type}'")
19141918

@@ -1932,19 +1936,50 @@ def test_classification_and_regression_minimally_work_with_all_all_accepted_data
19321936
init_score = pd_DataFrame(init_score)
19331937
else:
19341938
init_score = pd_Series(init_score)
1939+
elif y_type == "pa_Array":
1940+
y = pa_array(y)
1941+
weights = pa_array(weights)
1942+
if task == "multiclass-classification":
1943+
init_score = pa_Table.from_pandas(pd_DataFrame(init_score))
1944+
else:
1945+
init_score = pa_array(init_score)
1946+
elif y_type == "pa_ChunkedArray":
1947+
y = pa_chunked_array([y])
1948+
weights = pa_chunked_array([weights])
1949+
if task == "multiclass-classification":
1950+
init_score = pa_Table.from_pandas(pd_DataFrame(init_score))
1951+
else:
1952+
init_score = pa_chunked_array([init_score])
19351953
elif y_type != "numpy":
19361954
raise ValueError(f"Unrecognized y_type: '{y_type}'")
19371955

1956+
if g_type == "list1d_float":
1957+
g = g.astype("float").tolist()
1958+
elif g_type == "list1d_int":
1959+
g = g.astype("int").tolist()
1960+
elif g_type == "pd_Series":
1961+
g = pd_Series(g)
1962+
elif g_type == "pa_Array":
1963+
g = pa_array(g)
1964+
elif g_type == "pa_ChunkedArray":
1965+
g = pa_chunked_array([g])
1966+
elif g_type != "numpy":
1967+
raise ValueError(f"Unrecognized g_type: '{g_type}'")
1968+
19381969
model = task_to_model_factory[task](n_estimators=10, verbose=-1)
1939-
model.fit(
1940-
X=X,
1941-
y=y,
1942-
sample_weight=weights,
1943-
init_score=init_score,
1944-
eval_set=[(X_valid, y)],
1945-
eval_sample_weight=[weights],
1946-
eval_init_score=[init_score],
1947-
)
1970+
params_fit = {
1971+
"X": X,
1972+
"y": y,
1973+
"sample_weight": weights,
1974+
"init_score": init_score,
1975+
"eval_set": [(X_valid, y)],
1976+
"eval_sample_weight": [weights],
1977+
"eval_init_score": [init_score],
1978+
}
1979+
if task == "ranking":
1980+
params_fit["group"] = g
1981+
params_fit["eval_group"] = [g]
1982+
model.fit(**params_fit)
19481983

19491984
preds = model.predict(X)
19501985
if task == "binary-classification":
@@ -1953,72 +1988,44 @@ def test_classification_and_regression_minimally_work_with_all_all_accepted_data
19531988
assert accuracy_score(y, preds) >= 0.99
19541989
elif task == "regression":
19551990
assert r2_score(y, preds) > 0.86
1991+
elif task == "ranking":
1992+
assert spearmanr(preds, y).correlation >= 0.99
19561993
else:
19571994
raise ValueError(f"Unrecognized task: '{task}'")
19581995

19591996

1960-
@pytest.mark.parametrize("X_type", ["list2d", "numpy", "scipy_csc", "scipy_csr", "pd_DataFrame"])
1961-
@pytest.mark.parametrize("y_type", ["list1d", "numpy", "pd_DataFrame", "pd_Series"])
1962-
@pytest.mark.parametrize("g_type", ["list1d_float", "list1d_int", "numpy", "pd_Series"])
1963-
def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type, g_type, rng):
1964-
if any(t.startswith("pd_") for t in [X_type, y_type, g_type]) and not PANDAS_INSTALLED:
1997+
@pytest.mark.parametrize("X_type", all_x_types)
1998+
@pytest.mark.parametrize("y_type", all_y_types)
1999+
@pytest.mark.parametrize("task", [t for t in all_tasks if t != "ranking"])
2000+
def test_classification_and_regression_minimally_work_with_all_accepted_data_types(
2001+
X_type,
2002+
y_type,
2003+
task,
2004+
rng,
2005+
):
2006+
if any(t.startswith("pd_") for t in [X_type, y_type]) and not PANDAS_INSTALLED:
19652007
pytest.skip("pandas is not installed")
1966-
X, y, g = _create_data(task="ranking", n_samples=1_000)
1967-
weights = np.abs(rng.standard_normal(size=(y.shape[0],)))
1968-
init_score = np.full_like(y, np.mean(y))
1969-
X_valid = X * 2
2008+
if any(t.startswith("pa_") for t in [X_type, y_type]) and not PYARROW_INSTALLED:
2009+
pytest.skip("pyarrow is not installed")
19702010

1971-
if X_type == "list2d":
1972-
X = X.tolist()
1973-
elif X_type == "scipy_csc":
1974-
X = scipy.sparse.csc_matrix(X)
1975-
elif X_type == "scipy_csr":
1976-
X = scipy.sparse.csr_matrix(X)
1977-
elif X_type == "pd_DataFrame":
1978-
X = pd_DataFrame(X)
1979-
elif X_type != "numpy":
1980-
raise ValueError(f"Unrecognized X_type: '{X_type}'")
2011+
run_minimal_test(X_type=X_type, y_type=y_type, g_type="numpy", task=task, rng=rng)
19812012

1982-
# make weights and init_score same types as y, just to avoid
1983-
# a huge number of combinations and therefore test cases
1984-
if y_type == "list1d":
1985-
y = y.tolist()
1986-
weights = weights.tolist()
1987-
init_score = init_score.tolist()
1988-
elif y_type == "pd_DataFrame":
1989-
y = pd_DataFrame(y)
1990-
weights = pd_Series(weights)
1991-
init_score = pd_Series(init_score)
1992-
elif y_type == "pd_Series":
1993-
y = pd_Series(y)
1994-
weights = pd_Series(weights)
1995-
init_score = pd_Series(init_score)
1996-
elif y_type != "numpy":
1997-
raise ValueError(f"Unrecognized y_type: '{y_type}'")
19982013

1999-
if g_type == "list1d_float":
2000-
g = g.astype("float").tolist()
2001-
elif g_type == "list1d_int":
2002-
g = g.astype("int").tolist()
2003-
elif g_type == "pd_Series":
2004-
g = pd_Series(g)
2005-
elif g_type != "numpy":
2006-
raise ValueError(f"Unrecognized g_type: '{g_type}'")
2014+
@pytest.mark.parametrize("X_type", all_x_types)
2015+
@pytest.mark.parametrize("y_type", all_y_types)
2016+
@pytest.mark.parametrize("g_type", all_group_types)
2017+
def test_ranking_minimally_works_with_all_accepted_data_types(
2018+
X_type,
2019+
y_type,
2020+
g_type,
2021+
rng,
2022+
):
2023+
if any(t.startswith("pd_") for t in [X_type, y_type, g_type]) and not PANDAS_INSTALLED:
2024+
pytest.skip("pandas is not installed")
2025+
if any(t.startswith("pa_") for t in [X_type, y_type, g_type]) and not PYARROW_INSTALLED:
2026+
pytest.skip("pyarrow is not installed")
20072027

2008-
model = task_to_model_factory["ranking"](n_estimators=10, verbose=-1)
2009-
model.fit(
2010-
X=X,
2011-
y=y,
2012-
sample_weight=weights,
2013-
init_score=init_score,
2014-
group=g,
2015-
eval_set=[(X_valid, y)],
2016-
eval_sample_weight=[weights],
2017-
eval_init_score=[init_score],
2018-
eval_group=[g],
2019-
)
2020-
preds = model.predict(X)
2021-
assert spearmanr(preds, y).correlation >= 0.99
2028+
run_minimal_test(X_type=X_type, y_type=y_type, g_type=g_type, task="ranking", rng=rng)
20222029

20232030

20242031
def test_classifier_fit_detects_classes_every_time():

0 commit comments

Comments
 (0)