Skip to content

Commit

Permalink
[airbyte-cdk] Stream should not extract state using legacy get_update…
Browse files Browse the repository at this point in the history
…d_state if no cursor (#36342)
  • Loading branch information
brianjlai authored Mar 21, 2024
1 parent df17c85 commit 728c92c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
5 changes: 4 additions & 1 deletion airbyte-cdk/python/airbyte_cdk/sources/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o
hasattr(record_data_or_message, "type") and record_data_or_message.type == MessageType.RECORD
):
record_data = record_data_or_message if isinstance(record_data_or_message, Mapping) else record_data_or_message.record
stream_state = self.get_updated_state(stream_state, record_data)
if self.cursor_field:
# Some connectors have streams that implement get_updated_state(), but do not define a cursor_field. This
# should be fixed on the stream implementation, but we should also protect against this in the CDK as well
stream_state = self.get_updated_state(stream_state, record_data)
record_counter += 1

if sync_mode == SyncMode.incremental:
Expand Down
14 changes: 12 additions & 2 deletions airbyte-cdk/python/unit_tests/sources/streams/test_stream_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ def _incremental_concurrent_stream(slice_to_partition_mapping, slice_logger, log
return stream


def _stream_with_no_cursor_field(slice_to_partition_mapping, slice_logger, logger, message_repository):
def get_updated_state(current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]) -> MutableMapping[str, Any]:
raise Exception("I shouldn't be invoked by a full_refresh stream")

mock_stream = _MockStream(slice_to_partition_mapping)
mock_stream.get_updated_state = get_updated_state
return mock_stream


@pytest.mark.parametrize(
"constructor",
[
Expand Down Expand Up @@ -232,9 +241,10 @@ def test_full_refresh_read_a_single_slice(constructor):
[
pytest.param(_stream, id="synchronous_reader"),
pytest.param(_concurrent_stream, id="concurrent_reader"),
pytest.param(_stream_with_no_cursor_field, id="no_cursor_field"),
],
)
def test_full_refresh_read_a_two_slices(constructor):
def test_full_refresh_read_two_slices(constructor):
# This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object
# It is done by running the same test cases on both streams
configured_stream = ConfiguredAirbyteStream(stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}), sync_mode=SyncMode.full_refresh,destination_sync_mode=DestinationSyncMode.overwrite)
Expand All @@ -261,7 +271,7 @@ def test_full_refresh_read_a_two_slices(constructor):
]

# Temporary check to only validate the final state message for synchronous sources since it has not been implemented for concurrent yet
if constructor == _stream:
if constructor == _stream or constructor == _stream_with_no_cursor_field:
expected_records.append(
AirbyteMessage(
type=MessageType.STATE,
Expand Down
4 changes: 4 additions & 0 deletions airbyte-cdk/python/unit_tests/sources/test_abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ def read_records(self, **kwargs) -> Iterable[Mapping[str, Any]]: # type: ignore
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
return "pk"

@property
def cursor_field(self) -> Union[str, List[str]]:
return ["updated_at"]


class MockStreamWithState(MockStream):
cursor_field = "cursor"
Expand Down

0 comments on commit 728c92c

Please sign in to comment.