From c233356fa1327eea572a77cacc9e24ab754f1144 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Wed, 13 Mar 2024 11:51:07 -0700 Subject: [PATCH] enhance connector helper notebook to support 2.9 (#2202) Signed-off-by: Yaliang Wu --- docs/tutorials/aws/AIConnectorHelper.ipynb | 56 +++++++++++++++++-- ...ntic_search_with_byte_quantized_vector.md} | 0 2 files changed, 51 insertions(+), 5 deletions(-) rename docs/tutorials/semantic_search/{conversation_search_with_byte_quantized_vector.md => semantic_search_with_byte_quantized_vector.md} (100%) diff --git a/docs/tutorials/aws/AIConnectorHelper.ipynb b/docs/tutorials/aws/AIConnectorHelper.ipynb index 713ba6c331..e215ee2215 100644 --- a/docs/tutorials/aws/AIConnectorHelper.ipynb +++ b/docs/tutorials/aws/AIConnectorHelper.ipynb @@ -15,7 +15,7 @@ "from requests_aws4auth import AWS4Auth\n", "import time\n", "\n", - "# This python code works for AWS OpenSearch 2.11\n", + "# This Python code is compatible with AWS OpenSearch versions 2.9 and higher.\n", "class AIConnectorHelper:\n", " \n", " def __init__(self, region, opensearch_domain_name, opensearch_domain_username, opensearch_domain_password, aws_user_name):\n", @@ -265,24 +265,70 @@ " connector_id = json.loads(r.text)['connector_id']\n", " return connector_id\n", " \n", + " def search_model_group(self, model_group_name):\n", + " payload = {\n", + " \"query\": {\n", + " \"term\": {\n", + " \"name.keyword\": {\n", + " \"value\": model_group_name\n", + " }\n", + " }\n", + " }\n", + " }\n", + " headers = {\"Content-Type\": \"application/json\"}\n", + " r = requests.post(f'{self.opensearch_domain_url}/_plugins/_ml/model_groups/_search',\n", + " auth=HTTPBasicAuth(self.opensearch_domain_username, self.opensearch_domain_opensearch_domain_password),\n", + " json=payload,\n", + " headers=headers)\n", + " #print(r.text)\n", + " response = json.loads(r.text)\n", + " return response\n", + " \n", + " def create_model_group(self, model_group_name, description):\n", + " search_model_group_response = self.search_model_group(model_group_name)\n", + " if search_model_group_response['hits']['total']['value'] > 0:\n", + " return search_model_group_response['hits']['hits'][0]['_id']\n", + " payload = {\n", + " \"name\": model_group_name,\n", + " \"description\": description\n", + " }\n", + " headers = {\"Content-Type\": \"application/json\"}\n", + " r = requests.post(f'{self.opensearch_domain_url}/_plugins/_ml/model_groups/_register',\n", + " auth=HTTPBasicAuth(self.opensearch_domain_username, self.opensearch_domain_opensearch_domain_password),\n", + " json=payload,\n", + " headers=headers)\n", + " print(r.text)\n", + " response = json.loads(r.text)\n", + " return response['model_group_id']\n", + " \n", + " def get_task(self, task_id):\n", + " return requests.get(f'{self.opensearch_domain_url}/_plugins/_ml/tasks/{task_id}',\n", + " auth=HTTPBasicAuth(self.opensearch_domain_username, self.opensearch_domain_opensearch_domain_password))\n", + " \n", " def create_model(self, model_name, description, connector_id, deploy=True):\n", + " model_group_id = self.create_model_group(model_name, description)\n", " payload = {\n", " \"name\": model_name,\n", " \"function_name\": \"remote\",\n", " \"description\": description,\n", + " \"model_group_id\": model_group_id,\n", " \"connector_id\": connector_id\n", " }\n", - "\n", " headers = {\"Content-Type\": \"application/json\"}\n", - "\n", " deploy_str = str(deploy).lower()\n", " r = requests.post(f'{self.opensearch_domain_url}/_plugins/_ml/models/_register?deploy={deploy_str}',\n", " auth=HTTPBasicAuth(self.opensearch_domain_username, self.opensearch_domain_opensearch_domain_password),\n", " json=payload,\n", " headers=headers)\n", " print(r.text)\n", - " model_id = json.loads(r.text)['model_id']\n", - " return model_id\n", + " response = json.loads(r.text)\n", + " if 'model_id' in response:\n", + " return response['model_id']\n", + " else:\n", + " time.sleep(2) # sleep two seconds for task complete\n", + " r = self.get_task(response['task_id'])\n", + " print(r.text)\n", + " return json.loads(r.text)['model_id']\n", " \n", " def deploy_model(self, model_id):\n", " return requests.post(f'{self.opensearch_domain_url}/_plugins/_ml/models/{model_id}/_deploy',\n", diff --git a/docs/tutorials/semantic_search/conversation_search_with_byte_quantized_vector.md b/docs/tutorials/semantic_search/semantic_search_with_byte_quantized_vector.md similarity index 100% rename from docs/tutorials/semantic_search/conversation_search_with_byte_quantized_vector.md rename to docs/tutorials/semantic_search/semantic_search_with_byte_quantized_vector.md