@@ -135,7 +135,7 @@ def __init__(self, task="binary", **config):
135
135
self ._task = task if isinstance (task , Task ) else task_factory (task , None , None )
136
136
self .params = self .config2params (config )
137
137
self .estimator_class = self ._model = None
138
- if "_estimator_type" in config :
138
+ if "_estimator_type" in self . params :
139
139
self ._estimator_type = self .params .pop ("_estimator_type" )
140
140
else :
141
141
self ._estimator_type = "classifier" if self ._task .is_classification () else "regressor"
@@ -1696,7 +1696,7 @@ def config2params(self, config: dict) -> dict:
1696
1696
# use_label_encoder is deprecated in 1.7.
1697
1697
if xgboost_version < "1.7.0" :
1698
1698
params ["use_label_encoder" ] = params .get ("use_label_encoder" , False )
1699
- if "n_jobs" in config :
1699
+ if "n_jobs" in params :
1700
1700
params ["nthread" ] = params .pop ("n_jobs" )
1701
1701
return params
1702
1702
@@ -1896,7 +1896,7 @@ def config2params(self, config: dict) -> dict:
1896
1896
params = super ().config2params (config )
1897
1897
if "max_leaves" in params :
1898
1898
params ["max_leaf_nodes" ] = params .get ("max_leaf_nodes" , params .pop ("max_leaves" ))
1899
- if not self ._task .is_classification () and "criterion" in config :
1899
+ if not self ._task .is_classification () and "criterion" in params :
1900
1900
params .pop ("criterion" )
1901
1901
if "random_state" not in params :
1902
1902
params ["random_state" ] = 12032022
@@ -2349,7 +2349,7 @@ def config2params(self, config: dict) -> dict:
2349
2349
params ["loss" ] = params .get ("loss" , None )
2350
2350
if params ["loss" ] is None and self ._task .is_classification ():
2351
2351
params ["loss" ] = "log_loss" if SKLEARN_VERSION >= "1.1" else "log"
2352
- if not self ._task .is_classification ():
2352
+ if not self ._task .is_classification () and "n_jobs" in params :
2353
2353
params .pop ("n_jobs" )
2354
2354
2355
2355
if params .get ("penalty" ) != "elasticnet" :
@@ -2833,4 +2833,4 @@ def __exit__(self, *_):
2833
2833
os .dup2 (self .save_fds [1 ], 2 )
2834
2834
# Close the null files
2835
2835
os .close (self .null_fds [0 ])
2836
- os .close (self .null_fds [1 ])
2836
+ os .close (self .null_fds [1 ])
0 commit comments