Skip to content

Commit

Permalink
Optimize some low efficient code
Browse files Browse the repository at this point in the history
Signed-off-by: Li Liu <[email protected]>
  • Loading branch information
liliu-z committed Nov 15, 2022
1 parent d544834 commit 8b06225
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 125 deletions.
140 changes: 47 additions & 93 deletions pymilvus/client/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Union
from ..exceptions import ParamError
from ..grpc_gen import milvus_pb2 as milvus_types
from .singleton_utils import Singleton
from .utils import (
valid_index_types,
valid_binary_index_types,
Expand Down Expand Up @@ -308,103 +309,56 @@ def is_legal_operate_privilege_type(operate_privilege_type: Any) -> bool:
(milvus_types.OperatePrivilegeType.Grant, milvus_types.OperatePrivilegeType.Revoke)


class ParamChecker(metaclass=Singleton):
def __init__(self) -> None:
self.check_dict = {
"collection_name": is_legal_table_name,
"field_name": is_legal_field_name,
"dimension": is_legal_dimension,
"index_file_size": is_legal_index_size,
"topk": is_legal_topk,
"ids": is_legal_ids,
"nprobe": is_legal_nprobe,
"nlist": is_legal_nlist,
"cmd": is_legal_cmd,
"partition_name": is_legal_partition_name,
"partition_name_array": is_legal_partition_name_array,
"limit": is_legal_limit,
"anns_field": is_legal_anns_field,
"search_data": is_legal_search_data,
"output_fields": is_legal_output_fields,
"round_decimal": is_legal_round_decimal,
"travel_timestamp": is_legal_travel_timestamp,
"guarantee_timestamp": is_legal_guarantee_timestamp,
"user": is_legal_user,
"password": is_legal_password,
"role_name": is_legal_role_name,
"operate_user_role_type": is_legal_operate_user_role_type,
"include_user_info": is_legal_include_user_info,
"include_role_info": is_legal_include_role_info,
"object": is_legal_object,
"object_name": is_legal_object_name,
"privilege": is_legal_privilege,
"operate_privilege_type": is_legal_operate_privilege_type,
"properties": is_legal_collection_properties,
}

def check(self, key, value):
if key in self.check_dict:
if not self.check_dict[key](value):
_raise_param_error(key, value)
else:
raise ParamError(message=f"unknown param `{key}`")

def _get_param_checker():
return ParamChecker()

def check_pass_param(*_args: Any, **kwargs: Any) -> None: # pylint: disable=too-many-statements
if kwargs is None:
raise ParamError(message="Param should not be None")

checker = _get_param_checker()
for key, value in kwargs.items():
if key in ("collection_name",):
if not is_legal_table_name(value):
_raise_param_error(key, value)
elif key == "field_name":
if not is_legal_field_name(value):
_raise_param_error(key, value)
elif key == "dimension":
if not is_legal_dimension(value):
_raise_param_error(key, value)
elif key == "index_file_size":
if not is_legal_index_size(value):
_raise_param_error(key, value)
elif key in ("topk", "top_k"):
if not is_legal_topk(value):
_raise_param_error(key, value)
elif key in ("ids",):
if not is_legal_ids(value):
_raise_param_error(key, value)
elif key in ("nprobe",):
if not is_legal_nprobe(value):
_raise_param_error(key, value)
elif key in ("nlist",):
if not is_legal_nlist(value):
_raise_param_error(key, value)
elif key in ("cmd",):
if not is_legal_cmd(value):
_raise_param_error(key, value)
elif key in ("partition_name",):
if not is_legal_partition_name(value):
_raise_param_error(key, value)
elif key in ("partition_name_array",):
if not is_legal_partition_name_array(value):
_raise_param_error(key, value)
elif key in ("limit",):
if not is_legal_limit(value):
_raise_param_error(key, value)
elif key in ("anns_field",):
if not is_legal_anns_field(value):
_raise_param_error(key, value)
elif key in ("search_data",):
if not is_legal_search_data(value):
_raise_param_error(key, value)
elif key in ("output_fields",):
if not is_legal_output_fields(value):
_raise_param_error(key, value)
elif key in ("round_decimal",):
if not is_legal_round_decimal(value):
_raise_param_error(key, value)
elif key in ("travel_timestamp",):
if not is_legal_travel_timestamp(value):
_raise_param_error(key, value)
elif key in ("guarantee_timestamp",):
if not is_legal_guarantee_timestamp(value):
_raise_param_error(key, value)
elif key in ("user",):
if not is_legal_user(value):
_raise_param_error(key, value)
elif key in ("password",):
if not is_legal_password(value):
_raise_param_error(key, value)
# elif key in ("records",):
# if not is_legal_records(value):
# _raise_param_error(key, value)
elif key in ("role_name",):
if not is_legal_role_name(value):
_raise_param_error(key, value)
elif key in ("operate_user_role_type",):
if not is_legal_operate_user_role_type(value):
_raise_param_error(key, value)
elif key in ("include_user_info",):
if not is_legal_include_user_info(value):
_raise_param_error(key, value)
elif key in ("include_role_info",):
if not is_legal_include_role_info(value):
_raise_param_error(key, value)
elif key in ("object",):
if not is_legal_object(value):
_raise_param_error(key, value)
elif key in ("object_name",):
if not is_legal_object_name(value):
_raise_param_error(key, value)
elif key in ("privilege",):
if not is_legal_privilege(value):
_raise_param_error(key, value)
elif key in ("operate_privilege_type",):
if not is_legal_operate_privilege_type(value):
_raise_param_error(key, value)
elif key == "properties":
if not is_legal_collection_properties(value):
_raise_param_error(key, value)
else:
raise ParamError(message=f"unknown param `{key}`")
checker.check(key, value)


def check_index_params(params):
Expand Down
40 changes: 15 additions & 25 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,21 +413,17 @@ def _execute_search_requests(self, requests, timeout=None, **kwargs):
auto_id = kwargs.get("auto_id", True)

try:
raws = []
futures = []

# step 1: get future object
for request in requests:
ft = self._stub.Search.future(request, timeout=timeout)
futures.append(ft)

if kwargs.get("_async", False):
futures = []
for request in requests:
ft = self._stub.Search.future(request, timeout=timeout)
futures.append(ft)
func = kwargs.get("_callback", None)
return ChunkedSearchFuture(futures, func, auto_id)

# step2: get results
for ft in futures:
response = ft.result()
raws = []
for request in requests:
response = self._stub.Search(request, timeout=timeout)

if response.status.error_code != 0:
raise MilvusException(response.status.error_code, response.status.reason)
Expand All @@ -444,7 +440,7 @@ def _execute_search_requests(self, requests, timeout=None, **kwargs):
@retry_on_rpc_failure(retry_on_deadline=False)
def search(self, collection_name, data, anns_field, param, limit,
expression=None, partition_names=None, output_fields=None,
round_decimal=-1, timeout=None, **kwargs):
round_decimal=-1, timeout=None, collection_schema=None, **kwargs):
check_pass_param(
limit=limit,
round_decimal=round_decimal,
Expand All @@ -456,26 +452,20 @@ def search(self, collection_name, data, anns_field, param, limit,
guarantee_timestamp=kwargs.get("guarantee_timestamp", 0)
)

_kwargs = copy.deepcopy(kwargs)

collection_schema = kwargs.get("schema", None)
if not collection_schema:
collection_schema = self.describe_collection(collection_name, timeout=timeout, **kwargs)
auto_id = collection_schema["auto_id"]

consistency_level = collection_schema["consistency_level"]
# overwrite the consistency level defined when user created the collection
consistency_level = get_consistency_level(_kwargs.get("consistency_level", consistency_level))
_kwargs["schema"] = collection_schema
consistency_level = get_consistency_level(kwargs.get("consistency_level", consistency_level))

ts_utils.construct_guarantee_ts(consistency_level, collection_name, _kwargs)
ts_utils.construct_guarantee_ts(consistency_level, collection_name, kwargs)

requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, expression,
partition_names, output_fields, round_decimal, **_kwargs)
_kwargs.pop("schema")
_kwargs["auto_id"] = auto_id
_kwargs["round_decimal"] = round_decimal
requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, collection_schema,
expression, partition_names, output_fields, round_decimal, **kwargs)

return self._execute_search_requests(requests, timeout, **_kwargs)
auto_id = collection_schema["auto_id"]
return self._execute_search_requests(requests, timeout, round_decimal=round_decimal, auto_id=auto_id, **kwargs)

@retry_on_rpc_failure()
def get_query_segment_info(self, collection_name, timeout=30, **kwargs):
Expand Down
3 changes: 1 addition & 2 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,9 @@ def extract_vectors_param(param, placeholders, names, round_decimal):
return request

@classmethod
def search_requests_with_expr(cls, collection_name, data, anns_field, param, limit, expr=None, partition_names=None,
def search_requests_with_expr(cls, collection_name, data, anns_field, param, limit, schema, expr=None, partition_names=None,
output_fields=None, round_decimal=-1, **kwargs):
# TODO Move this impl into server side
schema = kwargs.get("schema", None)
fields_schema = schema.get("fields", None) # list
fields_name_locs = {fields_schema[loc]["name"]: loc for loc in range(len(fields_schema))}

Expand Down
8 changes: 5 additions & 3 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def __init__(self, name, schema=None, using="default", shards_num=2, **kwargs):
else:
raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType)

self._schema_dict = self._schema.to_dict()
self._schema_dict["consistency_level"] = self._consistency_level

def __repr__(self):
_dict = {
'name': self.name,
Expand Down Expand Up @@ -653,10 +656,9 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))

conn = self._get_connection()
schema_dict = self._schema.to_dict()
schema_dict["consistency_level"] = self._consistency_level
res = conn.search(self._name, data, anns_field, param, limit, expr,
partition_names, output_fields, round_decimal, timeout=timeout, schema=schema_dict, **kwargs)
partition_names, output_fields, round_decimal, timeout=timeout,
collection_schema=self._schema_dict, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/orm/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def search(self, data, anns_field, param, limit,
schema_dict = self._schema.to_dict()
schema_dict["consistency_level"] = self._consistency_level
res = conn.search(self._collection.name, data, anns_field, param, limit, expr, [self._name], output_fields,
round_decimal=round_decimal, timeout=timeout, schema=schema_dict, **kwargs)
round_decimal=round_decimal, timeout=timeout, collection_schema=schema_dict, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_search_requests_with_expr_offset(self):
"offset": 10,
}

ret = Prepare.search_requests_with_expr("name", data, "v", search_params, 100, schema=schema)
ret = Prepare.search_requests_with_expr("name", data, "v", search_params, 100, schema)

offset_exists = False
for p in ret[0].search_params:
Expand Down

0 comments on commit 8b06225

Please sign in to comment.