Skip to content

Commit 2ba5f8b

Browse files
authored
Fix params pop error (#1408)
1 parent d0a1195 commit 2ba5f8b

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

flaml/automl/model.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(self, task="binary", **config):
135135
self._task = task if isinstance(task, Task) else task_factory(task, None, None)
136136
self.params = self.config2params(config)
137137
self.estimator_class = self._model = None
138-
if "_estimator_type" in config:
138+
if "_estimator_type" in self.params:
139139
self._estimator_type = self.params.pop("_estimator_type")
140140
else:
141141
self._estimator_type = "classifier" if self._task.is_classification() else "regressor"
@@ -1696,7 +1696,7 @@ def config2params(self, config: dict) -> dict:
16961696
# use_label_encoder is deprecated in 1.7.
16971697
if xgboost_version < "1.7.0":
16981698
params["use_label_encoder"] = params.get("use_label_encoder", False)
1699-
if "n_jobs" in config:
1699+
if "n_jobs" in params:
17001700
params["nthread"] = params.pop("n_jobs")
17011701
return params
17021702

@@ -1896,7 +1896,7 @@ def config2params(self, config: dict) -> dict:
18961896
params = super().config2params(config)
18971897
if "max_leaves" in params:
18981898
params["max_leaf_nodes"] = params.get("max_leaf_nodes", params.pop("max_leaves"))
1899-
if not self._task.is_classification() and "criterion" in config:
1899+
if not self._task.is_classification() and "criterion" in params:
19001900
params.pop("criterion")
19011901
if "random_state" not in params:
19021902
params["random_state"] = 12032022
@@ -2349,7 +2349,7 @@ def config2params(self, config: dict) -> dict:
23492349
params["loss"] = params.get("loss", None)
23502350
if params["loss"] is None and self._task.is_classification():
23512351
params["loss"] = "log_loss" if SKLEARN_VERSION >= "1.1" else "log"
2352-
if not self._task.is_classification():
2352+
if not self._task.is_classification() and "n_jobs" in params:
23532353
params.pop("n_jobs")
23542354

23552355
if params.get("penalty") != "elasticnet":
@@ -2833,4 +2833,4 @@ def __exit__(self, *_):
28332833
os.dup2(self.save_fds[1], 2)
28342834
# Close the null files
28352835
os.close(self.null_fds[0])
2836-
os.close(self.null_fds[1])
2836+
os.close(self.null_fds[1])

0 commit comments

Comments
 (0)