From 3f3d2da95d2adda2f905bdd6faf424caa2f80dc6 Mon Sep 17 00:00:00 2001 From: yangxuan Date: Wed, 27 Sep 2023 18:05:45 +0800 Subject: [PATCH] Enhance search result Signed-off-by: yangxuan --- pymilvus/__init__.py | 2 +- pymilvus/client/abstract.py | 609 ++++++++++++++------------------ pymilvus/client/asynch.py | 10 +- pymilvus/client/check.py | 3 + pymilvus/client/grpc_handler.py | 37 +- pymilvus/client/prepare.py | 9 +- pymilvus/orm/collection.py | 3 +- pymilvus/orm/future.py | 3 +- pymilvus/orm/iterator.py | 3 +- pymilvus/orm/partition.py | 2 +- pymilvus/orm/search.py | 263 -------------- pyproject.toml | 3 +- tests/test_abstract.py | 193 ++++++++++ 13 files changed, 502 insertions(+), 638 deletions(-) delete mode 100644 pymilvus/orm/search.py create mode 100644 tests/test_abstract.py diff --git a/pymilvus/__init__.py b/pymilvus/__init__.py index 2c160fbd8..9e86bc066 100644 --- a/pymilvus/__init__.py +++ b/pymilvus/__init__.py @@ -27,6 +27,7 @@ RemoteBulkWriter, ) from .client import __version__ +from .client.abstract import Hit, Hits, SearchResult from .client.prepare import Prepare from .client.stub import Milvus from .client.types import ( @@ -53,7 +54,6 @@ from .orm.partition import Partition from .orm.role import Role from .orm.schema import CollectionSchema, FieldSchema -from .orm.search import Hit, Hits, SearchResult from .orm.utility import ( create_resource_group, create_user, diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 3dfb3e751..7e78bad2e 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -1,62 +1,16 @@ import abc -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import ujson from pymilvus.exceptions import MilvusException from pymilvus.grpc_gen import schema_pb2 from pymilvus.settings import Config -from . import entity_helper from .constants import DEFAULT_CONSISTENCY_LEVEL from .types import DataType -class LoopBase: - def __init__(self): - self.__index = 0 - - def __iter__(self): - return self - - def __getitem__(self, item: Any): - if isinstance(item, slice): - _start = item.start or 0 - _end = min(item.stop, self.__len__()) if item.stop else self.__len__() - _step = item.step or 1 - - return [self.get__item(i) for i in range(_start, _end, _step)] - - if item >= self.__len__(): - msg = "Index out of range" - raise IndexError(msg) - - return self.get__item(item) - - def __next__(self): - while self.__index < self.__len__(): - self.__index += 1 - return self.__getitem__(self.__index - 1) - - # iterate stop, raise Exception - self.__index = 0 - raise StopIteration - - def __str__(self): - return str(list(map(str, self.__getitem__(slice(0, 10))))) - - @abc.abstractmethod - def get__item(self, item: Any): - raise NotImplementedError - - -class LoopCache: - def __init__(self): - self._array = [] - - def fill(self, index: int, obj: Any): - if len(self._array) + 1 < index: - pass - - class FieldSchema: def __init__(self, raw: Any): self._raw = raw @@ -224,125 +178,6 @@ def __str__(self): return self.dict().__str__() -class Entity: - def __init__(self, entity_id: int, entity_row_data: Any, entity_score: float): - self._id = entity_id - self._row_data = entity_row_data - self._score = entity_score - self._distance = entity_score - - def __str__(self): - return f"id: {self._id}, distance: {self._distance}, entity: {self._row_data}" - - def __getattr__(self, item: Any): - return self.value_of_field(item) - - @property - def id(self): - return self._id - - @property - def fields(self): - return [k for k, v in self._row_data.items()] - - def get(self, field: Any): - return self.value_of_field(field) - - def value_of_field(self, field: Any): - if field not in self._row_data: - raise MilvusException(message=f"Field {field} is not in return entity") - return self._row_data[field] - - def type_of_field(self, field: Any): - msg = "TODO: support field in Hits" - raise NotImplementedError(msg) - - def to_dict(self): - return {"id": self._id, "distance": self._distance, "entity": self._row_data} - - -class Hit: - def __init__(self, entity_id: int, entity_row_data: Any, entity_score: float): - self._id = entity_id - self._row_data = entity_row_data - self._score = entity_score - self._distance = entity_score - - def __str__(self): - return str(self.entity) - - __repr__ = __str__ - - @property - def entity(self): - return Entity(self._id, self._row_data, self._score) - - @property - def id(self): - return self._id - - @property - def distance(self): - return self._distance - - @property - def score(self): - return self._score - - def to_dict(self): - return self.entity.to_dict() - - -class Hits(LoopBase): - def __init__(self, raw: Any, round_decimal: int = -1): - super().__init__() - self._raw = raw - if round_decimal != -1: - self._distances = [round(x, round_decimal) for x in self._raw.scores] - else: - self._distances = self._raw.scores - - self._dynamic_field_name = None - self._dynamic_fields = set() - ( - self._dynamic_field_name, - self._dynamic_fields, - ) = entity_helper.extract_dynamic_field_from_result(self._raw) - - def __len__(self): - if self._raw.ids.HasField("int_id"): - return len(self._raw.ids.int_id.data) - if self._raw.ids.HasField("str_id"): - return len(self._raw.ids.str_id.data) - return 0 - - def get__item(self, item: Any): - if self._raw.ids.HasField("int_id"): - entity_id = self._raw.ids.int_id.data[item] - elif self._raw.ids.HasField("str_id"): - entity_id = self._raw.ids.str_id.data[item] - else: - raise MilvusException(message="Unsupported ids type") - - entity_row_data = entity_helper.extract_row_data_from_fields_data( - self._raw.fields_data, item, self._dynamic_fields - ) - entity_score = self._distances[item] - return Hit(entity_id, entity_row_data, entity_score) - - @property - def ids(self): - if self._raw.ids.HasField("int_id"): - return self._raw.ids.int_id.data - if self._raw.ids.HasField("str_id"): - return self._raw.ids.str_id.data - return [] - - @property - def distances(self): - return self._distances - - class MutationResult: def __init__(self, raw: Any): self._raw = raw @@ -422,175 +257,281 @@ def _pack(self, raw: Any): self._err_index = raw.err_index -class QueryResult(LoopBase): - def __init__(self, raw: Any): - super().__init__() - self._raw = raw - self._pack(raw.hits) +class SequenceIterator: + def __init__(self, seq: Sequence[Any]): + self._seq = seq + self._idx = 0 - def __len__(self): - return self._nq + def __next__(self) -> Any: + if self._idx < len(self._seq): + res = self._seq[self._idx] + self._idx += 1 + return res + raise StopIteration - def _pack(self, raw: Any): - self._nq = raw.results.num_queries - self._topk = raw.results.top_k - self._hits = [] - offset = 0 - for i in range(self._nq): - hit = schema_pb2.SearchResultData() - start_pos = offset - end_pos = offset + raw.results.topks[i] - hit.scores.append(raw.results.scores[start_pos:end_pos]) - if raw.results.ids.HasField("int_id"): - hit.ids.append(raw.results.ids.int_id.data[start_pos:end_pos]) - elif raw.results.ids.HasField("str_id"): - hit.ids.append(raw.results.ids.str_id.data[start_pos:end_pos]) - for field_data in raw.result.fields_data: - field = schema_pb2.FieldData() - field.type = field_data.type - field.field_name = field_data.field_name - if field_data.type == DataType.BOOL: - field.scalars.bool_data.data.extend( - field_data.scalars.bool_data.data[start_pos:end_pos] - ) - elif field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32): - field.scalars.int_data.data.extend( - field_data.scalars.int_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.INT64: - field.scalars.long_data.data.extend( - field_data.scalars.long_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.FLOAT: - field.scalars.float_data.data.extend( - field_data.scalars.float_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.DOUBLE: - field.scalars.double_data.data.extend( - field_data.scalars.double_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.VARCHAR: - field.scalars.string_data.data.extend( - field_data.scalars.string_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.STRING: - raise MilvusException(message="Not support string yet") - elif field_data.type == DataType.JSON: - field.scalars.json_data.data.extend( - field_data.scalars.json_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.FLOAT_VECTOR: - dim = field.vectors.dim - field.vectors.dim = dim - field.vectors.float_vector.data.extend( - field_data.vectors.float_data.data[start_pos * dim : end_pos * dim] - ) - elif field_data.type == DataType.BINARY_VECTOR: - dim = field_data.vectors.dim - field.vectors.dim = dim - field.vectors.binary_vector += field_data.vectors.binary_vector[ - start_pos * (dim // 8) : end_pos * (dim // 8) - ] - hit.fields_data.append(field) - self._hits.append(hit) - offset += raw.results.topks[i] - def get__item(self, item: Any): - return Hits(self._hits[item]) - - -class ChunkedQueryResult(LoopBase): - def __init__(self, raw_list: List, round_decimal: int = -1): - super().__init__() - self._raw_list = raw_list - self._nq = 0 - self.round_decimal = round_decimal - - self._pack(self._raw_list) - - def __len__(self): - return self._nq - - def _pack(self, raw_list: List): - self._hits = [] - for raw in raw_list: - nq = raw.results.num_queries - self._nq += nq - self._topk = raw.results.top_k - offset = 0 - - for i in range(nq): - hit = schema_pb2.SearchResultData() - start_pos = offset - end_pos = offset + raw.results.topks[i] - hit.scores.extend(raw.results.scores[start_pos:end_pos]) - if raw.results.ids.HasField("int_id"): - hit.ids.int_id.data.extend(raw.results.ids.int_id.data[start_pos:end_pos]) - elif raw.results.ids.HasField("str_id"): - hit.ids.str_id.data.extend(raw.results.ids.str_id.data[start_pos:end_pos]) - hit.output_fields.extend(raw.results.output_fields) - for field_data in raw.results.fields_data: - field = schema_pb2.FieldData() - field.type = field_data.type - field.field_name = field_data.field_name - field.is_dynamic = field_data.is_dynamic - if field_data.type == DataType.BOOL: - field.scalars.bool_data.data.extend( - field_data.scalars.bool_data.data[start_pos:end_pos] - ) - elif field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32): - field.scalars.int_data.data.extend( - field_data.scalars.int_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.INT64: - field.scalars.long_data.data.extend( - field_data.scalars.long_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.FLOAT: - field.scalars.float_data.data.extend( - field_data.scalars.float_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.DOUBLE: - field.scalars.double_data.data.extend( - field_data.scalars.double_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.VARCHAR: - field.scalars.string_data.data.extend( - field_data.scalars.string_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.STRING: - raise MilvusException(message="Not support string yet") - elif field_data.type == DataType.JSON: - field.scalars.json_data.data.extend( - field_data.scalars.json_data.data[start_pos:end_pos] - ) - elif field_data.type == DataType.ARRAY: - field.scalars.array_data.data.extend( - field_data.scalars.array_data.data[start_pos:end_pos] - ) - field.scalars.array_data.element_type = ( - field_data.scalars.array_data.element_type - ) - elif field_data.type == DataType.FLOAT_VECTOR: - dim = field_data.vectors.dim - field.vectors.dim = dim - field.vectors.float_vector.data.extend( - field_data.vectors.float_vector.data[start_pos * dim : end_pos * dim] - ) - elif field_data.type == DataType.BINARY_VECTOR: - dim = field_data.vectors.dim - field.vectors.dim = dim - field.vectors.binary_vector += field_data.vectors.binary_vector[ - start_pos * (dim // 8) : end_pos * (dim // 8) - ] - hit.fields_data.append(field) - self._hits.append(hit) - offset += raw.results.topks[i] +class SearchResult(list): + """ nq results: List[Hits] """ + def __init__(self, res: schema_pb2.SearchResultData, round_decimal: Optional[int] = None): + self._nq = res.num_queries + all_topks = res.topks - def get__item(self, item: Any): - return Hits(self._hits[item], self.round_decimal) + output_fields = res.output_fields + fields_data = res.fields_data + + all_pks: List[Union[str, int]] = [] + all_scores: List[float] = [] + + if res.ids.HasField("int_id"): + all_pks = res.ids.int_id.data + elif res.ids.HasField("str_id"): + all_pks = res.ids.str_id.data + + if isinstance(round_decimal, int) and round_decimal > 0: + all_scores = [round(x, round_decimal) for x in res.scores] + else: + all_scores = res.scores + + data = [] + for i, topk in enumerate(all_topks): + start, end = i * topk, (i + 1) * topk + nq_th_fields = self.get_fields_by_range(start, end, fields_data) + data.append(Hits(topk, all_pks[start: end], all_scores[start: end], nq_th_fields, output_fields)) + + super().__init__(data) + + def get_fields_by_range(self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData]) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]: + field2data: Dict[str, Tuple[List[Any], schema_pb2.FieldData]] = {} + + for field in all_fields_data: + name, scalars, dtype = field.field_name, field.scalars, field.type + field_meta = schema_pb2.FieldData( + type=dtype, + field_name=name, + field_id=field.field_id, + is_dynamic=field.is_dynamic, + ) + if dtype == DataType.BOOL: + field2data[name] = scalars.bool_data.data[start: end], field_meta + continue + + if dtype in (DataType.INT8, DataType.INT16, DataType.INT32): + field2data[name] = scalars.int_data.data[start: end], field_meta + continue + + if dtype == DataType.INT64: + field2data[name] = scalars.long_data.data[start: end], field_meta + continue + + if dtype == DataType.FLOAT: + field2data[name] = scalars.float_data.data[start: end], field_meta + continue + + if dtype == DataType.DOUBLE: + field2data[name] = scalars.double_data.data[start: end], field_meta + continue + + if dtype == DataType.VARCHAR: + field2data[name] = scalars.string_data.data[start: end], field_meta + continue + + if dtype == DataType.JSON: + json_dict_list = list(map(ujson.loads, scalars.json_data.data[start: end])) + field2data[name] = json_dict_list, field_meta + continue + + if dtype == DataType.ARRAY: + topk_array_fields = scalars.array_data.data[start: end] + field2data[name] = extract_array_row_data(topk_array_fields, scalars.array_data.element_type), field_meta + continue + + # vectors + dim, vectors = field.vectors.dim, field.vectors + field_meta.vectors.dim = dim + if dtype == DataType.FLOAT_VECTOR: + field2data[name] = vectors.float_vector.data[start * dim: end * dim], field_meta + continue + + if dtype == DataType.BINARY_VECTOR: + field2data[name] = vectors.binary_vector[start * (dim // 8) : end * (dim // 8)], field_meta + continue + + return field2data + + + def __iter__(self) -> SequenceIterator: + return SequenceIterator(self) + + def __str__(self) -> str: + """Only print at most 10 query results""" + return str(list(map(str, self[:10]))) + + __repr__ = __str__ + + +class Hits(list): + ids: List[Union[str, int]] + distances: List[float] + + def __init__( + self, + topk: int, + pks: Union[int, str], + distances: List[float], + fields: Dict[str, Tuple[List[Any], schema_pb2.FieldData]], + output_fields: List[str]): + """ + Args: + fields(Dict[str, Tuple[List[Any], schema_pb2.FieldData]]): + field name to a tuple of topk data and field meta + """ + self.ids = pks + self.distances = distances + + all_fields = list(fields.keys()) + dynamic_fields = list(set(output_fields) - set(all_fields)) + + hits = [] + for i in range(topk): + curr_field = {} + for fname, (data, field_meta) in fields.items(): + # Get vectors + if field_meta.type in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR): + dim = field_meta.vectors.dim + dim = dim // 8 if field_meta.type == DataType.BINARY_VECTOR else dim + + curr_field[fname] = data[i*dim: (i+1) *dim] + continue + + # Get dynamic fields + if field_meta.type == DataType.JSON and field_meta.is_dynamic: + if len(dynamic_fields) > 0: + curr_field.update({k: v for k, v in data[i].items() if k in dynamic_fields}) + continue + + if fname in output_fields: + curr_field.update(data[i]) + continue + + # normal fields + curr_field[fname] = data[i] + + hits.append(Hit(pks[i], distances[i], curr_field)) + + super().__init__(hits) + + def __iter__(self) -> SequenceIterator: + return SequenceIterator(self) + + def __str__(self) -> str: + """Only print at most 10 query results""" + return str(list(map(str, self[:10]))) + + __repr__ = __str__ + + +class Hit: + id: Union[int, str] + distance: float + entity: Dict[str, Any] + + def __init__(self, pk: Union[int, str], distance: float, fields: Dict[str, Any]): + self.id = pk + self.distance = distance + self.entity = fields + + def __getattr__(self, item: str): + if item not in self.entity: + raise MilvusException(message=f"Field {item} is not in the hit entity") + return self.entity[item] + + @property + def pk(self) -> Union[str, int]: + return self.id + + @property + def score(self) -> float: + return self.distance + + def get(self, field_name: str) -> Any: + return self.entity.get(field_name) + + def __str__(self) -> str: + return f"" + + __repr__ = __str__ + + def to_dict(self): + return { + "id": self.id, + "distance": self.distance, + "entity": self.entity, + } + + +def extract_array_row_data(scalars: List[schema_pb2.ScalarField], element_type: DataType) -> List[List[Any]]: + row = [] + for ith_array in scalars: + if element_type == DataType.INT64: + row.append(ith_array.long_data.data) + continue + + if element_type == DataType.BOOL: + row.append(ith_array.bool_data.data) + continue + + if element_type in (DataType.INT8, DataType.INT16, DataType.INT32): + row.append(ith_array.int_data.data) + continue + if element_type == DataType.FLOAT: + row.append(ith_array.float_data.data) + continue + + if element_type == DataType.DOUBLE: + row.append(ith_array.double_data.data) + continue + + if element_type in (DataType.STRING, DataType.VARCHAR): + row.append(ith_array.string_data.data) + continue + return row + + +class LoopBase: + def __init__(self): + self.__index = 0 + + def __iter__(self): + return self + + def __getitem__(self, item: Any): + if isinstance(item, slice): + _start = item.start or 0 + _end = min(item.stop, self.__len__()) if item.stop else self.__len__() + _step = item.step or 1 + + return [self.get__item(i) for i in range(_start, _end, _step)] + + if item >= self.__len__(): + msg = "Index out of range" + raise IndexError(msg) + + return self.get__item(item) + + def __next__(self): + while self.__index < self.__len__(): + self.__index += 1 + return self.__getitem__(self.__index - 1) + + # iterate stop, raise Exception + self.__index = 0 + raise StopIteration + + def __str__(self): + return str(list(map(str, self.__getitem__(slice(0, 10))))) + + @abc.abstractmethod + def get__item(self, item: Any): + raise NotImplementedError -def _abstract(): - msg = "You need to override this function" - raise NotImplementedError(msg) diff --git a/pymilvus/client/asynch.py b/pymilvus/client/asynch.py index ec4bc7e95..dc9b3d7e8 100644 --- a/pymilvus/client/asynch.py +++ b/pymilvus/client/asynch.py @@ -3,8 +3,9 @@ from typing import Any, Callable, List, Optional from pymilvus.exceptions import MilvusException +from pymilvus.grpc_gen import milvus_pb2 -from .abstract import ChunkedQueryResult, MutationResult, QueryResult +from .abstract import MutationResult, SearchResult from .types import Status @@ -160,9 +161,9 @@ def exception(self): class SearchFuture(Future): - def on_response(self, response: Any): + def on_response(self, response: milvus_pb2.SearchResults): if response.status.code == 0: - return QueryResult(response) + return SearchResult(response.results) status = response.status raise MilvusException(status.code, status.reason, status.error_code) @@ -171,6 +172,7 @@ def on_response(self, response: Any): # TODO: if ChunkedFuture is more common later, consider using ChunkedFuture as Base Class, # then Future(future, done_cb, pre_exception) equal # to ChunkedFuture([future], done_cb, pre_exception) +# TODO GOOSE class ChunkedSearchFuture(Future): def __init__( self, @@ -249,7 +251,7 @@ def on_response(self, response: Any): if raw.status.code != 0: raise MilvusException(raw.status.code, raw.status.reason, raw.status.error_code) - return ChunkedQueryResult(response) + return SearchResult(response) class MutationFuture(Future): diff --git a/pymilvus/client/check.py b/pymilvus/client/check.py index e6948af26..c8a597d9b 100644 --- a/pymilvus/client/check.py +++ b/pymilvus/client/check.py @@ -175,6 +175,9 @@ def is_legal_search_data(data: Any) -> bool: if not isinstance(data, (list, np.ndarray)): return False + if len(data) == 0: + return False + return all(isinstance(vector, (list, bytes, np.ndarray)) for vector in data) diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 699e510b4..a200b4c64 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -24,7 +24,7 @@ from pymilvus.settings import Config from . import entity_helper, interceptor, ts_utils -from .abstract import ChunkedQueryResult, CollectionSchema, MutationResult +from .abstract import CollectionSchema, MutationResult, SearchResult from .asynch import ( ChunkedSearchFuture, CreateIndexFuture, @@ -722,31 +722,26 @@ def upsert_rows( response.status.code, response.status.reason, response.status.error_code ) - def _execute_search_requests(self, requests: Any, timeout: Optional[float] = None, **kwargs): + def _execute_search( + self, request: milvus_types.SearchRequest, timeout: Optional[float] = None, **kwargs + ): try: if kwargs.get("_async", False): - futures = [] - for request in requests: - ft = self._stub.Search.future(request, timeout=timeout) - futures.append(ft) + future = self._stub.Search.future(request, timeout=timeout) func = kwargs.get("_callback", None) - return ChunkedSearchFuture(futures, func) - - raws = [] - for request in requests: - response = self._stub.Search(request, timeout=timeout) + return ChunkedSearchFuture(future, func) - if response.status.code != 0: - raise MilvusException(response.status.code, response.status.reason) + response = self._stub.Search(request, timeout=timeout) + if response.status.code != 0: + raise MilvusException(response.status.code, response.status.reason) - raws.append(response) round_decimal = kwargs.get("round_decimal", -1) - return ChunkedQueryResult(raws, round_decimal) + return SearchResult(response.results, round_decimal) - except Exception as pre_err: + except Exception as e: if kwargs.get("_async", False): - return SearchFuture(None, None, pre_err) - raise pre_err from pre_err + return SearchFuture(None, None, e) + raise e from e @retry_on_rpc_failure() def search( @@ -773,7 +768,7 @@ def search( guarantee_timestamp=kwargs.get("guarantee_timestamp", None), ) - requests = Prepare.search_requests_with_expr( + request = Prepare.search_requests_with_expr( collection_name, data, anns_field, @@ -785,9 +780,7 @@ def search( round_decimal, **kwargs, ) - return self._execute_search_requests( - requests, timeout, round_decimal=round_decimal, **kwargs - ) + return self._execute_search(request, timeout, round_decimal=round_decimal, **kwargs) @retry_on_rpc_failure() def get_query_segment_info(self, collection_name: str, timeout: float = 30, **kwargs): diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 31d319b5b..a0683822d 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -577,11 +577,7 @@ def search_requests_with_expr( output_fields: Optional[List[str]] = None, round_decimal: int = -1, **kwargs, - ): - requests = [] - if len(data) <= 0: - return requests - + ) -> milvus_types.SearchRequest: if isinstance(data[0], bytes): is_binary = True pl_type = PlaceholderType.BinaryVector @@ -651,8 +647,7 @@ def dump(v: Dict): ] ) - requests.append(request) - return requests + return request @classmethod def create_alias_request(cls, collection_name: str, alias: str): diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index d904d1123..e01f5676a 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -52,7 +52,6 @@ check_upsert_schema, construct_fields_from_dataframe, ) -from .search import SearchResult from .types import DataType from .utility import _get_connection @@ -791,7 +790,7 @@ def search( ) if kwargs.get("_async", False): return SearchFuture(res) - return SearchResult(res) + return res def search_iterator( self, diff --git a/pymilvus/orm/future.py b/pymilvus/orm/future.py index 09f43b76c..308e55192 100644 --- a/pymilvus/orm/future.py +++ b/pymilvus/orm/future.py @@ -12,8 +12,9 @@ from typing import Any +from pymilvus.client.abstract import SearchResult + from .mutation import MutationResult -from .search import SearchResult # TODO(dragondriver): how could we inherit the docstring elegantly? diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 417fa7121..5d84fe4cb 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -1,12 +1,11 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, TypeVar -from pymilvus.client.abstract import LoopBase +from pymilvus.client.abstract import Hits, LoopBase from pymilvus.exceptions import ( MilvusException, ParamError, ) -from pymilvus.orm.search import Hits from .connections import Connections from .constants import ( diff --git a/pymilvus/orm/partition.py b/pymilvus/orm/partition.py index 0c84bbba4..bc5f7796f 100644 --- a/pymilvus/orm/partition.py +++ b/pymilvus/orm/partition.py @@ -15,6 +15,7 @@ import pandas as pd import ujson +from pymilvus.client.abstract import SearchResult from pymilvus.client.types import Replica from pymilvus.exceptions import ( ExceptionsMessage, @@ -23,7 +24,6 @@ ) from .mutation import MutationResult -from .search import SearchResult Collection = TypeVar("Collection") Partition = TypeVar("Partition") diff --git a/pymilvus/orm/search.py b/pymilvus/orm/search.py deleted file mode 100644 index 5c07ca46a..000000000 --- a/pymilvus/orm/search.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright (C) 2019-2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under -# the License. - -import abc -from typing import Any, Iterable - -from pymilvus.client.abstract import Entity - - -class _IterableWrapper: - def __init__(self, iterable_obj: Iterable) -> None: - self._iterable = iterable_obj - - def __iter__(self): - return self - - def __next__(self): - return self.on_result(self._iterable.__next__()) - - def __getitem__(self, item: str): - s = self._iterable.__getitem__(item) - if isinstance(item, slice): - _start = item.start or 0 - i_len = self._iterable.__len__() - _end = min(item.stop, i_len) if item.stop else i_len - - elements = [] - for i in range(_start, _end): - elements.append(self.on_result(s[i])) - return elements - return s - - def __len__(self) -> int: - return self._iterable.__len__() - - @abc.abstractmethod - def on_result(self, res: Any): - raise NotImplementedError - - -# TODO: how to add docstring to method of subclass and don't change the implementation? -# for example like below: -# class Hits(_IterableWrapper): -# -# def on_result(self, res): - - -class DocstringMeta(type): - def __new__(cls, name: str, bases: Any, attrs: Any): - doc_meta = attrs.pop("docstring", None) - new_cls = super().__new__(cls, name, bases, attrs) - if doc_meta: - for member_name, member in attrs.items(): - if member_name in doc_meta: - member.__doc__ = doc_meta[member_name] - return new_cls - - -# for example: -# class Hits(_IterableWrapper, metaclass=DocstringMeta): -# -# def on_result(self, res): - - -class Hit: - def __init__(self, hit: Any) -> None: - """ - Construct a Hit object from response. A hit represent a record corresponding to the query. - """ - self._hit = hit - - @property - def id(self) -> int: - """ - Return the id of the hit record. - - :return int: - The id of the hit record. - """ - return self._hit.id - - @property - def entity(self) -> Entity: - """ - Return the Entity of the hit record. - - :return pymilvus Entity object: - The entity content of the hit record. - """ - return self._hit.entity - - @property - def distance(self) -> float: - """ - Return the distance between the hit record and the query. - - :return float: - The distance of the hit record. - """ - return self._hit.distance - - @property - def score(self) -> float: - """ - Return the calculated score of the hit record, now the score is equal to distance. - - :return float: - The score of the hit record. - """ - return self._hit.score - - def __str__(self) -> str: - """ - Return the information of hit record. - - :return str: - The information of hit record. - """ - return str(self._hit) - - __repr__ = __str__ - - def to_dict(self): - return self._hit.to_dict() - - -class Hits: - def __init__(self, hits: Any) -> None: - """ - Construct a Hits object from response. - """ - self._hits = hits - - def __iter__(self): - """ - Iterate the Hits object. Every iteration returns a Hit which represent a record - corresponding to the query. - """ - return self - - def __next__(self): - """ - Iterate the Hits object. Every iteration returns a Hit which represent a record - corresponding to the query. - """ - return Hit(self._hits.__next__()) - - def __getitem__(self, item: str): - """ - Return the kth Hit corresponding to the query. - - :return Hit: - The kth specified by item Hit corresponding to the query. - """ - s = self._hits.__getitem__(item) - if isinstance(item, slice): - _start = item.start or 0 - i_len = self._hits.__len__() - _end = min(item.stop, i_len) if item.stop else i_len - - elements = [] - for i in range(_start, _end): - elements.append(self.on_result(s[i])) - return elements - return self.on_result(s) - - def __len__(self) -> int: - """ - Return the number of hit record. - - :return int: - The number of hit record. - """ - return self._hits.__len__() - - def __str__(self) -> str: - return str(list(map(str, self.__getitem__(slice(0, 10))))) - - def on_result(self, res: Any): - return Hit(res) - - @property - def ids(self) -> list: - """ - Return the ids of all hit record. - - :return list[int]: - The ids of all hit record. - """ - return self._hits.ids - - @property - def distances(self) -> list: - """ - Return the distances of all hit record. - - :return list[float]: - The distances of all hit record. - """ - return self._hits.distances - - -class SearchResult: - def __init__(self, query_result: Any = None) -> None: - """ - Construct a search result from response. - """ - self._qs = query_result - - def __iter__(self): - """ - Iterate the Search Result. Every iteration returns a Hits corresponding to a query. - """ - return self - - def __next__(self): - """ - Iterate the Search Result. Every iteration returns a Hits corresponding to a query. - """ - return self.on_result(self._qs.__next__()) - - def __getitem__(self, item: Any): - """ - Return the Hits corresponding to the nth query. - - :return Hits: - The hits corresponding to the nth(item) query. - """ - s = self._qs.__getitem__(item) - if isinstance(item, slice): - _start = item.start or 0 - i_len = self._qs.__len__() - _end = min(item.stop, i_len) if item.stop else i_len - - elements = [] - for i in range(_start, _end): - elements.append(self.on_result(s[i])) - return elements - return self.on_result(s) - - def __len__(self) -> int: - """ - Return the number of query of Search Result. - - :return int: - The number of query of search result. - """ - return self._qs.__len__() - - def __str__(self) -> str: - return str(list(map(str, self.__getitem__(slice(0, 10))))) - - def on_result(self, res: Any): - return Hits(res) diff --git a/pyproject.toml b/pyproject.toml index beb02e9fd..21c085d3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,7 +132,8 @@ exclude = [ "venv", "grpc_gen", "__pycache__", - "pymilvus/client/stub.py" + "pymilvus/client/stub.py", + "tests" ] # Same as Black. diff --git a/tests/test_abstract.py b/tests/test_abstract.py new file mode 100644 index 000000000..21dd8da7b --- /dev/null +++ b/tests/test_abstract.py @@ -0,0 +1,193 @@ +from typing import List, Tuple +from pymilvus.client.abstract import Hit, Hits, SearchResult +from pymilvus.client.types import DataType +from pymilvus.grpc_gen import schema_pb2 +import random + +import pytest +import ujson + + +class TestHit: + @pytest.mark.parametrize("pk_dist", [ + (1, 0.1), + (2, 0.3), + ("a", 0.4), + ]) + def test_hit_no_fields(self, pk_dist: List[Tuple]): + pk, dist = pk_dist + h = Hit(pk, dist, {}) + assert h.id == pk + assert h.score == dist + assert h.distance == dist + assert h.entity == {} + + assert h.to_dict() == { + "id": pk, + "distance": dist, + "entity": {}, + } + + assert h.__dict__ == h.to_dict() + + @pytest.mark.parametrize("pk_dist_fields", [ + (1, 0.1, {"vector": [1., 2., 3., 4.], "description": "This is a test", 'd_a': "dynamic a"}), + (2, 0.3, {"vector": [3., 4., 5., 6.], "description": "This is a test too", 'd_b': "dynamic b"}), + ("a", 0.4, {"vector": [4., 4., 4., 4.], "description": "This is a third test", 'd_a': "dynamic a twice"}), + ]) + def test_hit_with_fields(self, pk_dist_fields: List[Tuple]): + h = Hit(*pk_dist_fields) + + # fixed attributes + assert h.id == pk_dist_fields[0] + assert h.score == pk_dist_fields[1] + assert h.distance == h.score + assert h.entity == pk_dist_fields[2] + + # dynamic attributes + assert h.description == pk_dist_fields[2].get("description") + assert h.vector == pk_dist_fields[2].get("vector") + + with pytest.raises(Exception): + h.field_not_exits + + print(h) + + +class TestSearchResult: + @pytest.mark.parametrize("pk", [ + schema_pb2.IDs(int_id=schema_pb2.LongArray(data=[i for i in range(6)])), + schema_pb2.IDs(str_id=schema_pb2.StringArray(data=[str(i*10) for i in range(6)])) + ]) + @pytest.mark.parametrize("round_decimal", [ + None, + -1, + 4, + ]) + def test_search_result_no_fields_data(self, pk, round_decimal): + result = schema_pb2.SearchResultData( + num_queries=2, + top_k=3, + scores=[1.*i for i in range(6)], + ids=pk, + topks=[3, 3], + ) + r = SearchResult(result, round_decimal) + + # Iterable + assert 2 == len(r) + for hits in r: + assert isinstance(hits, Hits) + assert len(hits.ids) == 3 + assert len(hits.distances) == 3 + + # slicable + assert 1 == len(r[1:]) + first_q, second_q = r[0], r[1] + assert 3 == len(first_q) + assert 3 == len(first_q[:]) + assert 2 == len(first_q[1:]) + assert 1 == len(first_q[2:]) + assert 0 == len(first_q[3:]) + print(first_q[:]) + print(first_q[1:]) + print(first_q[2:]) + + first_hit = first_q[0] + print(first_hit) + assert first_hit.distance == 0. + assert first_hit.entity == {} + + @pytest.mark.parametrize("pk", [ + schema_pb2.IDs(int_id=schema_pb2.LongArray(data=[i for i in range(6)])), + schema_pb2.IDs(str_id=schema_pb2.StringArray(data=[str(i*10) for i in range(6)])) + ]) + def test_search_result_with_fields_data(self, pk): + fields_data = [ + schema_pb2.FieldData(type=DataType.BOOL, field_name="bool_field", field_id=100, + scalars=schema_pb2.ScalarField(bool_data=schema_pb2.BoolArray(data=[True for i in range(6)]))), + schema_pb2.FieldData(type=DataType.INT8, field_name="int8_field", field_id=101, + scalars=schema_pb2.ScalarField(int_data=schema_pb2.IntArray(data=[i for i in range(6)]))), + schema_pb2.FieldData(type=DataType.INT16, field_name="int16_field", field_id=102, + scalars=schema_pb2.ScalarField(int_data=schema_pb2.IntArray(data=[i for i in range(6)]))), + schema_pb2.FieldData(type=DataType.INT32, field_name="int32_field", field_id=103, + scalars=schema_pb2.ScalarField(int_data=schema_pb2.IntArray(data=[i for i in range(6)]))), + schema_pb2.FieldData(type=DataType.INT64, field_name="int64_field", field_id=104, + scalars=schema_pb2.ScalarField(long_data=schema_pb2.LongArray(data=[i for i in range(6)]))), + schema_pb2.FieldData(type=DataType.FLOAT, field_name="float_field", field_id=105, + scalars=schema_pb2.ScalarField(float_data=schema_pb2.FloatArray(data=[i*1. for i in range(6)]))), + schema_pb2.FieldData(type=DataType.DOUBLE, field_name="double_field", field_id=106, + scalars=schema_pb2.ScalarField(double_data=schema_pb2.DoubleArray(data=[i*1. for i in range(6)]))), + schema_pb2.FieldData(type=DataType.VARCHAR, field_name="varchar_field", field_id=107, + scalars=schema_pb2.ScalarField(string_data=schema_pb2.StringArray(data=[str(i*10) for i in range(6)]))), + schema_pb2.FieldData(type=DataType.ARRAY, field_name="int16_array_field", field_id=108, + scalars=schema_pb2.ScalarField( + array_data=schema_pb2.ArrayArray( + data=[schema_pb2.ScalarField(int_data=schema_pb2.IntArray(data=[j for j in range(10)])) for i in range(6)], + element_type=DataType.INT16, + ), + )), + schema_pb2.FieldData(type=DataType.ARRAY, field_name="int64_array_field", field_id=109, + scalars=schema_pb2.ScalarField( + array_data=schema_pb2.ArrayArray( + data=[schema_pb2.ScalarField(long_data=schema_pb2.LongArray(data=[j for j in range(10)])) for i in range(6)], + element_type=DataType.INT64, + ), + )), + schema_pb2.FieldData(type=DataType.ARRAY, field_name="float_array_field", field_id=110, + scalars=schema_pb2.ScalarField( + array_data=schema_pb2.ArrayArray( + data=[schema_pb2.ScalarField(float_data=schema_pb2.FloatArray(data=[j*1. for j in range(10)])) for i in range(6)], + element_type=DataType.FLOAT, + ), + )), + schema_pb2.FieldData(type=DataType.ARRAY, field_name="varchar_array_field", field_id=110, + scalars=schema_pb2.ScalarField( + array_data=schema_pb2.ArrayArray( + data=[schema_pb2.ScalarField(string_data=schema_pb2.StringArray(data=[str(j*1.) for j in range(10)])) for i in range(6)], + element_type=DataType.VARCHAR, + ), + )), + + schema_pb2.FieldData(type=DataType.JSON, field_name="normal_json_field", field_id=111, + scalars=schema_pb2.ScalarField(json_data=schema_pb2.JSONArray(data=[ujson.dumps({i: i for i in range(3)}).encode() for i in range(6)])), + ), + schema_pb2.FieldData(type=DataType.JSON, field_name="$meta", field_id=112, + is_dynamic=True, + scalars=schema_pb2.ScalarField(json_data=schema_pb2.JSONArray(data=[ujson.dumps({str(i*100): i}).encode() for i in range(6)])), + ), + + schema_pb2.FieldData(type=DataType.FLOAT_VECTOR, field_name="float_vector_field", field_id=113, + vectors=schema_pb2.VectorField( + dim=4, + float_vector=schema_pb2.FloatArray(data=[random.random() for i in range(24)]), + ), + ), + schema_pb2.FieldData(type=DataType.BINARY_VECTOR, field_name="binary_vector_field", field_id=114, + vectors=schema_pb2.VectorField( + dim=8, + binary_vector=random.randbytes(6), + ), + ), + ] + result = schema_pb2.SearchResultData( + fields_data=fields_data, + num_queries=2, + top_k=3, + scores=[1.*i for i in range(6)], + ids=pk, + topks=[3, 3], + output_fields=['$meta'] + ) + r = SearchResult(result) + print(r[0]) + assert 2 == len(r) + assert 3 == len(r[0]) == len(r[1]) + assert {'0': 0, '1': 1, '2': 2} == r[0][0].normal_json_field + # dynamic field + assert 1 == r[0][1].entity.get('100') + + assert 0 == r[0][0].int32_field + assert 1 == r[0][1].int8_field + assert 2 == r[0][2].int16_field + assert [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] == r[0][1].int64_array_field