diff --git a/examples/Serving-Ranking-Models-With-Merlin-Systems.ipynb b/examples/Serving-Ranking-Models-With-Merlin-Systems.ipynb index 9d5f2394f..3a948a2cd 100644 --- a/examples/Serving-Ranking-Models-With-Merlin-Systems.ipynb +++ b/examples/Serving-Ranking-Models-With-Merlin-Systems.ipynb @@ -1049,9 +1049,11 @@ "\n", "inputs = convert_df_to_triton_input(workflow.input_schema.column_names, batch, grpcclient.InferInput)\n", "\n", + "output_cols = ensemble.graph.output_schema.column_names\n", + "\n", "outputs = [\n", " grpcclient.InferRequestedOutput(col)\n", - " for col in ensemble.graph.output_schema.column_names\n", + " for col in output_cols\n", "]" ] }, diff --git a/tests/unit/examples/test_serving_ranking_models_with_merlin_systems.py b/tests/unit/examples/test_serving_ranking_models_with_merlin_systems.py index 405f04710..aba007229 100644 --- a/tests/unit/examples/test_serving_ranking_models_with_merlin_systems.py +++ b/tests/unit/examples/test_serving_ranking_models_with_merlin_systems.py @@ -84,20 +84,24 @@ def test_example_04_exporting_ranking_models(tb): NUM_OF_CELLS = len(tb.cells) tb.execute_cell(list(range(0, NUM_OF_CELLS - 12))) tb.execute_cell(list(range(NUM_OF_CELLS - 9, NUM_OF_CELLS - 6))) - tb.inject( - """ - import shutil - from merlin.models.loader.tf_utils import configure_tensorflow - configure_tensorflow() - from merlin.systems.triton.utils import run_ensemble_on_tritonserver - outputs = ensemble.graph.output_schema.column_names - response = run_ensemble_on_tritonserver( - "/tmp/data/ensemble/", outputs, batch, "ensemble_model" - ) - response = [x.tolist()[0] for x in response["click/binary_classification_task"]] - #shutil.rmtree("/tmp/data/", ignore_errors=True) - """ + from merlin.core.dispatch import get_lib + + df_lib = get_lib() + + # original_data_path = os.environ.get("INPUT_FOLDER", "/workspace/data/") + + # read in data for request + batch = df_lib.read_parquet( + os.path.join("/tmp/data/", "valid", "part.0.parquet"), + num_rows=3, + columns=workflow.input_schema.column_names, ) - tb.execute_cell(NUM_OF_CELLS - 6) - response = tb.ref("response") - assert len(response) == 3 + batch = batch.drop(columns="click") + outputs = tb.ref("output_cols") + from merlin.models.loader.tf_utils import configure_tensorflow + + configure_tensorflow() + from merlin.systems.triton.utils import run_ensemble_on_tritonserver + + response = run_ensemble_on_tritonserver("/tmp/data/ensemble/", outputs, batch, "ensemble_model") + assert len(response["click/binary_classification_task"]) == 3