From d65e3e1ef4d74f49a48f593a255bb95cf88af40c Mon Sep 17 00:00:00 2001 From: yimingl9 Date: Mon, 16 Oct 2023 11:22:18 -0400 Subject: [PATCH 1/2] add predict functionality Signed-off-by: yimingl9 --- .../ml_commons/ml_commons_client.py | 28 +++++++++ tests/ml_commons/test_ml_commons_client.py | 58 +++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index b4198d10d..0626f9b5b 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -494,6 +494,34 @@ def generate_embedding(self, model_id: str, sentences: List[str]) -> object: body=API_BODY, ) + def predict(self, model_id: str, algo_name: str, input_json): + + API_URL = f"{ML_BASE_URI}/_predict/{algo_name}/{model_id}/_deploy" + + if isinstance(input_json, str): + try: + json_obj = json.loads(input_json) + if not isinstance(json_obj, dict): + return "Invalid JSON object passed as argument." + API_BODY = json.dumps(json_obj) + except json.JSONDecodeError: + return "Invalid JSON string passed as argument." + elif isinstance(input_json, dict): + API_BODY = json.dumps(input_json) + else: + return "Invalid JSON object passed as argument." + + return self._client.transport.perform_request( + method="POST", + url=API_URL, + body=API_BODY, + ) + + + + + + @deprecated( reason="Since OpenSearch 2.7.0, you can use undeploy_model instead", version="2.7.0", diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 10be2c164..2259dc042 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -241,6 +241,64 @@ def test_DEPRECATED_integration_pretrained_model_upload_unload_delete(): raised = True assert raised == False, "Raised Exception in deleting pretrained model" +def test_predict(): + input_json = { + { + "input_query": { + "_source": ["petal_length_in_cm", "petal_width_in_cm"], + "size": 10000 + }, + "input_index": [ + "iris_data" + ] + } + } + + raised = False + model_id = ml_client.register_pretrained_model( + model_name=PRETRAINED_MODEL_NAME, + model_version=PRETRAINED_MODEL_VERSION, + model_format=PRETRAINED_MODEL_FORMAT, + deploy_model=True, + wait_until_deployed=True, + ) + + try: + predict_obj = ml_client.predict( + model_id=model_id, algo_name="kmeans",input_json=input_json + ) + assert predict_obj["status"] == "COMPLETED" + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in training and predicting task" + + raised = False + try: + predict_obj = ml_client.predict( + model_id=model_id, algo_name="something else",input_json=input_json + ) + assert predict_obj == "Invalid algorithm name passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in training and predicting task" + + try: + predict_obj = ml_client.predict( + model_id=model_id, algo_name="something else",input_json="15" + ) + assert predict_obj == "Invalid JSON object passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in training and predicting task" + + try: + predict_obj = ml_client.predict( + model_id=model_id, algo_name="something else",input_json=15 + ) + assert predict_obj == "Invalid JSON object passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in training and predicting task" def test_integration_pretrained_model_register_undeploy_delete(): raised = False From 133fbbd004ac5a8e7bd46287df9e6b5ef7b07747 Mon Sep 17 00:00:00 2001 From: yimingl9 Date: Mon, 23 Oct 2023 21:43:29 -0400 Subject: [PATCH 2/2] Signed-off-by: yimingl9 Modify request Signed-off-by: yimingl9 --- opensearch_py_ml/ml_commons/ml_commons_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 0626f9b5b..ab7bc6825 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -496,7 +496,7 @@ def generate_embedding(self, model_id: str, sentences: List[str]) -> object: def predict(self, model_id: str, algo_name: str, input_json): - API_URL = f"{ML_BASE_URI}/_predict/{algo_name}/{model_id}/_deploy" + API_URL = f"{ML_BASE_URI}/_predict/{algo_name}/{model_id}" if isinstance(input_json, str): try: