Skip to content

Commit

Permalink
Correct FIL export format for sklearn/cuml to treelite checkpoint (#124)
Browse files Browse the repository at this point in the history
* Save sklearn and cuml forest models in the correct format (treelite)

* Add test for exported fil model filenames

* Add treelite dependencies to fil workflow
  • Loading branch information
oliverholworthy authored Jun 23, 2022
1 parent 712b04d commit efbf7f6
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 7 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/fil.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ jobs:
- name: Install FIL pip dependencies
run: |
python -m pip install xgboost lightgbm sklearn
# version of treelite is required to match the version used in Triton
python -m pip install treelite==2.3.0 treelite_runtime==2.3.0
- name: Build
run: |
python setup.py develop
Expand Down
4 changes: 4 additions & 0 deletions merlin/systems/dag/ops/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import sklearn.ensemble as sklearn_ensemble
except ImportError:
sklearn_ensemble = None
try:
import treelite.sklearn as treelite_sklearn
except ImportError:
treelite_sklearn = None
try:
import lightgbm
except ImportError:
Expand Down
18 changes: 11 additions & 7 deletions merlin/systems/dag/ops/fil.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#
import json
import pathlib
import pickle
from abc import ABC, abstractmethod

import numpy as np
Expand All @@ -29,6 +28,7 @@
lightgbm,
pb_utils,
sklearn_ensemble,
treelite_sklearn,
xgboost,
)
from merlin.systems.dag.ops.operator import (
Expand Down Expand Up @@ -484,13 +484,18 @@ class SKLearnRandomForest(FILModel):
"""Scikit-Learn RandomForest Wrapper for FIL."""

model_type = "treelite_checkpoint"
model_filename = "model.pkl"
model_filename = "checkpoint.tl"

def save(self, version_path):
"""Save model to version_path."""
model_path = pathlib.Path(version_path) / self.model_filename
with open(model_path, "wb") as model_file:
pickle.dump(self.model, model_file)
if treelite_sklearn is None:
raise RuntimeError(
"Both 'treelite' and 'treelite_runtime' "
"are required to save an sklearn random forest model."
)
treelite_model = treelite_sklearn.import_model(self.model)
treelite_model.serialize(str(model_path))

@property
def num_features(self):
Expand All @@ -512,13 +517,12 @@ def num_targets(self):
class CUMLRandomForest(FILModel):

model_type = "treelite_checkpoint"
model_filename = "model.pkl"
model_filename = "checkpoint.tl"

def save(self, version_path):
"""Save model to version_path."""
model_path = pathlib.Path(version_path) / self.model_filename
with open(model_path, "wb") as model_file:
pickle.dump(self.model, model_file)
self.model.convert_to_treelite_model().to_treelite_checkpoint(str(model_path))

@property
def num_features(self):
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/systems/fil/test_fil.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,23 @@ def test_regressor(get_model_fn, get_model_params, tmpdir):
assert config.output[0].dims == [1]


@pytest.mark.parametrize(
["get_model_fn", "expected_model_filename"],
[
(xgboost_regressor, "xgboost.json"),
(lightgbm_regressor, "model.txt"),
(sklearn_forest_regressor, "checkpoint.tl"),
],
)
def test_model_file(get_model_fn, expected_model_filename, tmpdir):
X, y = get_regression_data()
model = get_model_fn(X, y)
triton_op = fil_op.FIL(model)
_ = export_op(tmpdir, triton_op)
model_path = pathlib.Path(tmpdir) / "fil" / "1" / expected_model_filename
assert model_path.is_file()


def test_fil_op_exports_own_config(tmpdir):
X, y = get_regression_data()
model = xgboost_train(X, y, objective="reg:squarederror")
Expand Down

0 comments on commit efbf7f6

Please sign in to comment.