From f0338cd6749d3879f750be86d52ccd46908fc2a2 Mon Sep 17 00:00:00 2001 From: drduhe Date: Tue, 10 Oct 2023 21:35:31 -0600 Subject: [PATCH] Addressing feedback --- src/aws/osml/model_runner/app.py | 20 ++-- .../model_runner/inference/http_detector.py | 97 ++++++++++++++++--- .../model_runner/tile_worker/tile_worker.py | 5 + 3 files changed, 99 insertions(+), 23 deletions(-) diff --git a/src/aws/osml/model_runner/app.py b/src/aws/osml/model_runner/app.py index 77b13dc4..dc5f29d3 100755 --- a/src/aws/osml/model_runner/app.py +++ b/src/aws/osml/model_runner/app.py @@ -11,14 +11,6 @@ from typing import Any, Dict, List, Optional, Tuple import shapely.geometry.base -from aws.osml.gdal import ( - GDALConfigEnv, - GDALDigitalElevationModelTileFactory, - get_image_extension, - load_gdal_dataset, - set_gdal_default_configuration, -) -from aws.osml.photogrammetry import DigitalElevationModel, ElevationModel, SensorModel, SRTMTileSet from aws_embedded_metrics.logger.metrics_logger import MetricsLogger from aws_embedded_metrics.metric_scope import metric_scope from aws_embedded_metrics.unit import Unit @@ -28,8 +20,16 @@ from osgeo import gdal from osgeo.gdal import Dataset -from .api import ImageRequest, InvalidImageRequestException, RegionRequest, SinkMode, \ - VALID_MODEL_HOSTING_OPTIONS +from aws.osml.gdal import ( + GDALConfigEnv, + GDALDigitalElevationModelTileFactory, + get_image_extension, + load_gdal_dataset, + set_gdal_default_configuration, +) +from aws.osml.photogrammetry import DigitalElevationModel, ElevationModel, SensorModel, SRTMTileSet + +from .api import VALID_MODEL_HOSTING_OPTIONS, ImageRequest, InvalidImageRequestException, RegionRequest, SinkMode from .app_config import MetricLabels, ServiceConfig from .common import ( EndpointUtils, diff --git a/src/aws/osml/model_runner/inference/http_detector.py b/src/aws/osml/model_runner/inference/http_detector.py index 61e8a94e..1b3daf61 100644 --- a/src/aws/osml/model_runner/inference/http_detector.py +++ b/src/aws/osml/model_runner/inference/http_detector.py @@ -9,6 +9,9 @@ from aws_embedded_metrics.metric_scope import metric_scope from aws_embedded_metrics.unit import Unit from geojson import FeatureCollection +from requests.exceptions import RetryError +from urllib3.exceptions import MaxRetryError +from urllib3.util.retry import Retry from aws.osml.model_runner.api import ModelInvokeMode from aws.osml.model_runner.app_config import MetricLabels, ServiceConfig @@ -21,16 +24,74 @@ logger = logging.getLogger(__name__) +class CountingRetry(urllib3.Retry): + def __init__(self, *args, **kwargs): + """ + Retry class implementation that counts the number of retries. + + :return: None + """ + super(CountingRetry, self).__init__(*args, **kwargs) + self.retry_counts = 0 + + def increment(self, *args, **kwargs) -> Retry: + # Call the parent's increment function + result = super(CountingRetry, self).increment(*args, **kwargs) + result.retry_counts = self.retry_counts + 1 + + return result + + @classmethod + def from_retry(cls, retry_instance: Retry) -> "CountingRetry": + """Create a CountingRetry instance from a Retry instance.""" + if isinstance(retry_instance, cls): + return retry_instance # No conversion needed if it's already a CountingRetry instance + + # Create a CountingRetry instance with the same configurations + return cls( + total=retry_instance.total, + connect=retry_instance.connect, + read=retry_instance.read, + redirect=retry_instance.redirect, + status=retry_instance.status, + other=retry_instance.other, + allowed_methods=retry_instance.allowed_methods, + status_forcelist=retry_instance.status_forcelist, + backoff_factor=retry_instance.backoff_factor, + raise_on_redirect=retry_instance.raise_on_redirect, + raise_on_status=retry_instance.raise_on_status, + history=retry_instance.history, + respect_retry_after_header=retry_instance.respect_retry_after_header, + remove_headers_on_redirect=retry_instance.remove_headers_on_redirect, + ) + + class HTTPDetector(Detector): - def __init__(self, endpoint: str) -> None: + def __init__(self, endpoint: str, name: Optional[str] = None, retry: Optional[urllib3.Retry] = None) -> None: """ - A HTTP model endpoint invoking object, intended to query sagemaker endpoints. + An HTTP model endpoint interface object, intended to query sagemaker endpoints. - :param endpoint: str = the full URL to invoke the model + :param endpoint: Full url to invoke the model + :param name: Name to give the model endpoint + :param retry: Retry policy to use when invoking the model :return: None """ - self.http_pool = urllib3.PoolManager(cert_reqs="CERT_NONE") + # Setup Retry with exponential backoff + # - We will retry for a maximum of eight times. + # - We start with a backoff of 1 second. + # - We will double the backoff after each failed retry attempt. + # - We cap the maximum backoff time to 255 seconds. + # - We can adjust these values as required. + if retry is None: + self.retry = CountingRetry(total=8, backoff_factor=1, raise_on_status=True) + else: + self.retry = CountingRetry.from_retry(retry) + self.http_pool = urllib3.PoolManager(cert_reqs="CERT_NONE", retries=self.retry) + if name: + self.name = name + else: + self.name = "http" super().__init__(endpoint=endpoint) @property @@ -42,17 +103,15 @@ def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> Feat """ Query the established endpoint mode to find features based on a payload - :param payload: BufferedReader = the BufferedReader object that holds the - data that will be sent to the feature generator - :param metrics: MetricsLogger = the metrics logger object to capture the log data on the system + :param payload: BufferedReader object that holds the data that will be sent to the feature generator + :param metrics: Metrics logger object to capture the log data on the system - :return: FeatureCollection = a feature collection containing the center point of a tile + :return: GeoJSON FeatureCollection containing the center point of a tile """ - retry_count = 0 - logger.info("Invoking HTTP Endpoint: {}".format(self.endpoint)) + logger.info("Invoking Model: {}".format(self.name)) if isinstance(metrics, MetricsLogger): metrics.set_dimensions() - metrics.put_dimensions({"HTTPModelEndpoint": self.endpoint}) + metrics.put_dimensions({"ModelName": self.name}) try: self.request_count += 1 @@ -74,10 +133,23 @@ def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> Feat url=self.endpoint, body=payload, ) - self.request_count = 1 + # get the history of retries and count them + retry_count = self.retry.retry_counts if isinstance(metrics, MetricsLogger): metrics.put_metric(MetricLabels.ENDPOINT_RETRY_COUNT, retry_count, str(Unit.COUNT.value)) return geojson.loads(response.data.decode("utf-8")) + except RetryError as err: + self.error_count += 1 + if isinstance(metrics, MetricsLogger): + metrics.put_metric(MetricLabels.MODEL_ERROR, 1, str(Unit.COUNT.value)) + logger.error("Retry failed - failed due to {}".format(err)) + logger.exception(err) + except MaxRetryError as err: + self.error_count += 1 + if isinstance(metrics, MetricsLogger): + metrics.put_metric(MetricLabels.MODEL_ERROR, 1, str(Unit.COUNT.value)) + logger.error("Max retries reached - failed due to {}".format(err.reason)) + logger.exception(err) except JSONDecodeError as err: self.error_count += 1 if isinstance(metrics, MetricsLogger): @@ -89,7 +161,6 @@ def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> Feat ) ) logger.exception(err) - self.error_count += 1 # Return an empty feature collection if the process errored out return FeatureCollection([]) diff --git a/src/aws/osml/model_runner/tile_worker/tile_worker.py b/src/aws/osml/model_runner/tile_worker/tile_worker.py index a5239899..286d751d 100755 --- a/src/aws/osml/model_runner/tile_worker/tile_worker.py +++ b/src/aws/osml/model_runner/tile_worker/tile_worker.py @@ -40,6 +40,11 @@ def run(self) -> None: if image_info is None: logging.info("All images processed. Stopping tile worker.") + logging.info( + "Feature Detector Stats: {} requests with {} errors".format( + self.feature_detector.request_count, self.feature_detector.error_count + ) + ) break try: