Skip to content

Commit

Permalink
Starting the change for XGBoost integration into EVADb.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jineet Desai committed Oct 3, 2023
1 parent e59092d commit 3ff70a9
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 1 deletion.
4 changes: 3 additions & 1 deletion evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement):
outputs.append(column)
else:
inputs.append(column)
elif string_comparison_case_insensitive(node.function_type, "sklearn"):
elif string_comparison_case_insensitive(
node.function_type, "sklearn"
) or string_comparison_case_insensitive(node.function_type, "XGBoost"):
assert (
"predict" in arg_map
), f"Creating {node.function_type} functions expects 'predict' metadata."
Expand Down
61 changes: 61 additions & 0 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from evadb.utils.generic_utils import (
load_function_class_from_file,
string_comparison_case_insensitive,
try_to_import_automl,
try_to_import_forecast,
try_to_import_ludwig,
try_to_import_neuralforecast,
try_to_import_sklearn,
Expand Down Expand Up @@ -163,6 +165,57 @@ def handle_sklearn_function(self):
self.node.metadata,
)

def handle_xgboost_function(self):
"""Handle xgboost functions
We use the Flaml AutoML model for training xgboost models.
"""
try_to_import_automl()

assert (
len(self.children) == 1
), "Create sklearn function expects 1 child, finds {}.".format(
len(self.children)
)

aggregated_batch_list = []
child = self.children[0]
for batch in child.exec():
aggregated_batch_list.append(batch)
aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
aggregated_batch.drop_column_alias()

arg_map = {arg.key: arg.value for arg in self.node.metadata}
from flaml import AutoML

model = AutoML()
settings = {
"time_budget": 120,
"metric": "r2",
"estimator_list": ["xgboost"],
"task": "regression",
}
model.fit(
dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings
)
model_path = os.path.join(
self.db.config.get_value("storage", "model_dir"), self.node.name
)
pickle.dump(model, open(model_path, "wb"))
self.node.metadata.append(
FunctionMetadataCatalogEntry("model_path", model_path)
)

impl_path = Path(f"{self.function_dir}/xgboost.py").absolute().as_posix()
io_list = self._resolve_function_io(None)
return (
self.node.name,
impl_path,
self.node.function_type,
io_list,
self.node.metadata,
)

def handle_ultralytics_function(self):
"""Handle Ultralytics functions"""
try_to_import_ultralytics()
Expand Down Expand Up @@ -516,6 +569,14 @@ def exec(self, *args, **kwargs):
io_list,
metadata,
) = self.handle_sklearn_function()
elif string_comparison_case_insensitive(self.node.function_type, "XGBoost"):
(
name,
impl_path,
function_type,
io_list,
metadata,
) = self.handle_xgboost_function()
elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
(
name,
Expand Down
47 changes: 47 additions & 0 deletions evadb/functions/xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle

import pandas as pd

from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.utils.generic_utils import try_to_import_automl


class GenericXGBoostModel(AbstractFunction):
@property
def name(self) -> str:
return "GenericXGBoostModel"

def setup(self, model_path: str, **kwargs):
try_to_import_automl()

self.model = pickle.load(open(model_path, "rb"))

def forward(self, frames: pd.DataFrame) -> pd.DataFrame:
# Last column is the value to predict, hence don't pass that to the
# predict method.
predictions = self.model.predict(frames.iloc[:, :-1])
predict_df = pd.DataFrame(predictions)
# We need to rename the column of the output dataframe. For this we
# shall rename it to the column name same as that of the last column of
# frames. This is because the last column of frames corresponds to the
# variable we want to predict.
predict_df.rename(columns={0: frames.columns[-1]}, inplace=True)
return predict_df

def to_device(self, device: str):
# TODO figure out how to control the GPU for ludwig models
return self
10 changes: 10 additions & 0 deletions evadb/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ def is_sklearn_available() -> bool:
except ValueError: # noqa: E722
return False

def try_to_import_automl():
try:
import flaml # noqa: F401
from flaml import AutoML # noqa: F401
except ImportError:
raise ValueError(
"""Could not import Flaml AutoML.
Please install it with `pip install "flaml[automl]"`."""
)


##############################
## VISION
Expand Down
16 changes: 16 additions & 0 deletions test/integration_tests/long/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ def test_sklearn_regression(self):
self.assertEqual(len(result.columns), 1)
self.assertEqual(len(result), 10)

def test_xgboost_regression(self):
create_predict_function = """
CREATE FUNCTION IF NOT EXISTS PredictRent FROM
( SELECT number_of_rooms, number_of_bathrooms, days_on_market, rental_price FROM HomeRentals )
TYPE XGBoost
PREDICT 'rental_price';
"""
execute_query_fetch_all(self.evadb, create_predict_function)

predict_query = """
SELECT PredictRent(number_of_rooms, number_of_bathrooms, days_on_market, rental_price) FROM HomeRentals LIMIT 10;
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result.columns), 1)
self.assertEqual(len(result), 10)


if __name__ == "__main__":
unittest.main()

0 comments on commit 3ff70a9

Please sign in to comment.