Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize some low efficient code #1223

Merged
merged 1 commit into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 47 additions & 93 deletions pymilvus/client/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Union
from ..exceptions import ParamError
from ..grpc_gen import milvus_pb2 as milvus_types
from .singleton_utils import Singleton
from .utils import (
valid_index_types,
valid_binary_index_types,
Expand Down Expand Up @@ -308,103 +309,56 @@ def is_legal_operate_privilege_type(operate_privilege_type: Any) -> bool:
(milvus_types.OperatePrivilegeType.Grant, milvus_types.OperatePrivilegeType.Revoke)


class ParamChecker(metaclass=Singleton):
def __init__(self) -> None:
self.check_dict = {
"collection_name": is_legal_table_name,
"field_name": is_legal_field_name,
"dimension": is_legal_dimension,
"index_file_size": is_legal_index_size,
"topk": is_legal_topk,
"ids": is_legal_ids,
"nprobe": is_legal_nprobe,
"nlist": is_legal_nlist,
"cmd": is_legal_cmd,
"partition_name": is_legal_partition_name,
"partition_name_array": is_legal_partition_name_array,
"limit": is_legal_limit,
"anns_field": is_legal_anns_field,
"search_data": is_legal_search_data,
"output_fields": is_legal_output_fields,
"round_decimal": is_legal_round_decimal,
"travel_timestamp": is_legal_travel_timestamp,
"guarantee_timestamp": is_legal_guarantee_timestamp,
"user": is_legal_user,
"password": is_legal_password,
"role_name": is_legal_role_name,
"operate_user_role_type": is_legal_operate_user_role_type,
"include_user_info": is_legal_include_user_info,
"include_role_info": is_legal_include_role_info,
"object": is_legal_object,
"object_name": is_legal_object_name,
"privilege": is_legal_privilege,
"operate_privilege_type": is_legal_operate_privilege_type,
"properties": is_legal_collection_properties,
}

def check(self, key, value):
if key in self.check_dict:
if not self.check_dict[key](value):
_raise_param_error(key, value)
else:
raise ParamError(message=f"unknown param `{key}`")

def _get_param_checker():
return ParamChecker()

def check_pass_param(*_args: Any, **kwargs: Any) -> None: # pylint: disable=too-many-statements
if kwargs is None:
raise ParamError(message="Param should not be None")

checker = _get_param_checker()
for key, value in kwargs.items():
if key in ("collection_name",):
if not is_legal_table_name(value):
_raise_param_error(key, value)
elif key == "field_name":
if not is_legal_field_name(value):
_raise_param_error(key, value)
elif key == "dimension":
if not is_legal_dimension(value):
_raise_param_error(key, value)
elif key == "index_file_size":
if not is_legal_index_size(value):
_raise_param_error(key, value)
elif key in ("topk", "top_k"):
if not is_legal_topk(value):
_raise_param_error(key, value)
elif key in ("ids",):
if not is_legal_ids(value):
_raise_param_error(key, value)
elif key in ("nprobe",):
if not is_legal_nprobe(value):
_raise_param_error(key, value)
elif key in ("nlist",):
if not is_legal_nlist(value):
_raise_param_error(key, value)
elif key in ("cmd",):
if not is_legal_cmd(value):
_raise_param_error(key, value)
elif key in ("partition_name",):
if not is_legal_partition_name(value):
_raise_param_error(key, value)
elif key in ("partition_name_array",):
if not is_legal_partition_name_array(value):
_raise_param_error(key, value)
elif key in ("limit",):
if not is_legal_limit(value):
_raise_param_error(key, value)
elif key in ("anns_field",):
if not is_legal_anns_field(value):
_raise_param_error(key, value)
elif key in ("search_data",):
if not is_legal_search_data(value):
_raise_param_error(key, value)
elif key in ("output_fields",):
if not is_legal_output_fields(value):
_raise_param_error(key, value)
elif key in ("round_decimal",):
if not is_legal_round_decimal(value):
_raise_param_error(key, value)
elif key in ("travel_timestamp",):
if not is_legal_travel_timestamp(value):
_raise_param_error(key, value)
elif key in ("guarantee_timestamp",):
if not is_legal_guarantee_timestamp(value):
_raise_param_error(key, value)
elif key in ("user",):
if not is_legal_user(value):
_raise_param_error(key, value)
elif key in ("password",):
if not is_legal_password(value):
_raise_param_error(key, value)
# elif key in ("records",):
# if not is_legal_records(value):
# _raise_param_error(key, value)
elif key in ("role_name",):
if not is_legal_role_name(value):
_raise_param_error(key, value)
elif key in ("operate_user_role_type",):
if not is_legal_operate_user_role_type(value):
_raise_param_error(key, value)
elif key in ("include_user_info",):
if not is_legal_include_user_info(value):
_raise_param_error(key, value)
elif key in ("include_role_info",):
if not is_legal_include_role_info(value):
_raise_param_error(key, value)
elif key in ("object",):
if not is_legal_object(value):
_raise_param_error(key, value)
elif key in ("object_name",):
if not is_legal_object_name(value):
_raise_param_error(key, value)
elif key in ("privilege",):
if not is_legal_privilege(value):
_raise_param_error(key, value)
elif key in ("operate_privilege_type",):
if not is_legal_operate_privilege_type(value):
_raise_param_error(key, value)
elif key == "properties":
if not is_legal_collection_properties(value):
_raise_param_error(key, value)
else:
raise ParamError(message=f"unknown param `{key}`")
checker.check(key, value)


def check_index_params(params):
Expand Down
41 changes: 16 additions & 25 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,21 +413,17 @@ def _execute_search_requests(self, requests, timeout=None, **kwargs):
auto_id = kwargs.get("auto_id", True)

try:
raws = []
futures = []

# step 1: get future object
for request in requests:
ft = self._stub.Search.future(request, timeout=timeout)
futures.append(ft)

if kwargs.get("_async", False):
futures = []
for request in requests:
ft = self._stub.Search.future(request, timeout=timeout)
futures.append(ft)
func = kwargs.get("_callback", None)
return ChunkedSearchFuture(futures, func, auto_id)

# step2: get results
for ft in futures:
response = ft.result()
raws = []
for request in requests:
response = self._stub.Search(request, timeout=timeout)

if response.status.error_code != 0:
raise MilvusException(response.status.error_code, response.status.reason)
Expand All @@ -444,7 +440,7 @@ def _execute_search_requests(self, requests, timeout=None, **kwargs):
@retry_on_rpc_failure(retry_on_deadline=False)
def search(self, collection_name, data, anns_field, param, limit,
expression=None, partition_names=None, output_fields=None,
round_decimal=-1, timeout=None, **kwargs):
round_decimal=-1, timeout=None, collection_schema=None, **kwargs):
check_pass_param(
limit=limit,
round_decimal=round_decimal,
Expand All @@ -456,26 +452,21 @@ def search(self, collection_name, data, anns_field, param, limit,
guarantee_timestamp=kwargs.get("guarantee_timestamp", 0)
)

_kwargs = copy.deepcopy(kwargs)

collection_schema = kwargs.get("schema", None)
if not collection_schema:
collection_schema = self.describe_collection(collection_name, timeout=timeout, **kwargs)
auto_id = collection_schema["auto_id"]

consistency_level = collection_schema["consistency_level"]
# overwrite the consistency level defined when user created the collection
consistency_level = get_consistency_level(_kwargs.get("consistency_level", consistency_level))
_kwargs["schema"] = collection_schema
consistency_level = get_consistency_level(kwargs.get("consistency_level", consistency_level))

ts_utils.construct_guarantee_ts(consistency_level, collection_name, _kwargs)
ts_utils.construct_guarantee_ts(consistency_level, collection_name, kwargs)

requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, expression,
partition_names, output_fields, round_decimal, **_kwargs)
_kwargs.pop("schema")
_kwargs["auto_id"] = auto_id
_kwargs["round_decimal"] = round_decimal
requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, collection_schema,
expression, partition_names, output_fields, round_decimal,
**kwargs)

return self._execute_search_requests(requests, timeout, **_kwargs)
auto_id = collection_schema["auto_id"]
return self._execute_search_requests(requests, timeout, round_decimal=round_decimal, auto_id=auto_id, **kwargs)

@retry_on_rpc_failure()
def get_query_segment_info(self, collection_name, timeout=30, **kwargs):
Expand Down
3 changes: 1 addition & 2 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,9 @@ def extract_vectors_param(param, placeholders, names, round_decimal):
return request

@classmethod
def search_requests_with_expr(cls, collection_name, data, anns_field, param, limit, expr=None, partition_names=None,
def search_requests_with_expr(cls, collection_name, data, anns_field, param, limit, schema, expr=None, partition_names=None,
output_fields=None, round_decimal=-1, **kwargs):
# TODO Move this impl into server side
schema = kwargs.get("schema", None)
fields_schema = schema.get("fields", None) # list
fields_name_locs = {fields_schema[loc]["name"]: loc for loc in range(len(fields_schema))}

Expand Down
8 changes: 5 additions & 3 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def __init__(self, name, schema=None, using="default", shards_num=2, **kwargs):
else:
raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType)

self._schema_dict = self._schema.to_dict()
self._schema_dict["consistency_level"] = self._consistency_level

def __repr__(self):
_dict = {
'name': self.name,
Expand Down Expand Up @@ -653,10 +656,9 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))

conn = self._get_connection()
schema_dict = self._schema.to_dict()
schema_dict["consistency_level"] = self._consistency_level
res = conn.search(self._name, data, anns_field, param, limit, expr,
partition_names, output_fields, round_decimal, timeout=timeout, schema=schema_dict, **kwargs)
partition_names, output_fields, round_decimal, timeout=timeout,
collection_schema=self._schema_dict, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/orm/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def search(self, data, anns_field, param, limit,
schema_dict = self._schema.to_dict()
schema_dict["consistency_level"] = self._consistency_level
res = conn.search(self._collection.name, data, anns_field, param, limit, expr, [self._name], output_fields,
round_decimal=round_decimal, timeout=timeout, schema=schema_dict, **kwargs)
round_decimal=round_decimal, timeout=timeout, collection_schema=schema_dict, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_search_requests_with_expr_offset(self):
"offset": 10,
}

ret = Prepare.search_requests_with_expr("name", data, "v", search_params, 100, schema=schema)
ret = Prepare.search_requests_with_expr("name", data, "v", search_params, 100, schema)

offset_exists = False
for p in ret[0].search_params:
Expand Down