Skip to content

Commit

Permalink
Addressing feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
drduhe committed Oct 18, 2023
1 parent 1209287 commit f0338cd
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 23 deletions.
20 changes: 10 additions & 10 deletions src/aws/osml/model_runner/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@
from typing import Any, Dict, List, Optional, Tuple

import shapely.geometry.base
from aws.osml.gdal import (
GDALConfigEnv,
GDALDigitalElevationModelTileFactory,
get_image_extension,
load_gdal_dataset,
set_gdal_default_configuration,
)
from aws.osml.photogrammetry import DigitalElevationModel, ElevationModel, SensorModel, SRTMTileSet
from aws_embedded_metrics.logger.metrics_logger import MetricsLogger
from aws_embedded_metrics.metric_scope import metric_scope
from aws_embedded_metrics.unit import Unit
Expand All @@ -28,8 +20,16 @@
from osgeo import gdal
from osgeo.gdal import Dataset

from .api import ImageRequest, InvalidImageRequestException, RegionRequest, SinkMode, \
VALID_MODEL_HOSTING_OPTIONS
from aws.osml.gdal import (
GDALConfigEnv,
GDALDigitalElevationModelTileFactory,
get_image_extension,
load_gdal_dataset,
set_gdal_default_configuration,
)
from aws.osml.photogrammetry import DigitalElevationModel, ElevationModel, SensorModel, SRTMTileSet

from .api import VALID_MODEL_HOSTING_OPTIONS, ImageRequest, InvalidImageRequestException, RegionRequest, SinkMode
from .app_config import MetricLabels, ServiceConfig
from .common import (
EndpointUtils,
Expand Down
97 changes: 84 additions & 13 deletions src/aws/osml/model_runner/inference/http_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from aws_embedded_metrics.metric_scope import metric_scope
from aws_embedded_metrics.unit import Unit
from geojson import FeatureCollection
from requests.exceptions import RetryError
from urllib3.exceptions import MaxRetryError
from urllib3.util.retry import Retry

from aws.osml.model_runner.api import ModelInvokeMode
from aws.osml.model_runner.app_config import MetricLabels, ServiceConfig
Expand All @@ -21,16 +24,74 @@
logger = logging.getLogger(__name__)


class CountingRetry(urllib3.Retry):
def __init__(self, *args, **kwargs):
"""
Retry class implementation that counts the number of retries.
:return: None
"""
super(CountingRetry, self).__init__(*args, **kwargs)
self.retry_counts = 0

def increment(self, *args, **kwargs) -> Retry:
# Call the parent's increment function
result = super(CountingRetry, self).increment(*args, **kwargs)
result.retry_counts = self.retry_counts + 1

return result

@classmethod
def from_retry(cls, retry_instance: Retry) -> "CountingRetry":
"""Create a CountingRetry instance from a Retry instance."""
if isinstance(retry_instance, cls):
return retry_instance # No conversion needed if it's already a CountingRetry instance

# Create a CountingRetry instance with the same configurations
return cls(
total=retry_instance.total,
connect=retry_instance.connect,
read=retry_instance.read,
redirect=retry_instance.redirect,
status=retry_instance.status,
other=retry_instance.other,
allowed_methods=retry_instance.allowed_methods,
status_forcelist=retry_instance.status_forcelist,
backoff_factor=retry_instance.backoff_factor,
raise_on_redirect=retry_instance.raise_on_redirect,
raise_on_status=retry_instance.raise_on_status,
history=retry_instance.history,
respect_retry_after_header=retry_instance.respect_retry_after_header,
remove_headers_on_redirect=retry_instance.remove_headers_on_redirect,
)


class HTTPDetector(Detector):
def __init__(self, endpoint: str) -> None:
def __init__(self, endpoint: str, name: Optional[str] = None, retry: Optional[urllib3.Retry] = None) -> None:
"""
A HTTP model endpoint invoking object, intended to query sagemaker endpoints.
An HTTP model endpoint interface object, intended to query sagemaker endpoints.
:param endpoint: str = the full URL to invoke the model
:param endpoint: Full url to invoke the model
:param name: Name to give the model endpoint
:param retry: Retry policy to use when invoking the model
:return: None
"""
self.http_pool = urllib3.PoolManager(cert_reqs="CERT_NONE")
# Setup Retry with exponential backoff
# - We will retry for a maximum of eight times.
# - We start with a backoff of 1 second.
# - We will double the backoff after each failed retry attempt.
# - We cap the maximum backoff time to 255 seconds.
# - We can adjust these values as required.
if retry is None:
self.retry = CountingRetry(total=8, backoff_factor=1, raise_on_status=True)
else:
self.retry = CountingRetry.from_retry(retry)
self.http_pool = urllib3.PoolManager(cert_reqs="CERT_NONE", retries=self.retry)
if name:
self.name = name
else:
self.name = "http"
super().__init__(endpoint=endpoint)

@property
Expand All @@ -42,17 +103,15 @@ def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> Feat
"""
Query the established endpoint mode to find features based on a payload
:param payload: BufferedReader = the BufferedReader object that holds the
data that will be sent to the feature generator
:param metrics: MetricsLogger = the metrics logger object to capture the log data on the system
:param payload: BufferedReader object that holds the data that will be sent to the feature generator
:param metrics: Metrics logger object to capture the log data on the system
:return: FeatureCollection = a feature collection containing the center point of a tile
:return: GeoJSON FeatureCollection containing the center point of a tile
"""
retry_count = 0
logger.info("Invoking HTTP Endpoint: {}".format(self.endpoint))
logger.info("Invoking Model: {}".format(self.name))
if isinstance(metrics, MetricsLogger):
metrics.set_dimensions()
metrics.put_dimensions({"HTTPModelEndpoint": self.endpoint})
metrics.put_dimensions({"ModelName": self.name})

try:
self.request_count += 1
Expand All @@ -74,10 +133,23 @@ def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> Feat
url=self.endpoint,
body=payload,
)
self.request_count = 1
# get the history of retries and count them
retry_count = self.retry.retry_counts
if isinstance(metrics, MetricsLogger):
metrics.put_metric(MetricLabels.ENDPOINT_RETRY_COUNT, retry_count, str(Unit.COUNT.value))
return geojson.loads(response.data.decode("utf-8"))
except RetryError as err:
self.error_count += 1
if isinstance(metrics, MetricsLogger):
metrics.put_metric(MetricLabels.MODEL_ERROR, 1, str(Unit.COUNT.value))
logger.error("Retry failed - failed due to {}".format(err))
logger.exception(err)
except MaxRetryError as err:
self.error_count += 1
if isinstance(metrics, MetricsLogger):
metrics.put_metric(MetricLabels.MODEL_ERROR, 1, str(Unit.COUNT.value))
logger.error("Max retries reached - failed due to {}".format(err.reason))
logger.exception(err)
except JSONDecodeError as err:
self.error_count += 1
if isinstance(metrics, MetricsLogger):
Expand All @@ -89,7 +161,6 @@ def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> Feat
)
)
logger.exception(err)
self.error_count += 1

# Return an empty feature collection if the process errored out
return FeatureCollection([])
Expand Down
5 changes: 5 additions & 0 deletions src/aws/osml/model_runner/tile_worker/tile_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def run(self) -> None:

if image_info is None:
logging.info("All images processed. Stopping tile worker.")
logging.info(
"Feature Detector Stats: {} requests with {} errors".format(
self.feature_detector.request_count, self.feature_detector.error_count
)
)
break

try:
Expand Down

0 comments on commit f0338cd

Please sign in to comment.