25
25
from lightgbm .compat import (
26
26
DASK_INSTALLED ,
27
27
PANDAS_INSTALLED ,
28
+ PYARROW_INSTALLED ,
28
29
_sklearn_version ,
30
+ pa_array ,
31
+ pa_chunked_array ,
32
+ pa_Table ,
29
33
pd_DataFrame ,
30
34
pd_Series ,
31
35
)
54
58
"regression" : lgb .LGBMRegressor ,
55
59
}
56
60
all_tasks = tuple (task_to_model_factory .keys ())
61
+ all_x_types = ("list2d" , "numpy" , "pd_DataFrame" , "pa_Table" , "scipy_csc" , "scipy_csr" )
62
+ all_y_types = ("list1d" , "numpy" , "pd_Series" , "pd_DataFrame" , "pa_Array" , "pa_ChunkedArray" )
63
+ all_group_types = ("list1d_float" , "list1d_int" , "numpy" , "pd_Series" , "pa_Array" , "pa_ChunkedArray" )
57
64
58
65
59
66
def _create_data (task , n_samples = 100 , n_features = 4 ):
@@ -1884,16 +1891,11 @@ def test_predict_rejects_inputs_with_incorrect_number_of_features(predict_disabl
1884
1891
assert preds .shape [0 ] == y .shape [0 ]
1885
1892
1886
1893
1887
- @pytest .mark .parametrize ("X_type" , ["list2d" , "numpy" , "scipy_csc" , "scipy_csr" , "pd_DataFrame" ])
1888
- @pytest .mark .parametrize ("y_type" , ["list1d" , "numpy" , "pd_Series" , "pd_DataFrame" ])
1889
- @pytest .mark .parametrize ("task" , ["binary-classification" , "multiclass-classification" , "regression" ])
1890
- def test_classification_and_regression_minimally_work_with_all_all_accepted_data_types (X_type , y_type , task , rng ):
1891
- if any (t .startswith ("pd_" ) for t in [X_type , y_type ]) and not PANDAS_INSTALLED :
1892
- pytest .skip ("pandas is not installed" )
1894
+ def run_minimal_test (X_type , y_type , g_type , task , rng ):
1893
1895
X , y , g = _create_data (task , n_samples = 2_000 )
1894
1896
weights = np .abs (rng .standard_normal (size = (y .shape [0 ],)))
1895
1897
1896
- if task == "binary-classification" or task == "regression" :
1898
+ if task in { "binary-classification" , "regression" , "ranking" } :
1897
1899
init_score = np .full_like (y , np .mean (y ))
1898
1900
elif task == "multiclass-classification" :
1899
1901
init_score = np .outer (y , np .array ([0.1 , 0.2 , 0.7 ]))
@@ -1909,6 +1911,8 @@ def test_classification_and_regression_minimally_work_with_all_all_accepted_data
1909
1911
X = scipy .sparse .csr_matrix (X )
1910
1912
elif X_type == "pd_DataFrame" :
1911
1913
X = pd_DataFrame (X )
1914
+ elif X_type == "pa_Table" :
1915
+ X = pa_Table .from_pandas (pd_DataFrame (X ))
1912
1916
elif X_type != "numpy" :
1913
1917
raise ValueError (f"Unrecognized X_type: '{ X_type } '" )
1914
1918
@@ -1932,19 +1936,50 @@ def test_classification_and_regression_minimally_work_with_all_all_accepted_data
1932
1936
init_score = pd_DataFrame (init_score )
1933
1937
else :
1934
1938
init_score = pd_Series (init_score )
1939
+ elif y_type == "pa_Array" :
1940
+ y = pa_array (y )
1941
+ weights = pa_array (weights )
1942
+ if task == "multiclass-classification" :
1943
+ init_score = pa_Table .from_pandas (pd_DataFrame (init_score ))
1944
+ else :
1945
+ init_score = pa_array (init_score )
1946
+ elif y_type == "pa_ChunkedArray" :
1947
+ y = pa_chunked_array ([y ])
1948
+ weights = pa_chunked_array ([weights ])
1949
+ if task == "multiclass-classification" :
1950
+ init_score = pa_Table .from_pandas (pd_DataFrame (init_score ))
1951
+ else :
1952
+ init_score = pa_chunked_array ([init_score ])
1935
1953
elif y_type != "numpy" :
1936
1954
raise ValueError (f"Unrecognized y_type: '{ y_type } '" )
1937
1955
1956
+ if g_type == "list1d_float" :
1957
+ g = g .astype ("float" ).tolist ()
1958
+ elif g_type == "list1d_int" :
1959
+ g = g .astype ("int" ).tolist ()
1960
+ elif g_type == "pd_Series" :
1961
+ g = pd_Series (g )
1962
+ elif g_type == "pa_Array" :
1963
+ g = pa_array (g )
1964
+ elif g_type == "pa_ChunkedArray" :
1965
+ g = pa_chunked_array ([g ])
1966
+ elif g_type != "numpy" :
1967
+ raise ValueError (f"Unrecognized g_type: '{ g_type } '" )
1968
+
1938
1969
model = task_to_model_factory [task ](n_estimators = 10 , verbose = - 1 )
1939
- model .fit (
1940
- X = X ,
1941
- y = y ,
1942
- sample_weight = weights ,
1943
- init_score = init_score ,
1944
- eval_set = [(X_valid , y )],
1945
- eval_sample_weight = [weights ],
1946
- eval_init_score = [init_score ],
1947
- )
1970
+ params_fit = {
1971
+ "X" : X ,
1972
+ "y" : y ,
1973
+ "sample_weight" : weights ,
1974
+ "init_score" : init_score ,
1975
+ "eval_set" : [(X_valid , y )],
1976
+ "eval_sample_weight" : [weights ],
1977
+ "eval_init_score" : [init_score ],
1978
+ }
1979
+ if task == "ranking" :
1980
+ params_fit ["group" ] = g
1981
+ params_fit ["eval_group" ] = [g ]
1982
+ model .fit (** params_fit )
1948
1983
1949
1984
preds = model .predict (X )
1950
1985
if task == "binary-classification" :
@@ -1953,72 +1988,44 @@ def test_classification_and_regression_minimally_work_with_all_all_accepted_data
1953
1988
assert accuracy_score (y , preds ) >= 0.99
1954
1989
elif task == "regression" :
1955
1990
assert r2_score (y , preds ) > 0.86
1991
+ elif task == "ranking" :
1992
+ assert spearmanr (preds , y ).correlation >= 0.99
1956
1993
else :
1957
1994
raise ValueError (f"Unrecognized task: '{ task } '" )
1958
1995
1959
1996
1960
- @pytest .mark .parametrize ("X_type" , ["list2d" , "numpy" , "scipy_csc" , "scipy_csr" , "pd_DataFrame" ])
1961
- @pytest .mark .parametrize ("y_type" , ["list1d" , "numpy" , "pd_DataFrame" , "pd_Series" ])
1962
- @pytest .mark .parametrize ("g_type" , ["list1d_float" , "list1d_int" , "numpy" , "pd_Series" ])
1963
- def test_ranking_minimally_works_with_all_all_accepted_data_types (X_type , y_type , g_type , rng ):
1964
- if any (t .startswith ("pd_" ) for t in [X_type , y_type , g_type ]) and not PANDAS_INSTALLED :
1997
+ @pytest .mark .parametrize ("X_type" , all_x_types )
1998
+ @pytest .mark .parametrize ("y_type" , all_y_types )
1999
+ @pytest .mark .parametrize ("task" , [t for t in all_tasks if t != "ranking" ])
2000
+ def test_classification_and_regression_minimally_work_with_all_accepted_data_types (
2001
+ X_type ,
2002
+ y_type ,
2003
+ task ,
2004
+ rng ,
2005
+ ):
2006
+ if any (t .startswith ("pd_" ) for t in [X_type , y_type ]) and not PANDAS_INSTALLED :
1965
2007
pytest .skip ("pandas is not installed" )
1966
- X , y , g = _create_data (task = "ranking" , n_samples = 1_000 )
1967
- weights = np .abs (rng .standard_normal (size = (y .shape [0 ],)))
1968
- init_score = np .full_like (y , np .mean (y ))
1969
- X_valid = X * 2
2008
+ if any (t .startswith ("pa_" ) for t in [X_type , y_type ]) and not PYARROW_INSTALLED :
2009
+ pytest .skip ("pyarrow is not installed" )
1970
2010
1971
- if X_type == "list2d" :
1972
- X = X .tolist ()
1973
- elif X_type == "scipy_csc" :
1974
- X = scipy .sparse .csc_matrix (X )
1975
- elif X_type == "scipy_csr" :
1976
- X = scipy .sparse .csr_matrix (X )
1977
- elif X_type == "pd_DataFrame" :
1978
- X = pd_DataFrame (X )
1979
- elif X_type != "numpy" :
1980
- raise ValueError (f"Unrecognized X_type: '{ X_type } '" )
2011
+ run_minimal_test (X_type = X_type , y_type = y_type , g_type = "numpy" , task = task , rng = rng )
1981
2012
1982
- # make weights and init_score same types as y, just to avoid
1983
- # a huge number of combinations and therefore test cases
1984
- if y_type == "list1d" :
1985
- y = y .tolist ()
1986
- weights = weights .tolist ()
1987
- init_score = init_score .tolist ()
1988
- elif y_type == "pd_DataFrame" :
1989
- y = pd_DataFrame (y )
1990
- weights = pd_Series (weights )
1991
- init_score = pd_Series (init_score )
1992
- elif y_type == "pd_Series" :
1993
- y = pd_Series (y )
1994
- weights = pd_Series (weights )
1995
- init_score = pd_Series (init_score )
1996
- elif y_type != "numpy" :
1997
- raise ValueError (f"Unrecognized y_type: '{ y_type } '" )
1998
2013
1999
- if g_type == "list1d_float" :
2000
- g = g .astype ("float" ).tolist ()
2001
- elif g_type == "list1d_int" :
2002
- g = g .astype ("int" ).tolist ()
2003
- elif g_type == "pd_Series" :
2004
- g = pd_Series (g )
2005
- elif g_type != "numpy" :
2006
- raise ValueError (f"Unrecognized g_type: '{ g_type } '" )
2014
+ @pytest .mark .parametrize ("X_type" , all_x_types )
2015
+ @pytest .mark .parametrize ("y_type" , all_y_types )
2016
+ @pytest .mark .parametrize ("g_type" , all_group_types )
2017
+ def test_ranking_minimally_works_with_all_accepted_data_types (
2018
+ X_type ,
2019
+ y_type ,
2020
+ g_type ,
2021
+ rng ,
2022
+ ):
2023
+ if any (t .startswith ("pd_" ) for t in [X_type , y_type , g_type ]) and not PANDAS_INSTALLED :
2024
+ pytest .skip ("pandas is not installed" )
2025
+ if any (t .startswith ("pa_" ) for t in [X_type , y_type , g_type ]) and not PYARROW_INSTALLED :
2026
+ pytest .skip ("pyarrow is not installed" )
2007
2027
2008
- model = task_to_model_factory ["ranking" ](n_estimators = 10 , verbose = - 1 )
2009
- model .fit (
2010
- X = X ,
2011
- y = y ,
2012
- sample_weight = weights ,
2013
- init_score = init_score ,
2014
- group = g ,
2015
- eval_set = [(X_valid , y )],
2016
- eval_sample_weight = [weights ],
2017
- eval_init_score = [init_score ],
2018
- eval_group = [g ],
2019
- )
2020
- preds = model .predict (X )
2021
- assert spearmanr (preds , y ).correlation >= 0.99
2028
+ run_minimal_test (X_type = X_type , y_type = y_type , g_type = g_type , task = "ranking" , rng = rng )
2022
2029
2023
2030
2024
2031
def test_classifier_fit_detects_classes_every_time ():
0 commit comments