Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support mvcc and break-down-continue for iterator(#2278) #2279

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions examples/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@
logging.getLogger().addHandler(console_handler) # Attach the handler to the root logger



def re_create_collection(skip_data_period: bool):
if not skip_data_period:
def re_create_collection(prepare_new_data: bool):
if prepare_new_data:
if utility.has_collection(COLLECTION_NAME) and CLEAR_EXIST:
utility.drop_collection(COLLECTION_NAME)
print(f"dropped existed collection{COLLECTION_NAME}")
Expand Down Expand Up @@ -118,7 +117,8 @@ def query_iterate_collection_no_offset(collection):

query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE],
offset=0, batch_size=5, consistency_level=CONSISTENCY_LEVEL,
reduce_stop_for_best="false", print_iterator_cursor=True)
reduce_stop_for_best="false", print_iterator_cursor=False,
iterator_cp_file="/tmp/it_cp")
no_best_ids: set = set({})
page_idx = 0
while True:
Expand All @@ -136,7 +136,8 @@ def query_iterate_collection_no_offset(collection):
print("best---------------------------")
query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE],
offset=0, batch_size=5, consistency_level=CONSISTENCY_LEVEL,
reduce_stop_for_best="true", print_iterator_cursor=True)
reduce_stop_for_best="true", print_iterator_cursor=False, iterator_cp_file="/tmp/it_cp")

best_ids: set = set({})
page_idx = 0
while True:
Expand Down Expand Up @@ -239,10 +240,10 @@ def search_iterator_collection_with_limit(collection):


def main():
skip_data_period = True
prepare_new_data = True
connections.connect("default", host=HOST, port=PORT)
collection = re_create_collection(skip_data_period)
if not skip_data_period:
collection = re_create_collection(prepare_new_data)
if prepare_new_data:
collection = prepare_data(collection)
query_iterate_collection_no_offset(collection)
query_iterate_collection_with_offset(collection)
Expand Down
6 changes: 5 additions & 1 deletion pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ def __init__(
res: schema_pb2.SearchResultData,
round_decimal: Optional[int] = None,
status: Optional[common_pb2.Status] = None,
session_ts: Optional[int] = 0,
):
self._nq = res.num_queries
all_topks = res.topks
Expand Down Expand Up @@ -441,9 +442,12 @@ def __init__(
Hits(topk, all_pks[start:end], all_scores[start:end], nq_th_fields, output_fields)
)
nq_thres += topk

self._session_ts = session_ts
super().__init__(data)

def get_session_ts(self):
return self._session_ts

def get_fields_by_range(
self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData]
) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]:
Expand Down
3 changes: 3 additions & 0 deletions pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
RANK_GROUP_SCORER = "rank_group_scorer"
GROUP_STRICT_SIZE = "group_strict_size"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
PAGE_RETAIN_ORDER_FIELD = "page_retain_order"

RANKER_TYPE_RRF = "rrf"
RANKER_TYPE_WEIGHTED = "weighted"

GUARANTEE_TIMESTAMP = "guarantee_timestamp"
14 changes: 11 additions & 3 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
is_legal_host,
is_legal_port,
)
from .constants import ITERATOR_SESSION_TS_FIELD
from .prepare import Prepare
from .types import (
BulkInsertState,
Expand Down Expand Up @@ -733,8 +734,12 @@ def _execute_search(
response = self._stub.Search(request, timeout=timeout)
check_status(response.status)
round_decimal = kwargs.get("round_decimal", -1)
return SearchResult(response.results, round_decimal, status=response.status)

return SearchResult(
response.results,
round_decimal,
status=response.status,
session_ts=response.session_ts,
)
except Exception as e:
if kwargs.get("_async", False):
return SearchFuture(None, None, e)
Expand Down Expand Up @@ -1554,7 +1559,10 @@ def query(
response.fields_data, index, dynamic_fields
)
results.append(entity_row_data)
return ExtraList(results, extra=get_cost_extra(response.status))

extra_dict = get_cost_extra(response.status)
extra_dict[ITERATOR_SESSION_TS_FIELD] = response.session_ts
return ExtraList(results, extra=extra_dict)

@retry_on_rpc_failure()
def load_balance(
Expand Down
15 changes: 9 additions & 6 deletions pymilvus/client/ts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pymilvus.grpc_gen import common_pb2

from .constants import BOUNDED_TS, EVENTUALLY_TS
from .constants import BOUNDED_TS, EVENTUALLY_TS, GUARANTEE_TIMESTAMP, ITERATOR_FIELD
from .singleton_utils import Singleton
from .types import get_consistency_level
from .utils import hybridts_to_unixtime
Expand Down Expand Up @@ -75,26 +75,29 @@ def get_bounded_ts():


def construct_guarantee_ts(collection_name: str, kwargs: Dict):
if kwargs.get(ITERATOR_FIELD) is not None:
return True

consistency_level = kwargs.get("consistency_level")
use_default = consistency_level is None
if use_default:
# in case of the default consistency is Customized or Session,
# we set guarantee_timestamp to the cached mutation ts or 1
kwargs["guarantee_timestamp"] = get_collection_ts(collection_name) or get_eventually_ts()
kwargs[GUARANTEE_TIMESTAMP] = get_collection_ts(collection_name) or get_eventually_ts()
return True
consistency_level = get_consistency_level(consistency_level)
kwargs["consistency_level"] = consistency_level
if consistency_level == ConsistencyLevel.Strong:
# Milvus will assign a newest ts.
kwargs["guarantee_timestamp"] = 0
kwargs[GUARANTEE_TIMESTAMP] = 0
elif consistency_level == ConsistencyLevel.Session:
# Using the last write ts of the collection.
# TODO: get a timestamp from server?
kwargs["guarantee_timestamp"] = get_collection_ts(collection_name) or get_eventually_ts()
kwargs[GUARANTEE_TIMESTAMP] = get_collection_ts(collection_name) or get_eventually_ts()
elif consistency_level == ConsistencyLevel.Bounded:
# Milvus will assign ts according to the server timestamp and a configured time interval
kwargs["guarantee_timestamp"] = get_bounded_ts()
kwargs[GUARANTEE_TIMESTAMP] = get_bounded_ts()
else:
# Users customize the consistency level, no modification on `guarantee_timestamp`.
kwargs.setdefault("guarantee_timestamp", get_eventually_ts())
kwargs.setdefault(GUARANTEE_TIMESTAMP, get_eventually_ts())
return use_default
2 changes: 1 addition & 1 deletion pymilvus/grpc_gen/milvus-proto
528 changes: 264 additions & 264 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions pymilvus/grpc_gen/milvus_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -791,14 +791,16 @@ class Hits(_message.Message):
def __init__(self, IDs: _Optional[_Iterable[int]] = ..., row_data: _Optional[_Iterable[bytes]] = ..., scores: _Optional[_Iterable[float]] = ...) -> None: ...

class SearchResults(_message.Message):
__slots__ = ("status", "results", "collection_name")
__slots__ = ("status", "results", "collection_name", "session_ts")
STATUS_FIELD_NUMBER: _ClassVar[int]
RESULTS_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
SESSION_TS_FIELD_NUMBER: _ClassVar[int]
status: _common_pb2.Status
results: _schema_pb2.SearchResultData
collection_name: str
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., results: _Optional[_Union[_schema_pb2.SearchResultData, _Mapping]] = ..., collection_name: _Optional[str] = ...) -> None: ...
session_ts: int
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., results: _Optional[_Union[_schema_pb2.SearchResultData, _Mapping]] = ..., collection_name: _Optional[str] = ..., session_ts: _Optional[int] = ...) -> None: ...

class HybridSearchRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "partition_names", "requests", "rank_params", "travel_timestamp", "guarantee_timestamp", "not_return_all_meta", "output_fields", "consistency_level", "use_default_consistency")
Expand Down Expand Up @@ -920,16 +922,18 @@ class QueryRequest(_message.Message):
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., expr: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ..., partition_names: _Optional[_Iterable[str]] = ..., travel_timestamp: _Optional[int] = ..., guarantee_timestamp: _Optional[int] = ..., query_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., not_return_all_meta: bool = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ..., use_default_consistency: bool = ...) -> None: ...

class QueryResults(_message.Message):
__slots__ = ("status", "fields_data", "collection_name", "output_fields")
__slots__ = ("status", "fields_data", "collection_name", "output_fields", "session_ts")
STATUS_FIELD_NUMBER: _ClassVar[int]
FIELDS_DATA_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
OUTPUT_FIELDS_FIELD_NUMBER: _ClassVar[int]
SESSION_TS_FIELD_NUMBER: _ClassVar[int]
status: _common_pb2.Status
fields_data: _containers.RepeatedCompositeFieldContainer[_schema_pb2.FieldData]
collection_name: str
output_fields: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., collection_name: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ...) -> None: ...
session_ts: int
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., collection_name: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ..., session_ts: _Optional[int] = ...) -> None: ...

class VectorIDs(_message.Message):
__slots__ = ("collection_name", "field_name", "id_array", "partition_names")
Expand Down
Loading