diff --git a/deploy/mlflow-triton-plugin/examples/onnx_float32_int32_int32/config.pbtxt b/deploy/mlflow-triton-plugin/examples/onnx_float32_int32_int32/config.pbtxt index 74cc2a6c78..75ea016cfa 100755 --- a/deploy/mlflow-triton-plugin/examples/onnx_float32_int32_int32/config.pbtxt +++ b/deploy/mlflow-triton-plugin/examples/onnx_float32_int32_int32/config.pbtxt @@ -1,5 +1,5 @@ -# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -24,7 +24,6 @@ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -name: "onnx_float32_int32_int32" platform: "onnxruntime_onnx" max_batch_size: 8 version_policy: { latest { num_versions: 1 }} diff --git a/qa/L0_mlflow/plugin_test.py b/qa/L0_mlflow/plugin_test.py index bed892bb54..8dbf9d9146 100644 --- a/qa/L0_mlflow/plugin_test.py +++ b/qa/L0_mlflow/plugin_test.py @@ -42,31 +42,20 @@ class PluginTest(tu.TestResultCollector): def setUp(self): self.client_ = get_deploy_client('triton') - def test_onnx_flavor(self): - # Log the ONNX model to MLFlow - import mlflow.onnx - import onnx - model = onnx.load( - "./mlflow-triton-plugin/examples/onnx_float32_int32_int32/1/model.onnx" - ) - # Use a different name to ensure the plugin operates on correct model - mlflow.onnx.log_model(model, - "triton", - registered_model_name="onnx_model") - + def _validate_deployment(self, model_name): # create - self.client_.create_deployment("onnx_model", - "models:/onnx_model/1", + self.client_.create_deployment(model_name, + "models:/{}/1".format(model_name), flavor="onnx") # list deployment_list = self.client_.list_deployments() self.assertEqual(len(deployment_list), 1) - self.assertEqual(deployment_list[0]['name'], "onnx_model") + self.assertEqual(deployment_list[0]['name'], model_name) # get - deployment = self.client_.get_deployment("onnx_model") - self.assertEqual(deployment['name'], "onnx_model") + deployment = self.client_.get_deployment(model_name) + self.assertEqual(deployment['name'], model_name) # predict inputs = {} @@ -75,7 +64,7 @@ def test_onnx_flavor(self): for key, value in input_json['inputs'].items(): inputs[key] = np.array(value, dtype=np.float32) - output = self.client_.predict("onnx_model", inputs) + output = self.client_.predict(model_name, inputs) with open("./mlflow-triton-plugin/examples/expected_output.json", "r") as f: output_json = json.load(f) @@ -86,7 +75,43 @@ def test_onnx_flavor(self): err_msg='Inference result is not correct') # delete - self.client_.delete_deployment("onnx_model") + self.client_.delete_deployment(model_name) + + def test_onnx_flavor(self): + # Log the ONNX model to MLFlow + import mlflow.onnx + import onnx + model = onnx.load( + "./mlflow-triton-plugin/examples/onnx_float32_int32_int32/1/model.onnx" + ) + # Use a different name to ensure the plugin operates on correct model + mlflow.onnx.log_model(model, + "triton", + registered_model_name="onnx_model") + + self._validate_deployment("onnx_model") + + def test_onnx_flavor_with_files(self): + # Log the ONNX model and additional Triton config file to MLFlow + import mlflow.onnx + import onnx + model = onnx.load( + "./mlflow-triton-plugin/examples/onnx_float32_int32_int32/1/model.onnx" + ) + config_path = "./mlflow-triton-plugin/examples/onnx_float32_int32_int32/config.pbtxt" + # Use a different name to ensure the plugin operates on correct model + mlflow.onnx.log_model(model, + "triton", + registered_model_name="onnx_model_with_files") + mlflow.log_artifact(config_path, "triton") + + self._validate_deployment("onnx_model_with_files") + + # Check if the additional files are properly copied + import filecmp + self.assertTrue( + filecmp.cmp(config_path, + "./models/onnx_model_with_files/config.pbtxt")) if __name__ == '__main__': diff --git a/qa/L0_mlflow/test.sh b/qa/L0_mlflow/test.sh index 13f04a8d1a..74c9348f1d 100644 --- a/qa/L0_mlflow/test.sh +++ b/qa/L0_mlflow/test.sh @@ -45,6 +45,14 @@ mkdir -p ./mlflow/artifacts pip install ./mlflow-triton-plugin/ +# Clear mlflow registered models if any +python - << EOF +from mlflow.tracking import MlflowClient +c = MlflowClient() +for m in c.list_registered_models(): + c.delete_registered_model(m.name) +EOF + rm -rf ./models mkdir -p ./models SERVER=/opt/tritonserver/bin/tritonserver @@ -137,6 +145,7 @@ if [ $CLI_RET -ne 0 ]; then fi set -e +# ONNX flavor with Python package set +e PY_LOG=plugin_py.log PY_TEST=plugin_test.py @@ -147,7 +156,7 @@ if [ $? -ne 0 ]; then echo -e "\n***\n*** Python Test Failed\n***" RET=1 else - check_test_results $TEST_RESULT_FILE 1 + check_test_results $TEST_RESULT_FILE 2 if [ $? -ne 0 ]; then cat $PY_LOG echo -e "\n***\n*** Test Result Verification Failed\n***"