Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

source-mixpanel-native: fix cohort_members OOMs #2170

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,51 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import Any, Iterable, List, Mapping, Optional
from typing import Any, Iterable, List, Mapping, Optional, MutableMapping

import requests
from airbyte_cdk.sources.streams.core import IncrementalMixin
from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer

from .base import MixpanelStream
from .cohorts import Cohorts
from .engage import Engage


class CohortMembers(Engage):
# CohortMembers is currently a full refresh stream that uses checkpoints to flush records out of the Airbyte connector.
# This is necessary because some cohorts have enough members that the connector OOMs before it finishes reading that
# cohort's members. In the future, we could make this stream incremental by having cursor values for each cohort within
# the state and performing client-side filtering.
class CohortMembers(MixpanelStream, IncrementalMixin):
"""Return list of users grouped by cohort"""

http_method: str = "POST"
data_field: str = "results"
primary_key: str = "distinct_id"
page_size: int = 50000
_total: Any = None
_cursor_value: str = ''

@property
def cursor_field(self) -> str:
return "last_seen"

@property
def state(self) -> Mapping[str, Any]:
return {self.cursor_field: self._cursor_value}

@state.setter
def state(self, value: Mapping[str, Any]):
self._cursor_value = value[self.cursor_field]

@property
def state_checkpoint_interval(self) -> int:
return 10000
return 100

# enable automatic object mutation to align with desired schema before outputting to the destination
transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)

def path(self, **kwargs) -> str:
return "engage"

def request_body_json(
self,
Expand All @@ -29,22 +57,63 @@ def request_body_json(
# example: {"filter_by_cohort": {"id": 1343181}}
return {"filter_by_cohort": stream_slice}

def request_params(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, any] = None, next_page_token: Mapping[str, Any] = None
) -> MutableMapping[str, Any]:
params = super().request_params(stream_state, stream_slice, next_page_token)
params = {**params, "page_size": self.page_size}
if next_page_token:
params.update(next_page_token)

return params

def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
response_json = response.json()
page_number = response_json.get("page")
total = response_json.get("total") # exists only on first page
if total:
self._total = total

if self._total and page_number is not None and self._total > self.page_size * (page_number + 1):
return {
"session_id": response_json.get("session_id"),
"page": page_number + 1,
}
else:
self._total = None
return None

def stream_slices(
self, sync_mode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
if sync_mode == SyncMode.incremental:
self.set_cursor(cursor_field)

# full refresh is needed because even though some cohorts might already have been read
# they can still have new members added
# full refresh is needed for Cohorts because even though some cohorts might already have been read
# they can still have new members added.
cohorts = Cohorts(**self.get_stream_params()).read_records(SyncMode.full_refresh)
# A single cohort could be empty (i.e. no members), so we only check for members in non-empty cohorts.
filtered_cohorts = [cohort for cohort in cohorts if cohort["count"] > 0]

for cohort in filtered_cohorts:
yield {"id": cohort["id"]}

def process_response(self, response: requests.Response, stream_slice: Mapping[str, Any] = None, **kwargs) -> Iterable[Mapping]:
records = super().process_response(response, **kwargs)
for record in records:
record["cohort_id"] = stream_slice["id"]
yield record
for record in response.json().get(self.data_field, []):
# Format record
item = {"distinct_id": record["$distinct_id"]}
properties = record["$properties"]
for property_name in properties:
this_property_name = property_name
if property_name.startswith("$"):
# Remove leading '$' for 'reserved' mixpanel property names.
this_property_name = this_property_name[1:]
item[this_property_name] = properties[property_name]

item_cursor: str = item.get(self.cursor_field)
if item_cursor:
item_cursor += "+00:00"
item[self.cursor_field] = item_cursor

item["cohort_id"] = stream_slice["id"]

# Always yield every record. If/when this stream is actually made incremental,
# we will need to filter which records are yielded based
yield item
Empty file.
Loading