Skip to content

Commit 0c25a44

Browse files
jineetdJineet Desai
and
Jineet Desai
authored
Adding changes for Flaml Sklearn integration (#1361)
Flaml provides support for Sklearn models like Random Forests, KNN, Extra Trees Regressor, and Logistic Regression with regularization. We plan to integrate these ML models into EVADB. Link for Flaml documentation: https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML --------- Co-authored-by: Jineet Desai <[email protected]>
1 parent 69b39b8 commit 0c25a44

File tree

7 files changed

+39
-42
lines changed

7 files changed

+39
-42
lines changed

evadb/configuration/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,5 @@
3636
DEFAULT_DOCUMENT_CHUNK_OVERLAP = 200
3737
DEFAULT_TRAIN_REGRESSION_METRIC = "rmse"
3838
DEFAULT_XGBOOST_TASK = "regression"
39+
DEFAULT_SKLEARN_TRAIN_MODEL = "rf"
40+
SKLEARN_SUPPORTED_MODELS = ["rf", "extra_tree", "kneighbor"]

evadb/executor/create_function_executor.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
3131
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
3232
from evadb.configuration.constants import (
33+
DEFAULT_SKLEARN_TRAIN_MODEL,
3334
DEFAULT_TRAIN_REGRESSION_METRIC,
3435
DEFAULT_TRAIN_TIME_LIMIT,
3536
DEFAULT_XGBOOST_TASK,
37+
SKLEARN_SUPPORTED_MODELS,
3638
EvaDB_INSTALLATION_DIR,
3739
)
3840
from evadb.database import EvaDBDatabase
@@ -45,13 +47,12 @@
4547
from evadb.utils.generic_utils import (
4648
load_function_class_from_file,
4749
string_comparison_case_insensitive,
50+
try_to_import_flaml_automl,
4851
try_to_import_ludwig,
4952
try_to_import_neuralforecast,
50-
try_to_import_sklearn,
5153
try_to_import_statsforecast,
5254
try_to_import_torch,
5355
try_to_import_ultralytics,
54-
try_to_import_xgboost,
5556
)
5657
from evadb.utils.logging_manager import logger
5758

@@ -169,8 +170,7 @@ def handle_sklearn_function(self):
169170
170171
Use Sklearn's regression to train models.
171172
"""
172-
try_to_import_sklearn()
173-
from sklearn.linear_model import LinearRegression
173+
try_to_import_flaml_automl()
174174

175175
assert (
176176
len(self.children) == 1
@@ -186,13 +186,26 @@ def handle_sklearn_function(self):
186186
aggregated_batch.drop_column_alias()
187187

188188
arg_map = {arg.key: arg.value for arg in self.node.metadata}
189-
model = LinearRegression()
190-
Y = aggregated_batch.frames[arg_map["predict"]]
191-
aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
189+
from flaml import AutoML
190+
191+
model = AutoML()
192+
sklearn_model = arg_map.get("model", DEFAULT_SKLEARN_TRAIN_MODEL)
193+
if sklearn_model not in SKLEARN_SUPPORTED_MODELS:
194+
raise ValueError(
195+
f"Sklearn Model {sklearn_model} provided as input is not supported."
196+
)
197+
settings = {
198+
"time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
199+
"metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC),
200+
"estimator_list": [sklearn_model],
201+
"task": arg_map.get("task", DEFAULT_XGBOOST_TASK),
202+
}
192203
start_time = int(time.time())
193-
model.fit(X=aggregated_batch.frames, y=Y)
204+
model.fit(
205+
dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings
206+
)
194207
train_time = int(time.time()) - start_time
195-
score = model.score(X=aggregated_batch.frames, y=Y)
208+
score = model.best_loss
196209
model_path = os.path.join(
197210
self.db.catalog().get_configuration_catalog_value("model_dir"),
198211
self.node.name,
@@ -232,7 +245,7 @@ def handle_xgboost_function(self):
232245
233246
We use the Flaml AutoML model for training xgboost models.
234247
"""
235-
try_to_import_xgboost()
248+
try_to_import_flaml_automl()
236249

237250
assert (
238251
len(self.children) == 1

evadb/functions/sklearn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pandas as pd
1818

1919
from evadb.functions.abstract.abstract_function import AbstractFunction
20-
from evadb.utils.generic_utils import try_to_import_sklearn
20+
from evadb.utils.generic_utils import try_to_import_flaml_automl
2121

2222

2323
class GenericSklearnModel(AbstractFunction):
@@ -26,7 +26,7 @@ def name(self) -> str:
2626
return "GenericSklearnModel"
2727

2828
def setup(self, model_path: str, predict_col: str, **kwargs):
29-
try_to_import_sklearn()
29+
try_to_import_flaml_automl()
3030

3131
self.model = pickle.load(open(model_path, "rb"))
3232
self.predict_col = predict_col

evadb/functions/xgboost.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pandas as pd
1818

1919
from evadb.functions.abstract.abstract_function import AbstractFunction
20-
from evadb.utils.generic_utils import try_to_import_xgboost
20+
from evadb.utils.generic_utils import try_to_import_flaml_automl
2121

2222

2323
class GenericXGBoostModel(AbstractFunction):
@@ -26,7 +26,7 @@ def name(self) -> str:
2626
return "GenericXGBoostModel"
2727

2828
def setup(self, model_path: str, predict_col: str, **kwargs):
29-
try_to_import_xgboost()
29+
try_to_import_flaml_automl()
3030

3131
self.model = pickle.load(open(model_path, "rb"))
3232
self.predict_col = predict_col

evadb/utils/generic_utils.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -369,39 +369,20 @@ def is_forecast_available() -> bool:
369369
return False
370370

371371

372-
def try_to_import_sklearn():
373-
try:
374-
import sklearn # noqa: F401
375-
from sklearn.linear_model import LinearRegression # noqa: F401
376-
except ImportError:
377-
raise ValueError(
378-
"""Could not import sklearn.
379-
Please install it with `pip install scikit-learn`."""
380-
)
381-
382-
383-
def is_sklearn_available() -> bool:
384-
try:
385-
try_to_import_sklearn()
386-
return True
387-
except ValueError: # noqa: E722
388-
return False
389-
390-
391-
def try_to_import_xgboost():
372+
def try_to_import_flaml_automl():
392373
try:
393374
import flaml # noqa: F401
394375
from flaml import AutoML # noqa: F401
395376
except ImportError:
396377
raise ValueError(
397-
"""Could not import Flaml AutoML.
378+
"""Could not import Flaml AutML.
398379
Please install it with `pip install "flaml[automl]"`."""
399380
)
400381

401382

402-
def is_xgboost_available() -> bool:
383+
def is_flaml_automl_available() -> bool:
403384
try:
404-
try_to_import_xgboost()
385+
try_to_import_flaml_automl()
405386
return True
406387
except ValueError: # noqa: E722
407388
return False

test/integration_tests/long/test_model_train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def test_sklearn_regression(self):
116116
CREATE OR REPLACE FUNCTION PredictHouseRentSklearn FROM
117117
( SELECT number_of_rooms, number_of_bathrooms, days_on_market, rental_price FROM HomeRentals )
118118
TYPE Sklearn
119-
PREDICT 'rental_price';
119+
PREDICT 'rental_price'
120+
MODEL 'extra_tree'
121+
METRIC 'r2';
120122
"""
121123
execute_query_fetch_all(self.evadb, create_predict_function)
122124

test/markers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@
2020

2121
from evadb.utils.generic_utils import (
2222
is_chromadb_available,
23+
is_flaml_automl_available,
2324
is_forecast_available,
2425
is_gpu_available,
2526
is_ludwig_available,
2627
is_milvus_available,
2728
is_pinecone_available,
2829
is_qdrant_available,
2930
is_replicate_available,
30-
is_sklearn_available,
31-
is_xgboost_available,
3231
)
3332

3433
asyncio_skip_marker = pytest.mark.skipif(
@@ -93,11 +92,11 @@
9392
)
9493

9594
sklearn_skip_marker = pytest.mark.skipif(
96-
is_sklearn_available() is False, reason="Run only if sklearn is available"
95+
is_flaml_automl_available() is False, reason="Run only if Flaml AutoML is available"
9796
)
9897

9998
xgboost_skip_marker = pytest.mark.skipif(
100-
is_xgboost_available() is False, reason="Run only if xgboost is available"
99+
is_flaml_automl_available() is False, reason="Run only if Flaml AutoML is available"
101100
)
102101

103102
chatgpt_skip_marker = pytest.mark.skip(

0 commit comments

Comments
 (0)