Skip to content

Commit

Permalink
Move ensemble treelite test from test_ensemble to test_forest
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverholworthy committed Jun 23, 2022
1 parent d1785a8 commit 078e196
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 39 deletions.
34 changes: 34 additions & 0 deletions tests/unit/systems/fil/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from merlin.systems.dag.ops.workflow import TransformWorkflow
from nvtabular import Workflow
from nvtabular import ops as wf_ops
from tests.unit.systems.utils.triton import _run_ensemble_on_tritonserver # noqa

tritonclient = pytest.importorskip("tritonclient")
import tritonclient.grpc.model_config_pb2 as model_config # noqa
Expand Down Expand Up @@ -145,3 +146,36 @@ def test_ensemble(tmpdir):
parsed_config = read_config(config_path)
assert parsed_config.name == "ensemble_model"
assert parsed_config.platform == "ensemble"


def test_fil_treelite_ensemble(tmpdir):
rows = 200
num_features = 16
X, y = sklearn.datasets.make_regression(
n_samples=rows,
n_features=num_features,
n_informative=num_features // 3,
random_state=0,
)
feature_names = [str(i) for i in range(num_features)]
df = pd.DataFrame(X, columns=feature_names, dtype=np.float32)

# Fit RF
model = sklearn.ensemble.RandomForestRegressor()
model.fit(X, y)

input_column_schemas = [ColumnSchema(col, dtype=np.float32) for col in feature_names]
input_schema = Schema(input_column_schemas)
selector = ColumnSelector(feature_names)

triton_chain = selector >> PredictForest(model, input_schema)

triton_ens = Ensemble(triton_chain, input_schema)

request_df = df[:5]
triton_ens.export(tmpdir)

response = _run_ensemble_on_tritonserver(
str(tmpdir), ["output__0"], request_df, triton_ens.name
)
assert response.as_numpy("output__0").shape == (5,)
40 changes: 1 addition & 39 deletions tests/unit/systems/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@
import os
from distutils.spawn import find_executable

import numpy as np
import pandas as pd
import pytest
import sklearn.datasets
import sklearn.ensemble

os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

Expand All @@ -31,8 +27,7 @@
from merlin.dag.node import postorder_iter_nodes # noqa
from merlin.dag.ops.concat_columns import ConcatColumns # noqa
from merlin.dag.ops.selection import SelectionOp # noqa
from merlin.schema import ColumnSchema, Schema, Tags # noqa
from merlin.systems.dag.ops.fil import PredictForest # noqa
from merlin.schema import Tags # noqa
from nvtabular import Workflow # noqa
from nvtabular import ops as wf_ops # noqa

Expand Down Expand Up @@ -181,36 +176,3 @@ def test_graph_traverse_algo():
assert len(ordered_list) == 5
assert isinstance(ordered_list[0].op, SelectionOp)
assert isinstance(ordered_list[-1].op, ConcatColumns)


def test_fil_treelite_ensemble(tmpdir):
rows = 200
num_features = 16
X, y = sklearn.datasets.make_regression(
n_samples=rows,
n_features=num_features,
n_informative=num_features // 3,
random_state=0,
)
feature_names = [str(i) for i in range(num_features)]
df = pd.DataFrame(X, columns=feature_names, dtype=np.float32)

# Fit RF
model = sklearn.ensemble.RandomForestRegressor()
model.fit(X, y)

input_column_schemas = [ColumnSchema(col, dtype=np.float32) for col in feature_names]
input_schema = Schema(input_column_schemas)
selector = ColumnSelector(feature_names)

triton_chain = selector >> PredictForest(model, input_schema)

triton_ens = Ensemble(triton_chain, input_schema)

request_df = df[:5]
triton_ens.export(tmpdir)

response = _run_ensemble_on_tritonserver(
str(tmpdir), ["output__0"], request_df, triton_ens.name
)
assert response.as_numpy("output__0").shape == (5,)

0 comments on commit 078e196

Please sign in to comment.