From 966ad8650551e1984eddfb3dd05ee2855f72fb7a Mon Sep 17 00:00:00 2001
From: liliu-z <105927039+liliu-z@users.noreply.github.com>
Date: Wed, 16 Nov 2022 16:59:04 +0800
Subject: [PATCH] Optimize some low efficient code (#1223)
Signed-off-by: Li Liu
Signed-off-by: Li Liu
---
pymilvus/client/check.py | 140 +++++++++++---------------------
pymilvus/client/grpc_handler.py | 41 ++++------
pymilvus/client/prepare.py | 3 +-
pymilvus/orm/collection.py | 8 +-
pymilvus/orm/partition.py | 2 +-
tests/test_prepare.py | 2 +-
6 files changed, 71 insertions(+), 125 deletions(-)
diff --git a/pymilvus/client/check.py b/pymilvus/client/check.py
index 5c923cace..959318c49 100644
--- a/pymilvus/client/check.py
+++ b/pymilvus/client/check.py
@@ -3,6 +3,7 @@
from typing import Any, Union
from ..exceptions import ParamError
from ..grpc_gen import milvus_pb2 as milvus_types
+from .singleton_utils import Singleton
from .utils import (
valid_index_types,
valid_binary_index_types,
@@ -308,103 +309,56 @@ def is_legal_operate_privilege_type(operate_privilege_type: Any) -> bool:
(milvus_types.OperatePrivilegeType.Grant, milvus_types.OperatePrivilegeType.Revoke)
+class ParamChecker(metaclass=Singleton):
+ def __init__(self) -> None:
+ self.check_dict = {
+ "collection_name": is_legal_table_name,
+ "field_name": is_legal_field_name,
+ "dimension": is_legal_dimension,
+ "index_file_size": is_legal_index_size,
+ "topk": is_legal_topk,
+ "ids": is_legal_ids,
+ "nprobe": is_legal_nprobe,
+ "nlist": is_legal_nlist,
+ "cmd": is_legal_cmd,
+ "partition_name": is_legal_partition_name,
+ "partition_name_array": is_legal_partition_name_array,
+ "limit": is_legal_limit,
+ "anns_field": is_legal_anns_field,
+ "search_data": is_legal_search_data,
+ "output_fields": is_legal_output_fields,
+ "round_decimal": is_legal_round_decimal,
+ "travel_timestamp": is_legal_travel_timestamp,
+ "guarantee_timestamp": is_legal_guarantee_timestamp,
+ "user": is_legal_user,
+ "password": is_legal_password,
+ "role_name": is_legal_role_name,
+ "operate_user_role_type": is_legal_operate_user_role_type,
+ "include_user_info": is_legal_include_user_info,
+ "include_role_info": is_legal_include_role_info,
+ "object": is_legal_object,
+ "object_name": is_legal_object_name,
+ "privilege": is_legal_privilege,
+ "operate_privilege_type": is_legal_operate_privilege_type,
+ "properties": is_legal_collection_properties,
+ }
+
+ def check(self, key, value):
+ if key in self.check_dict:
+ if not self.check_dict[key](value):
+ _raise_param_error(key, value)
+ else:
+ raise ParamError(message=f"unknown param `{key}`")
+
+def _get_param_checker():
+ return ParamChecker()
+
def check_pass_param(*_args: Any, **kwargs: Any) -> None: # pylint: disable=too-many-statements
if kwargs is None:
raise ParamError(message="Param should not be None")
-
+ checker = _get_param_checker()
for key, value in kwargs.items():
- if key in ("collection_name",):
- if not is_legal_table_name(value):
- _raise_param_error(key, value)
- elif key == "field_name":
- if not is_legal_field_name(value):
- _raise_param_error(key, value)
- elif key == "dimension":
- if not is_legal_dimension(value):
- _raise_param_error(key, value)
- elif key == "index_file_size":
- if not is_legal_index_size(value):
- _raise_param_error(key, value)
- elif key in ("topk", "top_k"):
- if not is_legal_topk(value):
- _raise_param_error(key, value)
- elif key in ("ids",):
- if not is_legal_ids(value):
- _raise_param_error(key, value)
- elif key in ("nprobe",):
- if not is_legal_nprobe(value):
- _raise_param_error(key, value)
- elif key in ("nlist",):
- if not is_legal_nlist(value):
- _raise_param_error(key, value)
- elif key in ("cmd",):
- if not is_legal_cmd(value):
- _raise_param_error(key, value)
- elif key in ("partition_name",):
- if not is_legal_partition_name(value):
- _raise_param_error(key, value)
- elif key in ("partition_name_array",):
- if not is_legal_partition_name_array(value):
- _raise_param_error(key, value)
- elif key in ("limit",):
- if not is_legal_limit(value):
- _raise_param_error(key, value)
- elif key in ("anns_field",):
- if not is_legal_anns_field(value):
- _raise_param_error(key, value)
- elif key in ("search_data",):
- if not is_legal_search_data(value):
- _raise_param_error(key, value)
- elif key in ("output_fields",):
- if not is_legal_output_fields(value):
- _raise_param_error(key, value)
- elif key in ("round_decimal",):
- if not is_legal_round_decimal(value):
- _raise_param_error(key, value)
- elif key in ("travel_timestamp",):
- if not is_legal_travel_timestamp(value):
- _raise_param_error(key, value)
- elif key in ("guarantee_timestamp",):
- if not is_legal_guarantee_timestamp(value):
- _raise_param_error(key, value)
- elif key in ("user",):
- if not is_legal_user(value):
- _raise_param_error(key, value)
- elif key in ("password",):
- if not is_legal_password(value):
- _raise_param_error(key, value)
- # elif key in ("records",):
- # if not is_legal_records(value):
- # _raise_param_error(key, value)
- elif key in ("role_name",):
- if not is_legal_role_name(value):
- _raise_param_error(key, value)
- elif key in ("operate_user_role_type",):
- if not is_legal_operate_user_role_type(value):
- _raise_param_error(key, value)
- elif key in ("include_user_info",):
- if not is_legal_include_user_info(value):
- _raise_param_error(key, value)
- elif key in ("include_role_info",):
- if not is_legal_include_role_info(value):
- _raise_param_error(key, value)
- elif key in ("object",):
- if not is_legal_object(value):
- _raise_param_error(key, value)
- elif key in ("object_name",):
- if not is_legal_object_name(value):
- _raise_param_error(key, value)
- elif key in ("privilege",):
- if not is_legal_privilege(value):
- _raise_param_error(key, value)
- elif key in ("operate_privilege_type",):
- if not is_legal_operate_privilege_type(value):
- _raise_param_error(key, value)
- elif key == "properties":
- if not is_legal_collection_properties(value):
- _raise_param_error(key, value)
- else:
- raise ParamError(message=f"unknown param `{key}`")
+ checker.check(key, value)
def check_index_params(params):
diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py
index 66232daaa..3eedd5a30 100644
--- a/pymilvus/client/grpc_handler.py
+++ b/pymilvus/client/grpc_handler.py
@@ -413,21 +413,17 @@ def _execute_search_requests(self, requests, timeout=None, **kwargs):
auto_id = kwargs.get("auto_id", True)
try:
- raws = []
- futures = []
-
- # step 1: get future object
- for request in requests:
- ft = self._stub.Search.future(request, timeout=timeout)
- futures.append(ft)
-
if kwargs.get("_async", False):
+ futures = []
+ for request in requests:
+ ft = self._stub.Search.future(request, timeout=timeout)
+ futures.append(ft)
func = kwargs.get("_callback", None)
return ChunkedSearchFuture(futures, func, auto_id)
- # step2: get results
- for ft in futures:
- response = ft.result()
+ raws = []
+ for request in requests:
+ response = self._stub.Search(request, timeout=timeout)
if response.status.error_code != 0:
raise MilvusException(response.status.error_code, response.status.reason)
@@ -444,7 +440,7 @@ def _execute_search_requests(self, requests, timeout=None, **kwargs):
@retry_on_rpc_failure(retry_on_deadline=False)
def search(self, collection_name, data, anns_field, param, limit,
expression=None, partition_names=None, output_fields=None,
- round_decimal=-1, timeout=None, **kwargs):
+ round_decimal=-1, timeout=None, collection_schema=None, **kwargs):
check_pass_param(
limit=limit,
round_decimal=round_decimal,
@@ -456,26 +452,21 @@ def search(self, collection_name, data, anns_field, param, limit,
guarantee_timestamp=kwargs.get("guarantee_timestamp", 0)
)
- _kwargs = copy.deepcopy(kwargs)
-
- collection_schema = kwargs.get("schema", None)
if not collection_schema:
collection_schema = self.describe_collection(collection_name, timeout=timeout, **kwargs)
- auto_id = collection_schema["auto_id"]
+
consistency_level = collection_schema["consistency_level"]
# overwrite the consistency level defined when user created the collection
- consistency_level = get_consistency_level(_kwargs.get("consistency_level", consistency_level))
- _kwargs["schema"] = collection_schema
+ consistency_level = get_consistency_level(kwargs.get("consistency_level", consistency_level))
- ts_utils.construct_guarantee_ts(consistency_level, collection_name, _kwargs)
+ ts_utils.construct_guarantee_ts(consistency_level, collection_name, kwargs)
- requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, expression,
- partition_names, output_fields, round_decimal, **_kwargs)
- _kwargs.pop("schema")
- _kwargs["auto_id"] = auto_id
- _kwargs["round_decimal"] = round_decimal
+ requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, collection_schema,
+ expression, partition_names, output_fields, round_decimal,
+ **kwargs)
- return self._execute_search_requests(requests, timeout, **_kwargs)
+ auto_id = collection_schema["auto_id"]
+ return self._execute_search_requests(requests, timeout, round_decimal=round_decimal, auto_id=auto_id, **kwargs)
@retry_on_rpc_failure()
def get_query_segment_info(self, collection_name, timeout=30, **kwargs):
diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py
index 880610614..432e4f9f8 100644
--- a/pymilvus/client/prepare.py
+++ b/pymilvus/client/prepare.py
@@ -431,10 +431,9 @@ def extract_vectors_param(param, placeholders, names, round_decimal):
return request
@classmethod
- def search_requests_with_expr(cls, collection_name, data, anns_field, param, limit, expr=None, partition_names=None,
+ def search_requests_with_expr(cls, collection_name, data, anns_field, param, limit, schema, expr=None, partition_names=None,
output_fields=None, round_decimal=-1, **kwargs):
# TODO Move this impl into server side
- schema = kwargs.get("schema", None)
fields_schema = schema.get("fields", None) # list
fields_name_locs = {fields_schema[loc]["name"]: loc for loc in range(len(fields_schema))}
diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py
index 21e8807cd..192abc996 100644
--- a/pymilvus/orm/collection.py
+++ b/pymilvus/orm/collection.py
@@ -134,6 +134,9 @@ def __init__(self, name, schema=None, using="default", shards_num=2, **kwargs):
else:
raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType)
+ self._schema_dict = self._schema.to_dict()
+ self._schema_dict["consistency_level"] = self._consistency_level
+
def __repr__(self):
_dict = {
'name': self.name,
@@ -653,10 +656,9 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))
conn = self._get_connection()
- schema_dict = self._schema.to_dict()
- schema_dict["consistency_level"] = self._consistency_level
res = conn.search(self._name, data, anns_field, param, limit, expr,
- partition_names, output_fields, round_decimal, timeout=timeout, schema=schema_dict, **kwargs)
+ partition_names, output_fields, round_decimal, timeout=timeout,
+ collection_schema=self._schema_dict, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)
diff --git a/pymilvus/orm/partition.py b/pymilvus/orm/partition.py
index e57ebda87..c39a4ec8e 100644
--- a/pymilvus/orm/partition.py
+++ b/pymilvus/orm/partition.py
@@ -444,7 +444,7 @@ def search(self, data, anns_field, param, limit,
schema_dict = self._schema.to_dict()
schema_dict["consistency_level"] = self._consistency_level
res = conn.search(self._collection.name, data, anns_field, param, limit, expr, [self._name], output_fields,
- round_decimal=round_decimal, timeout=timeout, schema=schema_dict, **kwargs)
+ round_decimal=round_decimal, timeout=timeout, collection_schema=schema_dict, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)
diff --git a/tests/test_prepare.py b/tests/test_prepare.py
index 2d5ba1206..7265ade20 100644
--- a/tests/test_prepare.py
+++ b/tests/test_prepare.py
@@ -26,7 +26,7 @@ def test_search_requests_with_expr_offset(self):
"offset": 10,
}
- ret = Prepare.search_requests_with_expr("name", data, "v", search_params, 100, schema=schema)
+ ret = Prepare.search_requests_with_expr("name", data, "v", search_params, 100, schema)
offset_exists = False
for p in ret[0].search_params: