Skip to content

Commit

Permalink
feat: support http model endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
drduhe committed Oct 18, 2023
1 parent e56f6f0 commit 1209287
Show file tree
Hide file tree
Showing 12 changed files with 277 additions and 55 deletions.
36 changes: 13 additions & 23 deletions src/aws/osml/model_runner/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@
from typing import Any, Dict, List, Optional, Tuple

import shapely.geometry.base
from aws_embedded_metrics.logger.metrics_logger import MetricsLogger
from aws_embedded_metrics.metric_scope import metric_scope
from aws_embedded_metrics.unit import Unit
from dacite import Config as dacite_Config
from dacite import from_dict
from geojson import Feature
from osgeo import gdal
from osgeo.gdal import Dataset

from aws.osml.gdal import (
GDALConfigEnv,
GDALDigitalElevationModelTileFactory,
Expand All @@ -28,8 +19,17 @@
set_gdal_default_configuration,
)
from aws.osml.photogrammetry import DigitalElevationModel, ElevationModel, SensorModel, SRTMTileSet
from aws_embedded_metrics.logger.metrics_logger import MetricsLogger
from aws_embedded_metrics.metric_scope import metric_scope
from aws_embedded_metrics.unit import Unit
from dacite import Config as dacite_Config
from dacite import from_dict
from geojson import Feature
from osgeo import gdal
from osgeo.gdal import Dataset

from .api import ImageRequest, InvalidImageRequestException, ModelInvokeMode, RegionRequest, SinkMode
from .api import ImageRequest, InvalidImageRequestException, RegionRequest, SinkMode, \
VALID_MODEL_HOSTING_OPTIONS
from .app_config import MetricLabels, ServiceConfig
from .common import (
EndpointUtils,
Expand Down Expand Up @@ -732,22 +732,12 @@ def validate_model_hosting(self, image_request: JobItem, metrics: MetricsLogger
:return: None
"""
# TODO: The long term goal is to support AWS provided models hosted by this service as well
# as customer provided models where we're managing the endpoints internally. For an
# initial release we can limit processing to customer managed SageMaker Model
# Endpoints hence this check. The other type options should not be advertised in the
# API but we are including the name/type structure in the API to allow expansion
# through a non-breaking API change.
if (
not image_request.model_invoke_mode
or image_request.model_invoke_mode is ModelInvokeMode.NONE
or image_request.model_invoke_mode.casefold() != "SM_ENDPOINT".casefold()
):
error = "Application only supports SageMaker Model Endpoints"
if not image_request.model_invoke_mode or image_request.model_invoke_mode not in VALID_MODEL_HOSTING_OPTIONS:
error = f"Application only supports ${VALID_MODEL_HOSTING_OPTIONS} Endpoints"
self.status_monitor.process_event(
image_request,
ImageRequestStatus.FAILED,
"Application only supports SageMaker Model Endpoints",
error,
)
if isinstance(metrics, MetricsLogger):
metrics.put_metric(MetricLabels.UNSUPPORTED_MODEL_HOST, 1, str(Unit.COUNT.value))
Expand Down
4 changes: 3 additions & 1 deletion src/aws/osml/model_runner/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# flake8: noqa

from .detector import Detector
from .endpoint_factory import FeatureDetectorFactory
from .feature_selection import FeatureSelector
from .feature_utils import calculate_processing_bounds, get_source_property
from .sm_endpoint_detector import SMDetector
from .http_detector import HTTPDetector
from .sm_detector import SMDetector
11 changes: 9 additions & 2 deletions src/aws/osml/model_runner/inference/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ class Detector(abc.ABC):
The mechanism by which detected features are sent to their destination.
"""

def __init__(self) -> str:
return f"{self.mode.value}"
def __init__(self, endpoint: str) -> None:
"""
Endpoint Detector base class.
:param endpoint: str = the endpoint that will be invoked
"""
self.endpoint = endpoint
self.request_count = 0
self.error_count = 0

@property
@abc.abstractmethod
Expand Down
27 changes: 27 additions & 0 deletions src/aws/osml/model_runner/inference/endpoint_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from typing import Optional

from .detector import Detector


class FeatureEndpointBuilder(ABC):
"""
This is an abstract base for all classes to construct Detectors for various types of endpoints.
"""

def __init__(self) -> None:
"""
Constructor for the builder accepting required properties or formats for detectors
:return: None
"""
pass

@abstractmethod
def build(self) -> Optional[Detector]:
"""
Constructs the sensor model from the available information. Note that in cases where not enough information is
available to provide any solution, this method will return None.
:return: Optional[Detector] = the detector to generate features based on the provided build data
"""
36 changes: 36 additions & 0 deletions src/aws/osml/model_runner/inference/endpoint_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Dict, Optional

from aws.osml.model_runner.api import ModelInvokeMode

from .detector import Detector
from .http_detector import HTTPDetectorBuilder
from .sm_detector import SMDetectorBuilder


class FeatureDetectorFactory:
def __init__(
self, endpoint: str, endpoint_mode: ModelInvokeMode, assumed_credentials: Optional[Dict[str, str]] = None
) -> None:
"""
:param endpoint: URL of the inference model endpoint
:param endpoint_mode: the type of endpoint (HTTP, SageMaker)
:param assumed_credentials: optional credentials to use with the model
"""

self.endpoint = endpoint
self.endpoint_mode = endpoint_mode
self.assumed_credentials = assumed_credentials

def build(self) -> Optional[Detector]:
"""
:return: a feature detector based on the parameters defined during initialization
"""
if self.endpoint_mode == ModelInvokeMode.SM_ENDPOINT:
return SMDetectorBuilder(
endpoint=self.endpoint,
assumed_credentials=self.assumed_credentials,
).build()
if self.endpoint_mode == ModelInvokeMode.HTTP_ENDPOINT:
return HTTPDetectorBuilder(
endpoint=self.endpoint,
).build()
109 changes: 109 additions & 0 deletions src/aws/osml/model_runner/inference/http_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import logging
from io import BufferedReader
from json import JSONDecodeError
from typing import Optional

import geojson
import urllib3
from aws_embedded_metrics.logger.metrics_logger import MetricsLogger
from aws_embedded_metrics.metric_scope import metric_scope
from aws_embedded_metrics.unit import Unit
from geojson import FeatureCollection

from aws.osml.model_runner.api import ModelInvokeMode
from aws.osml.model_runner.app_config import MetricLabels, ServiceConfig
from aws.osml.model_runner.common import Timer

from .detector import Detector
from .endpoint_builder import FeatureEndpointBuilder
from .feature_utils import create_mock_feature_collection

logger = logging.getLogger(__name__)


class HTTPDetector(Detector):
def __init__(self, endpoint: str) -> None:
"""
A HTTP model endpoint invoking object, intended to query sagemaker endpoints.
:param endpoint: str = the full URL to invoke the model
:return: None
"""
self.http_pool = urllib3.PoolManager(cert_reqs="CERT_NONE")
super().__init__(endpoint=endpoint)

@property
def mode(self) -> ModelInvokeMode:
return ModelInvokeMode.HTTP_ENDPOINT

@metric_scope
def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> FeatureCollection:
"""
Query the established endpoint mode to find features based on a payload
:param payload: BufferedReader = the BufferedReader object that holds the
data that will be sent to the feature generator
:param metrics: MetricsLogger = the metrics logger object to capture the log data on the system
:return: FeatureCollection = a feature collection containing the center point of a tile
"""
retry_count = 0
logger.info("Invoking HTTP Endpoint: {}".format(self.endpoint))
if isinstance(metrics, MetricsLogger):
metrics.set_dimensions()
metrics.put_dimensions({"HTTPModelEndpoint": self.endpoint})

try:
self.request_count += 1
if isinstance(metrics, MetricsLogger):
metrics.put_metric(MetricLabels.MODEL_INVOCATION, 1, str(Unit.COUNT.value))

with Timer(
task_str="Invoke HTTP Endpoint",
metric_name=MetricLabels.ENDPOINT_LATENCY,
logger=logger,
metrics_logger=metrics,
):
# If we are not running against a real model
if self.endpoint == ServiceConfig.noop_model_name:
return create_mock_feature_collection(payload)
else:
response = self.http_pool.request(
method="POST",
url=self.endpoint,
body=payload,
)
self.request_count = 1
if isinstance(metrics, MetricsLogger):
metrics.put_metric(MetricLabels.ENDPOINT_RETRY_COUNT, retry_count, str(Unit.COUNT.value))
return geojson.loads(response.data.decode("utf-8"))
except JSONDecodeError as err:
self.error_count += 1
if isinstance(metrics, MetricsLogger):
metrics.put_metric(MetricLabels.FEATURE_DECODE, 1, str(Unit.COUNT.value))
metrics.put_metric(MetricLabels.MODEL_ERROR, 1, str(Unit.COUNT.value))
logger.error(
"Unable to decode response from model. URL: {}, Status: {}, Headers: {}, Response: {}".format(
self.endpoint, response.status, response.info(), response.data
)
)
logger.exception(err)
self.error_count += 1

# Return an empty feature collection if the process errored out
return FeatureCollection([])


class HTTPDetectorBuilder(FeatureEndpointBuilder):
def __init__(
self,
endpoint: str,
):
super().__init__()
self.endpoint = endpoint

def build(self) -> Optional[Detector]:
return HTTPDetector(
endpoint=self.endpoint,
)
46 changes: 33 additions & 13 deletions ..._runner/inference/sm_endpoint_detector.py → ...sml/model_runner/inference/sm_detector.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright 2023 Amazon.com, Inc. or its affiliates.

# Copyright 2023 Amazon.com, Inc. or its affiliates.

import logging
from io import BufferedReader
from json import JSONDecodeError
from typing import Dict
from typing import Dict, Optional

import boto3
import geojson
Expand All @@ -18,22 +20,22 @@
from aws.osml.model_runner.common import Timer

from .detector import Detector
from .endpoint_builder import FeatureEndpointBuilder
from .feature_utils import create_mock_feature_collection

logger = logging.getLogger(__name__)


class SMDetector(Detector):
def __init__(self, model_name: str, assumed_credentials: Dict[str, str] = None) -> None:
def __init__(self, endpoint: str, assumed_credentials: Dict[str, str] = None) -> None:
"""
A sagemaker model endpoint invoking object, intended to query sagemaker endpoints.
:param model_name: str = the name of the sagemaker endpoint that will be invoked
:param endpoint: str = the name of the sagemaker endpoint that will be invoked
:param assumed_credentials: Dict[str, str] = Optional credentials to invoke the sagemaker model
:return: None
"""
super().__init__()
if assumed_credentials is not None:
# Here we will be invoking the SageMaker endpoints using an IAM role other than the
# one for this process. Use those credentials when creating the Boto3 SageMaker client.
Expand All @@ -47,13 +49,11 @@ def __init__(self, model_name: str, assumed_credentials: Dict[str, str] = None)
aws_session_token=assumed_credentials.get("SessionToken"),
)
else:
# If no invocation role is provided the assumption is that the default role for this
# If no invocation role is provided, the assumption is that the default role for this
# container will be sufficient to invoke the SageMaker endpoints. This will typically
# be the case for AWS managed models running in the same account as the model runner.
self.sm_client = boto3.client("sagemaker-runtime", config=BotoConfig.sagemaker)
self.model_name = model_name
self.request_count = 0
self.error_count = 0
super().__init__(endpoint=endpoint)

@property
def mode(self) -> ModelInvokeMode:
Expand All @@ -65,15 +65,15 @@ def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> Feat
Query the established endpoint mode to find features based on a payload
:param payload: BufferedReader = the BufferedReader object that holds the
data that will be sent to the feature generator
data that will be sent to the feature generator
:param metrics: MetricsLogger = the metrics logger object to capture the log data on the system
:return: FeatureCollection = a feature collection containing the center point of a tile
"""
logger.info("Invoking Model: {}".format(self.model_name))
logger.info("Invoking Model: {}".format(self.endpoint))
if isinstance(metrics, MetricsLogger):
metrics.set_dimensions()
metrics.put_dimensions({"ModelName": self.model_name})
metrics.put_dimensions({"ModelName": self.endpoint})

try:
self.request_count += 1
Expand All @@ -87,13 +87,13 @@ def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> Feat
metrics_logger=metrics,
):
# If we are not running against a real model
if self.model_name == ServiceConfig.noop_model_name:
if self.endpoint == ServiceConfig.noop_model_name:
# We are expecting the body of the message to contain a geojson FeatureCollection
return create_mock_feature_collection(payload)
else:
# Use the sagemaker model endpoint to invoke the model and return detection points
# as a geojson FeatureCollection
model_response = self.sm_client.invoke_endpoint(EndpointName=self.model_name, Body=payload)
model_response = self.sm_client.invoke_endpoint(EndpointName=self.endpoint, Body=payload)
retry_count = model_response.get("ResponseMetadata", {}).get("RetryAttempts", 0)
if isinstance(metrics, MetricsLogger):
metrics.put_metric(MetricLabels.ENDPOINT_RETRY_COUNT, retry_count, str(Unit.COUNT.value))
Expand Down Expand Up @@ -124,3 +124,23 @@ def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> Feat

# Return an empty feature collection if the process errored out
return FeatureCollection([])


class SMDetectorBuilder(FeatureEndpointBuilder):
def __init__(self, endpoint: str, assumed_credentials: Dict[str, str] = None):
"""
:param endpoint: The URL to the SageMaker endpoint
:param assumed_credentials: The credentials to use with the SageMaker endpoint
"""
super().__init__()
self.endpoint = endpoint
self.assumed_credentials = assumed_credentials

def build(self) -> Optional[Detector]:
"""
:return: a SageMaker detector based on the parameters defined during initialization
"""
return SMDetector(
endpoint=self.endpoint,
assumed_credentials=self.assumed_credentials,
)
Loading

0 comments on commit 1209287

Please sign in to comment.