Skip to content

Commit

Permalink
feat(taps): Add api costs hook (#704)
Browse files Browse the repository at this point in the history
* Add api costs hook

* Correct typing hints for older pythons

* Rename api to sync

Co-authored-by: Edgar R. M. <[email protected]>

* Apply suggestions from code review

Co-authored-by: Edgar R. M. <[email protected]>

* Rename cost methods

* Add sync costs calculation test

* Use a single loop for logging costs

Co-authored-by: Edgar R. M. <[email protected]>

* Update tap_base.py

* Add test for log_sync_costs

Co-authored-by: Edgar R. M. <[email protected]>

* Add missing import

Co-authored-by: Edgar R. M. <[email protected]>
Co-authored-by: Eric Boucher <[email protected]>
  • Loading branch information
3 people authored Jun 21, 2022
1 parent 3cdb614 commit f75e9d6
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 0 deletions.
15 changes: 15 additions & 0 deletions singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class Stream(metaclass=abc.ABCMeta):
parent_stream_type: Optional[Type["Stream"]] = None
ignore_parent_replication_key: bool = False

# Internal API cost aggregator
_sync_costs: Dict[str, int] = {}

def __init__(
self,
tap: TapBaseClass,
Expand Down Expand Up @@ -864,6 +867,18 @@ def _write_request_duration_log(
extra_tags["context"] = context
self._write_metric_log(metric=request_duration_metric, extra_tags=extra_tags)

def log_sync_costs(self) -> None:
"""Log a summary of Sync costs.
The costs are calculated via `calculate_sync_cost`.
This method can be overridden to log results in a custom
format. It is only called once at the end of the life of
the stream.
"""
if len(self._sync_costs) > 0:
msg = f"Total Sync costs for stream {self.name}: {self._sync_costs}"
self.logger.info(msg)

def _check_max_record_limit(self, record_count: int) -> None:
"""TODO.
Expand Down
56 changes: 56 additions & 0 deletions singer_sdk/streams/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def request_records(self, context: Optional[dict]) -> Iterable[dict]:
context, next_page_token=next_page_token
)
resp = decorated_request(prepared_request, context)
self.update_sync_costs(prepared_request, resp, context)
yield from self.parse_response(resp)
previous_token = copy.deepcopy(next_page_token)
next_page_token = self.get_next_page_token(
Expand All @@ -343,8 +344,63 @@ def request_records(self, context: Optional[dict]) -> Iterable[dict]:
# Cycle until get_next_page_token() no longer returns a value
finished = not next_page_token

def update_sync_costs(
self,
request: requests.PreparedRequest,
response: requests.Response,
context: Optional[Dict],
) -> Dict[str, int]:
"""Update internal calculation of Sync costs.
Args:
request: the Request object that was just called.
response: the `requests.Response` object
context: the context passed to the call
Returns:
A dict of costs (for the single request) whose keys are
the "cost domains". See `calculate_sync_cost` for details.
"""
call_costs = self.calculate_sync_cost(request, response, context)
self._sync_costs = {
k: self._sync_costs.get(k, 0) + call_costs.get(k, 0)
for k in call_costs.keys()
}
return self._sync_costs

# Overridable:

def calculate_sync_cost(
self,
request: requests.PreparedRequest,
response: requests.Response,
context: Optional[Dict],
) -> Dict[str, int]:
"""Calculate the cost of the last API call made.
This method can optionally be implemented in streams to calculate
the costs (in arbitrary units to be defined by the tap developer)
associated with a single API/network call. The request and response objects
are available in the callback, as well as the context.
The method returns a dict where the keys are arbitrary cost dimensions,
and the values the cost along each dimension for this one call. For
instance: { "rest": 0, "graphql": 42 } for a call to github's graphql API.
All keys should be present in the dict.
This method can be overridden by tap streams. By default it won't do
anything.
Args:
request: the API Request object that was just called.
response: the `requests.Response` object
context: the context passed to the call
Returns:
A dict of accumulated costs whose keys are the "cost domains".
"""
return {}

def prepare_request_payload(
self, context: Optional[dict], next_page_token: Optional[_TToken]
) -> Optional[dict]:
Expand Down
1 change: 1 addition & 0 deletions singer_sdk/tap_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def sync_all(self) -> None:

stream.sync()
stream.finalize_state_progress_markers()
stream.log_sync_costs()

# Command Line Execution

Expand Down
30 changes: 30 additions & 0 deletions tests/core/test_streams.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Stream tests."""

import logging
from typing import Any, Dict, Iterable, List, Optional, cast

import pendulum
Expand Down Expand Up @@ -357,3 +358,32 @@ def test_cached_jsonpath():

# cached objects should point to the same memory location
assert recompiled is compiled


def test_sync_costs_calculation(tap: SimpleTestTap, caplog):
"""Test sync costs are added up correctly."""
fake_request = requests.PreparedRequest()
fake_response = requests.Response()

stream = RestTestStream(tap)

def calculate_test_cost(
request: requests.PreparedRequest,
response: requests.Response,
context: Optional[Dict],
):
return {"dim1": 1, "dim2": 2}

stream.calculate_sync_cost = calculate_test_cost
stream.update_sync_costs(fake_request, fake_response, None)
stream.update_sync_costs(fake_request, fake_response, None)
assert stream._sync_costs == {"dim1": 2, "dim2": 4}

with caplog.at_level(logging.INFO, logger=tap.name):
stream.log_sync_costs()

assert len(caplog.records) == 1

for record in caplog.records:
assert record.levelname == "INFO"
assert f"Total Sync costs for stream {stream.name}" in record.message

0 comments on commit f75e9d6

Please sign in to comment.