From 8b062252adf1aca76631c7a3d5465d7532ac66ae Mon Sep 17 00:00:00 2001 From: Li Liu Date: Tue, 15 Nov 2022 17:33:34 +0800 Subject: [PATCH] Optimize some low efficient code Signed-off-by: Li Liu --- pymilvus/client/check.py | 140 +++++++++++--------------------- pymilvus/client/grpc_handler.py | 40 ++++----- pymilvus/client/prepare.py | 3 +- pymilvus/orm/collection.py | 8 +- pymilvus/orm/partition.py | 2 +- tests/test_prepare.py | 2 +- 6 files changed, 70 insertions(+), 125 deletions(-) diff --git a/pymilvus/client/check.py b/pymilvus/client/check.py index 5c923cace..959318c49 100644 --- a/pymilvus/client/check.py +++ b/pymilvus/client/check.py @@ -3,6 +3,7 @@ from typing import Any, Union from ..exceptions import ParamError from ..grpc_gen import milvus_pb2 as milvus_types +from .singleton_utils import Singleton from .utils import ( valid_index_types, valid_binary_index_types, @@ -308,103 +309,56 @@ def is_legal_operate_privilege_type(operate_privilege_type: Any) -> bool: (milvus_types.OperatePrivilegeType.Grant, milvus_types.OperatePrivilegeType.Revoke) +class ParamChecker(metaclass=Singleton): + def __init__(self) -> None: + self.check_dict = { + "collection_name": is_legal_table_name, + "field_name": is_legal_field_name, + "dimension": is_legal_dimension, + "index_file_size": is_legal_index_size, + "topk": is_legal_topk, + "ids": is_legal_ids, + "nprobe": is_legal_nprobe, + "nlist": is_legal_nlist, + "cmd": is_legal_cmd, + "partition_name": is_legal_partition_name, + "partition_name_array": is_legal_partition_name_array, + "limit": is_legal_limit, + "anns_field": is_legal_anns_field, + "search_data": is_legal_search_data, + "output_fields": is_legal_output_fields, + "round_decimal": is_legal_round_decimal, + "travel_timestamp": is_legal_travel_timestamp, + "guarantee_timestamp": is_legal_guarantee_timestamp, + "user": is_legal_user, + "password": is_legal_password, + "role_name": is_legal_role_name, + "operate_user_role_type": is_legal_operate_user_role_type, + "include_user_info": is_legal_include_user_info, + "include_role_info": is_legal_include_role_info, + "object": is_legal_object, + "object_name": is_legal_object_name, + "privilege": is_legal_privilege, + "operate_privilege_type": is_legal_operate_privilege_type, + "properties": is_legal_collection_properties, + } + + def check(self, key, value): + if key in self.check_dict: + if not self.check_dict[key](value): + _raise_param_error(key, value) + else: + raise ParamError(message=f"unknown param `{key}`") + +def _get_param_checker(): + return ParamChecker() + def check_pass_param(*_args: Any, **kwargs: Any) -> None: # pylint: disable=too-many-statements if kwargs is None: raise ParamError(message="Param should not be None") - + checker = _get_param_checker() for key, value in kwargs.items(): - if key in ("collection_name",): - if not is_legal_table_name(value): - _raise_param_error(key, value) - elif key == "field_name": - if not is_legal_field_name(value): - _raise_param_error(key, value) - elif key == "dimension": - if not is_legal_dimension(value): - _raise_param_error(key, value) - elif key == "index_file_size": - if not is_legal_index_size(value): - _raise_param_error(key, value) - elif key in ("topk", "top_k"): - if not is_legal_topk(value): - _raise_param_error(key, value) - elif key in ("ids",): - if not is_legal_ids(value): - _raise_param_error(key, value) - elif key in ("nprobe",): - if not is_legal_nprobe(value): - _raise_param_error(key, value) - elif key in ("nlist",): - if not is_legal_nlist(value): - _raise_param_error(key, value) - elif key in ("cmd",): - if not is_legal_cmd(value): - _raise_param_error(key, value) - elif key in ("partition_name",): - if not is_legal_partition_name(value): - _raise_param_error(key, value) - elif key in ("partition_name_array",): - if not is_legal_partition_name_array(value): - _raise_param_error(key, value) - elif key in ("limit",): - if not is_legal_limit(value): - _raise_param_error(key, value) - elif key in ("anns_field",): - if not is_legal_anns_field(value): - _raise_param_error(key, value) - elif key in ("search_data",): - if not is_legal_search_data(value): - _raise_param_error(key, value) - elif key in ("output_fields",): - if not is_legal_output_fields(value): - _raise_param_error(key, value) - elif key in ("round_decimal",): - if not is_legal_round_decimal(value): - _raise_param_error(key, value) - elif key in ("travel_timestamp",): - if not is_legal_travel_timestamp(value): - _raise_param_error(key, value) - elif key in ("guarantee_timestamp",): - if not is_legal_guarantee_timestamp(value): - _raise_param_error(key, value) - elif key in ("user",): - if not is_legal_user(value): - _raise_param_error(key, value) - elif key in ("password",): - if not is_legal_password(value): - _raise_param_error(key, value) - # elif key in ("records",): - # if not is_legal_records(value): - # _raise_param_error(key, value) - elif key in ("role_name",): - if not is_legal_role_name(value): - _raise_param_error(key, value) - elif key in ("operate_user_role_type",): - if not is_legal_operate_user_role_type(value): - _raise_param_error(key, value) - elif key in ("include_user_info",): - if not is_legal_include_user_info(value): - _raise_param_error(key, value) - elif key in ("include_role_info",): - if not is_legal_include_role_info(value): - _raise_param_error(key, value) - elif key in ("object",): - if not is_legal_object(value): - _raise_param_error(key, value) - elif key in ("object_name",): - if not is_legal_object_name(value): - _raise_param_error(key, value) - elif key in ("privilege",): - if not is_legal_privilege(value): - _raise_param_error(key, value) - elif key in ("operate_privilege_type",): - if not is_legal_operate_privilege_type(value): - _raise_param_error(key, value) - elif key == "properties": - if not is_legal_collection_properties(value): - _raise_param_error(key, value) - else: - raise ParamError(message=f"unknown param `{key}`") + checker.check(key, value) def check_index_params(params): diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 66232daaa..33f5356f4 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -413,21 +413,17 @@ def _execute_search_requests(self, requests, timeout=None, **kwargs): auto_id = kwargs.get("auto_id", True) try: - raws = [] - futures = [] - - # step 1: get future object - for request in requests: - ft = self._stub.Search.future(request, timeout=timeout) - futures.append(ft) - if kwargs.get("_async", False): + futures = [] + for request in requests: + ft = self._stub.Search.future(request, timeout=timeout) + futures.append(ft) func = kwargs.get("_callback", None) return ChunkedSearchFuture(futures, func, auto_id) - # step2: get results - for ft in futures: - response = ft.result() + raws = [] + for request in requests: + response = self._stub.Search(request, timeout=timeout) if response.status.error_code != 0: raise MilvusException(response.status.error_code, response.status.reason) @@ -444,7 +440,7 @@ def _execute_search_requests(self, requests, timeout=None, **kwargs): @retry_on_rpc_failure(retry_on_deadline=False) def search(self, collection_name, data, anns_field, param, limit, expression=None, partition_names=None, output_fields=None, - round_decimal=-1, timeout=None, **kwargs): + round_decimal=-1, timeout=None, collection_schema=None, **kwargs): check_pass_param( limit=limit, round_decimal=round_decimal, @@ -456,26 +452,20 @@ def search(self, collection_name, data, anns_field, param, limit, guarantee_timestamp=kwargs.get("guarantee_timestamp", 0) ) - _kwargs = copy.deepcopy(kwargs) - - collection_schema = kwargs.get("schema", None) if not collection_schema: collection_schema = self.describe_collection(collection_name, timeout=timeout, **kwargs) - auto_id = collection_schema["auto_id"] + consistency_level = collection_schema["consistency_level"] # overwrite the consistency level defined when user created the collection - consistency_level = get_consistency_level(_kwargs.get("consistency_level", consistency_level)) - _kwargs["schema"] = collection_schema + consistency_level = get_consistency_level(kwargs.get("consistency_level", consistency_level)) - ts_utils.construct_guarantee_ts(consistency_level, collection_name, _kwargs) + ts_utils.construct_guarantee_ts(consistency_level, collection_name, kwargs) - requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, expression, - partition_names, output_fields, round_decimal, **_kwargs) - _kwargs.pop("schema") - _kwargs["auto_id"] = auto_id - _kwargs["round_decimal"] = round_decimal + requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, collection_schema, + expression, partition_names, output_fields, round_decimal, **kwargs) - return self._execute_search_requests(requests, timeout, **_kwargs) + auto_id = collection_schema["auto_id"] + return self._execute_search_requests(requests, timeout, round_decimal=round_decimal, auto_id=auto_id, **kwargs) @retry_on_rpc_failure() def get_query_segment_info(self, collection_name, timeout=30, **kwargs): diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 880610614..432e4f9f8 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -431,10 +431,9 @@ def extract_vectors_param(param, placeholders, names, round_decimal): return request @classmethod - def search_requests_with_expr(cls, collection_name, data, anns_field, param, limit, expr=None, partition_names=None, + def search_requests_with_expr(cls, collection_name, data, anns_field, param, limit, schema, expr=None, partition_names=None, output_fields=None, round_decimal=-1, **kwargs): # TODO Move this impl into server side - schema = kwargs.get("schema", None) fields_schema = schema.get("fields", None) # list fields_name_locs = {fields_schema[loc]["name"]: loc for loc in range(len(fields_schema))} diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 21e8807cd..a3901223e 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -134,6 +134,9 @@ def __init__(self, name, schema=None, using="default", shards_num=2, **kwargs): else: raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType) + self._schema_dict = self._schema.to_dict() + self._schema_dict["consistency_level"] = self._consistency_level + def __repr__(self): _dict = { 'name': self.name, @@ -653,10 +656,9 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) conn = self._get_connection() - schema_dict = self._schema.to_dict() - schema_dict["consistency_level"] = self._consistency_level res = conn.search(self._name, data, anns_field, param, limit, expr, - partition_names, output_fields, round_decimal, timeout=timeout, schema=schema_dict, **kwargs) + partition_names, output_fields, round_decimal, timeout=timeout, + collection_schema=self._schema_dict, **kwargs) if kwargs.get("_async", False): return SearchFuture(res) return SearchResult(res) diff --git a/pymilvus/orm/partition.py b/pymilvus/orm/partition.py index e57ebda87..c39a4ec8e 100644 --- a/pymilvus/orm/partition.py +++ b/pymilvus/orm/partition.py @@ -444,7 +444,7 @@ def search(self, data, anns_field, param, limit, schema_dict = self._schema.to_dict() schema_dict["consistency_level"] = self._consistency_level res = conn.search(self._collection.name, data, anns_field, param, limit, expr, [self._name], output_fields, - round_decimal=round_decimal, timeout=timeout, schema=schema_dict, **kwargs) + round_decimal=round_decimal, timeout=timeout, collection_schema=schema_dict, **kwargs) if kwargs.get("_async", False): return SearchFuture(res) return SearchResult(res) diff --git a/tests/test_prepare.py b/tests/test_prepare.py index 2d5ba1206..7265ade20 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -26,7 +26,7 @@ def test_search_requests_with_expr_offset(self): "offset": 10, } - ret = Prepare.search_requests_with_expr("name", data, "v", search_params, 100, schema=schema) + ret = Prepare.search_requests_with_expr("name", data, "v", search_params, 100, schema) offset_exists = False for p in ret[0].search_params: