From f75e9d6c11ea7a609e50719fa324678b569e4708 Mon Sep 17 00:00:00 2001 From: Laurent Savaete Date: Tue, 21 Jun 2022 15:33:28 -0300 Subject: [PATCH] feat(taps): Add api costs hook (#704) * Add api costs hook * Correct typing hints for older pythons * Rename api to sync Co-authored-by: Edgar R. M. * Apply suggestions from code review Co-authored-by: Edgar R. M. * Rename cost methods * Add sync costs calculation test * Use a single loop for logging costs Co-authored-by: Edgar R. M. * Update tap_base.py * Add test for log_sync_costs Co-authored-by: Edgar R. M. * Add missing import Co-authored-by: Edgar R. M. Co-authored-by: Eric Boucher --- singer_sdk/streams/core.py | 15 ++++++++++ singer_sdk/streams/rest.py | 56 ++++++++++++++++++++++++++++++++++++++ singer_sdk/tap_base.py | 1 + tests/core/test_streams.py | 30 ++++++++++++++++++++ 4 files changed, 102 insertions(+) diff --git a/singer_sdk/streams/core.py b/singer_sdk/streams/core.py index 1bb5a9c3e..6b9d35f3f 100644 --- a/singer_sdk/streams/core.py +++ b/singer_sdk/streams/core.py @@ -76,6 +76,9 @@ class Stream(metaclass=abc.ABCMeta): parent_stream_type: Optional[Type["Stream"]] = None ignore_parent_replication_key: bool = False + # Internal API cost aggregator + _sync_costs: Dict[str, int] = {} + def __init__( self, tap: TapBaseClass, @@ -864,6 +867,18 @@ def _write_request_duration_log( extra_tags["context"] = context self._write_metric_log(metric=request_duration_metric, extra_tags=extra_tags) + def log_sync_costs(self) -> None: + """Log a summary of Sync costs. + + The costs are calculated via `calculate_sync_cost`. + This method can be overridden to log results in a custom + format. It is only called once at the end of the life of + the stream. + """ + if len(self._sync_costs) > 0: + msg = f"Total Sync costs for stream {self.name}: {self._sync_costs}" + self.logger.info(msg) + def _check_max_record_limit(self, record_count: int) -> None: """TODO. diff --git a/singer_sdk/streams/rest.py b/singer_sdk/streams/rest.py index 752ff334c..983abefd6 100644 --- a/singer_sdk/streams/rest.py +++ b/singer_sdk/streams/rest.py @@ -330,6 +330,7 @@ def request_records(self, context: Optional[dict]) -> Iterable[dict]: context, next_page_token=next_page_token ) resp = decorated_request(prepared_request, context) + self.update_sync_costs(prepared_request, resp, context) yield from self.parse_response(resp) previous_token = copy.deepcopy(next_page_token) next_page_token = self.get_next_page_token( @@ -343,8 +344,63 @@ def request_records(self, context: Optional[dict]) -> Iterable[dict]: # Cycle until get_next_page_token() no longer returns a value finished = not next_page_token + def update_sync_costs( + self, + request: requests.PreparedRequest, + response: requests.Response, + context: Optional[Dict], + ) -> Dict[str, int]: + """Update internal calculation of Sync costs. + + Args: + request: the Request object that was just called. + response: the `requests.Response` object + context: the context passed to the call + + Returns: + A dict of costs (for the single request) whose keys are + the "cost domains". See `calculate_sync_cost` for details. + """ + call_costs = self.calculate_sync_cost(request, response, context) + self._sync_costs = { + k: self._sync_costs.get(k, 0) + call_costs.get(k, 0) + for k in call_costs.keys() + } + return self._sync_costs + # Overridable: + def calculate_sync_cost( + self, + request: requests.PreparedRequest, + response: requests.Response, + context: Optional[Dict], + ) -> Dict[str, int]: + """Calculate the cost of the last API call made. + + This method can optionally be implemented in streams to calculate + the costs (in arbitrary units to be defined by the tap developer) + associated with a single API/network call. The request and response objects + are available in the callback, as well as the context. + + The method returns a dict where the keys are arbitrary cost dimensions, + and the values the cost along each dimension for this one call. For + instance: { "rest": 0, "graphql": 42 } for a call to github's graphql API. + All keys should be present in the dict. + + This method can be overridden by tap streams. By default it won't do + anything. + + Args: + request: the API Request object that was just called. + response: the `requests.Response` object + context: the context passed to the call + + Returns: + A dict of accumulated costs whose keys are the "cost domains". + """ + return {} + def prepare_request_payload( self, context: Optional[dict], next_page_token: Optional[_TToken] ) -> Optional[dict]: diff --git a/singer_sdk/tap_base.py b/singer_sdk/tap_base.py index 5e9d814f4..8f4f59e27 100644 --- a/singer_sdk/tap_base.py +++ b/singer_sdk/tap_base.py @@ -378,6 +378,7 @@ def sync_all(self) -> None: stream.sync() stream.finalize_state_progress_markers() + stream.log_sync_costs() # Command Line Execution diff --git a/tests/core/test_streams.py b/tests/core/test_streams.py index 8e1d44030..2b87dd75b 100644 --- a/tests/core/test_streams.py +++ b/tests/core/test_streams.py @@ -1,5 +1,6 @@ """Stream tests.""" +import logging from typing import Any, Dict, Iterable, List, Optional, cast import pendulum @@ -357,3 +358,32 @@ def test_cached_jsonpath(): # cached objects should point to the same memory location assert recompiled is compiled + + +def test_sync_costs_calculation(tap: SimpleTestTap, caplog): + """Test sync costs are added up correctly.""" + fake_request = requests.PreparedRequest() + fake_response = requests.Response() + + stream = RestTestStream(tap) + + def calculate_test_cost( + request: requests.PreparedRequest, + response: requests.Response, + context: Optional[Dict], + ): + return {"dim1": 1, "dim2": 2} + + stream.calculate_sync_cost = calculate_test_cost + stream.update_sync_costs(fake_request, fake_response, None) + stream.update_sync_costs(fake_request, fake_response, None) + assert stream._sync_costs == {"dim1": 2, "dim2": 4} + + with caplog.at_level(logging.INFO, logger=tap.name): + stream.log_sync_costs() + + assert len(caplog.records) == 1 + + for record in caplog.records: + assert record.levelname == "INFO" + assert f"Total Sync costs for stream {stream.name}" in record.message