Skip to content

Commit

Permalink
remove async keyword from changeFeed query in aio package
Browse files Browse the repository at this point in the history
  • Loading branch information
annie-mac committed Aug 18, 2024
1 parent 2950e20 commit 7a1a1eb
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@

"""Iterable change feed results in the Azure Cosmos database service.
"""

from azure.core.async_paging import AsyncPageIterator

from azure.cosmos import PartitionKey
from azure.cosmos._change_feed.aio.change_feed_fetcher import ChangeFeedFetcherV1, ChangeFeedFetcherV2
from azure.cosmos._change_feed.aio.change_feed_state import ChangeFeedStateV1, ChangeFeedState
from azure.cosmos._utils import is_base64_encoded
from azure.cosmos._utils import is_base64_encoded, is_key_exists_and_not_none


class ChangeFeedIterable(AsyncPageIterator):
Expand Down Expand Up @@ -57,40 +59,30 @@ def __init__(
self._options = options
self._fetch_function = fetch_function
self._collection_link = collection_link
self._change_feed_fetcher = None

change_feed_state = self._options.get("changeFeedState")
if not change_feed_state:
raise ValueError("Missing changeFeedState in feed options")
if not is_key_exists_and_not_none(self._options, "changeFeedStateContext"):
raise ValueError("Missing changeFeedStateContext in feed options")

if isinstance(change_feed_state, ChangeFeedStateV1):
if continuation_token:
if is_base64_encoded(continuation_token):
raise ValueError("Incompatible continuation token")
else:
change_feed_state.apply_server_response_continuation(continuation_token)
change_feed_state_context = self._options.pop("changeFeedStateContext")

self._change_feed_fetcher = ChangeFeedFetcherV1(
self._client,
self._collection_link,
self._options,
fetch_function
)
else:
if continuation_token:
if not is_base64_encoded(continuation_token):
raise ValueError("Incompatible continuation token")
continuation = continuation_token if continuation_token is not None else change_feed_state_context.pop("continuation", None)

effective_change_feed_context = {"continuationFeedRange": continuation_token}
effective_change_feed_state = ChangeFeedState.from_json(change_feed_state.container_rid, effective_change_feed_context)
# replace with the effective change feed state
self._options["continuationFeedRange"] = effective_change_feed_state
# analysis and validate continuation token
# there are two types of continuation token we support currently:
# v1 version: the continuation token would just be the _etag,
# which is being returned when customer is using partition_key_range_id,
# which is under deprecation and does not support split/merge
# v2 version: the continuation token will be base64 encoded composition token which includes full change feed state
if continuation is not None:
if is_base64_encoded(continuation):
change_feed_state_context["continuationFeedRange"] = continuation
else:
change_feed_state_context["continuationPkRangeId"] = continuation

self._validate_change_feed_state_context(change_feed_state_context)
self._options["changeFeedStateContext"] = change_feed_state_context

self._change_feed_fetcher = ChangeFeedFetcherV2(
self._client,
self._collection_link,
self._options,
fetch_function
)
super(ChangeFeedIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token)

async def _unpack(self, block):
Expand All @@ -112,7 +104,59 @@ async def _fetch_next(self, *args): # pylint: disable=unused-argument
:return: List of results.
:rtype: list
"""
if self._change_feed_fetcher is None:
await self._initialize_change_feed_fetcher()

block = await self._change_feed_fetcher.fetch_next_block()
if not block:
raise StopAsyncIteration
return block

async def _initialize_change_feed_fetcher(self):
change_feed_state_context = self._options.pop("changeFeedStateContext")
conn_properties = await change_feed_state_context.pop("containerProperties")
if is_key_exists_and_not_none(change_feed_state_context, "partitionKey"):
change_feed_state_context["partitionKey"] = await change_feed_state_context.pop("partitionKey")

pk_properties = conn_properties.get("partitionKey")
partition_key_definition = PartitionKey(path=pk_properties["paths"], kind=pk_properties["kind"])

change_feed_state =\
ChangeFeedState.from_json(self._collection_link, conn_properties["_rid"], partition_key_definition, change_feed_state_context)
self._options["changeFeedState"] = change_feed_state

if isinstance(change_feed_state, ChangeFeedStateV1):
self._change_feed_fetcher = ChangeFeedFetcherV1(
self._client,
self._collection_link,
self._options,
self._fetch_function
)
else:
self._change_feed_fetcher = ChangeFeedFetcherV2(
self._client,
self._collection_link,
self._options,
self._fetch_function
)

def _validate_change_feed_state_context(self, change_feed_state_context: dict[str, any]) -> None:

if is_key_exists_and_not_none(change_feed_state_context, "continuationPkRangeId"):
# if continuation token is in v1 format, throw exception if feed_range is set
if is_key_exists_and_not_none(change_feed_state_context, "feedRange"):
raise ValueError("feed_range and continuation are incompatible")
elif is_key_exists_and_not_none(change_feed_state_context, "continuationFeedRange"):
# if continuation token is in v2 format, since the token itself contains the full change feed state
# so we will ignore other parameters (including incompatible parameters) if they passed in
pass
else:
# validation when no continuation is passed
exclusive_keys = ["partitionKeyRangeId", "partitionKey", "feedRange"]
count = sum(1 for key in exclusive_keys if
key in change_feed_state_context and change_feed_state_context[key] is not None)
if count > 1:
raise ValueError(
"partition_key_range_id, partition_key, feed_range are exclusive parameters, please only set one of them")


Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@
from abc import ABC, abstractmethod
from typing import Optional, Union, List, Any

from azure.cosmos import http_constants
from azure.cosmos import http_constants, PartitionKey
from azure.cosmos._change_feed.aio.change_feed_start_from import ChangeFeedStartFromETagAndFeedRange, \
ChangeFeedStartFromInternal
from azure.cosmos._change_feed.aio.composite_continuation_token import CompositeContinuationToken
from azure.cosmos._change_feed.aio.feed_range_composite_continuation_token import FeedRangeCompositeContinuation
from azure.cosmos._change_feed.feed_range import FeedRangeEpk, FeedRangePartitionKey, FeedRange
from azure.cosmos._routing.aio.routing_map_provider import SmartRoutingMapProvider
from azure.cosmos._routing.routing_range import Range
from azure.cosmos._utils import is_key_exists_and_not_none
Expand All @@ -49,15 +50,22 @@ def populate_feed_options(self, feed_options: dict[str, any]) -> None:
pass

@abstractmethod
async def populate_request_headers(self, routing_provider: SmartRoutingMapProvider, request_headers: dict[str, any]) -> None:
async def populate_request_headers(
self,
routing_provider: SmartRoutingMapProvider,
request_headers: dict[str, any]) -> None:
pass

@abstractmethod
def apply_server_response_continuation(self, continuation: str) -> None:
pass

@staticmethod
def from_json(container_link: str, container_rid: str, data: dict[str, Any]):
def from_json(
container_link: str,
container_rid: str,
partition_key_definition: PartitionKey,
data: dict[str, Any]):
if is_key_exists_and_not_none(data, "partitionKeyRangeId") or is_key_exists_and_not_none(data, "continuationPkRangeId"):
return ChangeFeedStateV1.from_json(container_link, container_rid, data)
else:
Expand All @@ -69,11 +77,11 @@ def from_json(container_link: str, container_rid: str, data: dict[str, Any]):
if version is None:
raise ValueError("Invalid base64 encoded continuation string [Missing version]")
elif version == "V2":
return ChangeFeedStateV2.from_continuation(container_link, container_rid, continuation_json)
return ChangeFeedStateV2.from_continuation(container_link, container_rid, partition_key_definition, continuation_json)
else:
raise ValueError("Invalid base64 encoded continuation string [Invalid version]")
# when there is no continuation token, by default construct ChangeFeedStateV2
return ChangeFeedStateV2.from_initial_state(container_link, container_rid, data)
return ChangeFeedStateV2.from_initial_state(container_link, container_rid, partition_key_definition, data)

class ChangeFeedStateV1(ChangeFeedState):
"""Change feed state v1 implementation. This is used when partition key range id is used or the continuation is just simple _etag
Expand Down Expand Up @@ -110,7 +118,10 @@ def from_json(cls, container_link: str, container_rid: str, data: dict[str, Any]
data.get("continuationPkRangeId")
)

async def populate_request_headers(self, routing_provider: SmartRoutingMapProvider, headers: dict[str, Any]) -> None:
async def populate_request_headers(
self,
routing_provider: SmartRoutingMapProvider,
headers: dict[str, Any]) -> None:
headers[http_constants.HttpHeaders.AIM] = http_constants.HttpHeaders.IncrementalFeedHeaderValue

# When a merge happens, the child partition will contain documents ordered by LSN but the _ts/creation time
Expand Down Expand Up @@ -140,7 +151,8 @@ def __init__(
self,
container_link: str,
container_rid: str,
feed_range: Range,
partition_key_definition: PartitionKey,
feed_range: FeedRange,
change_feed_start_from: ChangeFeedStartFromInternal,
continuation: Optional[FeedRangeCompositeContinuation] = None):

Expand All @@ -151,7 +163,9 @@ def __init__(
self._continuation = continuation
if self._continuation is None:
composite_continuation_token_queue = collections.deque()
composite_continuation_token_queue.append(CompositeContinuationToken(self._feed_range, None))
composite_continuation_token_queue.append(CompositeContinuationToken(
self._feed_range.get_normalized_range(partition_key_definition),
None))
self._continuation =\
FeedRangeCompositeContinuation(self._container_rid, self._feed_range, composite_continuation_token_queue)

Expand All @@ -168,7 +182,10 @@ def to_dict(self) -> dict[str, Any]:
self.continuation_property_name: self._continuation.to_dict()
}

async def populate_request_headers(self, routing_provider: SmartRoutingMapProvider, headers: dict[str, any]) -> None:
async def populate_request_headers(
self,
routing_provider: SmartRoutingMapProvider,
headers: dict[str, any]) -> None:
headers[http_constants.HttpHeaders.AIM] = http_constants.HttpHeaders.IncrementalFeedHeaderValue

# When a merge happens, the child partition will contain documents ordered by LSN but the _ts/creation time
Expand Down Expand Up @@ -224,6 +241,7 @@ def from_continuation(
cls,
container_link: str,
container_rid: str,
partition_key_definition: PartitionKey,
continuation_json: dict[str, Any]) -> 'ChangeFeedStateV2':

container_rid_from_continuation = continuation_json.get(ChangeFeedStateV2.container_rid_property_name)
Expand All @@ -244,6 +262,7 @@ def from_continuation(
return ChangeFeedStateV2(
container_link=container_link,
container_rid=container_rid,
partition_key_definition=partition_key_definition,
feed_range=continuation.feed_range,
change_feed_start_from=change_feed_start_from,
continuation=continuation)
Expand All @@ -253,26 +272,29 @@ def from_initial_state(
cls,
container_link: str,
collection_rid: str,
partition_key_definition: PartitionKey,
data: dict[str, Any]) -> 'ChangeFeedStateV2':

if is_key_exists_and_not_none(data, "feedRange"):
feed_range_str = base64.b64decode(data["feedRange"]).decode('utf-8')
feed_range_json = json.loads(feed_range_str)
feed_range = Range.ParseFromDict(feed_range_json)
elif is_key_exists_and_not_none(data, "partitionKeyFeedRange"):
feed_range = data["partitionKeyFeedRange"]
feed_range = FeedRangeEpk(Range.ParseFromDict(feed_range_json))
elif is_key_exists_and_not_none(data, "partitionKey"):
feed_range = FeedRangePartitionKey(data["partitionKey"])
else:
# default to full range
feed_range = Range(
"",
"FF",
True,
False)
feed_range = FeedRangeEpk(
Range(
"",
"FF",
True,
False))

change_feed_start_from = ChangeFeedStartFromInternal.from_start_time(data.get("startTime"))
return cls(
container_link=container_link,
container_rid=collection_rid,
partition_key_definition=partition_key_definition,
feed_range=feed_range,
change_feed_start_from=change_feed_start_from,
continuation=None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,21 @@
from typing import Any

from azure.cosmos._change_feed.aio.composite_continuation_token import CompositeContinuationToken
from azure.cosmos._change_feed.feed_range import FeedRange, FeedRangeEpk, FeedRangePartitionKey
from azure.cosmos._routing.aio.routing_map_provider import SmartRoutingMapProvider
from azure.cosmos._routing.routing_range import Range
from azure.cosmos._utils import is_key_exists_and_not_none


class FeedRangeCompositeContinuation(object):
_version_property_name = "V"
_container_rid_property_name = "Rid"
_continuation_property_name = "Continuation"
_feed_range_property_name = "Range"
_version_property_name = "v"
_container_rid_property_name = "rid"
_continuation_property_name = "continuation"

def __init__(
self,
container_rid: str,
feed_range: Range,
feed_range: FeedRange,
continuation: collections.deque[CompositeContinuationToken]):
if container_rid is None:
raise ValueError("container_rid is missing")
Expand All @@ -55,38 +56,49 @@ def __init__(
def current_token(self):
return self._current_token

def get_feed_range(self) -> FeedRange:
if isinstance(self._feed_range, FeedRangeEpk):
return FeedRangeEpk(self.current_token.feed_range)
else:
return self._feed_range

def to_dict(self) -> dict[str, Any]:
return {
self._version_property_name: "v1", #TODO: should this start from v2
json_data = {
self._version_property_name: "v2",
self._container_rid_property_name: self._container_rid,
self._continuation_property_name: [childToken.to_dict() for childToken in self._continuation],
self._feed_range_property_name: self._feed_range.to_dict()
}

json_data.update(self._feed_range.to_dict())
return json_data

@classmethod
def from_json(cls, data) -> 'FeedRangeCompositeContinuation':
version = data.get(cls._version_property_name)
if version is None:
raise ValueError(f"Invalid feed range composite continuation token [Missing {cls._version_property_name}]")
if version != "v1":
if version != "v2":
raise ValueError("Invalid feed range composite continuation token [Invalid version]")

container_rid = data.get(cls._container_rid_property_name)
if container_rid is None:
raise ValueError(f"Invalid feed range composite continuation token [Missing {cls._container_rid_property_name}]")

feed_range_data = data.get(cls._feed_range_property_name)
if feed_range_data is None:
raise ValueError(f"Invalid feed range composite continuation token [Missing {cls._feed_range_property_name}]")
feed_range = Range.ParseFromDict(feed_range_data)

continuation_data = data.get(cls._continuation_property_name)
if continuation_data is None:
raise ValueError(f"Invalid feed range composite continuation token [Missing {cls._continuation_property_name}]")
if not isinstance(continuation_data, list) or len(continuation_data) == 0:
raise ValueError(f"Invalid feed range composite continuation token [The {cls._continuation_property_name} must be non-empty array]")
continuation = [CompositeContinuationToken.from_json(child_range_continuation_token) for child_range_continuation_token in continuation_data]

# parsing feed range
if is_key_exists_and_not_none(data, FeedRangeEpk.type_property_name):
feed_range = FeedRangeEpk.from_json({ FeedRangeEpk.type_property_name: data[FeedRangeEpk.type_property_name] })
elif is_key_exists_and_not_none(data, FeedRangePartitionKey.type_property_name):
feed_range = FeedRangePartitionKey.from_json({ FeedRangePartitionKey.type_property_name: data[FeedRangePartitionKey.type_property_name] })
else:
raise ValueError("Invalid feed range composite continuation token [Missing feed range scope]")

return cls(container_rid=container_rid, feed_range=feed_range, continuation=deque(continuation))

async def handle_feed_range_gone(self, routing_provider: SmartRoutingMapProvider, collection_link: str) -> None:
Expand Down Expand Up @@ -130,5 +142,5 @@ def apply_not_modified_response(self) -> None:
self._initial_no_result_range = self._current_token.feed_range

@property
def feed_range(self) -> Range:
def feed_range(self) -> FeedRange:
return self._feed_range
Loading

0 comments on commit 7a1a1eb

Please sign in to comment.