Skip to content

Commit

Permalink
feature: http endpoint support
Browse files Browse the repository at this point in the history
  • Loading branch information
drduhe committed Oct 26, 2023
1 parent f51ac8c commit 7da9a82
Show file tree
Hide file tree
Showing 17 changed files with 4,650 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/osml-model-runner-test-build.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: "OSML Models Build Workflow"
name: "OSML Model Runner Test Build Workflow"

on:
pull_request:
Expand Down
23 changes: 12 additions & 11 deletions bin/process_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,17 @@

# set up a cli tool for the script using argparse
parser = argparse.ArgumentParser("process_image")
parser.add_argument("--image", help="The target image URL to process with OSML Model Runner.", type=str, default="small")
parser.add_argument("--model", help="The target model to use for object detection.", type=str, default="centerpoint")
parser.add_argument("--image", help="Target image URL to process with OSML Model Runner.", type=str, default="small")
parser.add_argument("--model", help="Target model to use for object detection.", type=str, default="centerpoint")
parser.add_argument("--skip_integ", help="Whether or not to compare image with known results.", action="store_true")
parser.add_argument("--tile_format", help="The target tile format to use for tiling.", type=str)
parser.add_argument("--tile_compression", help="The compression used for the target image.", type=str)
parser.add_argument("--tile_size", help="The tile size to split the image into for model processing.", type=str)
parser.add_argument("--tile_overlap", help="The tile overlap to consider when processing regions.", type=str)
parser.add_argument("--feature_selection_options", help="The feature selection options JSON string.", type=str)
parser.add_argument("--region", help="The AWS region OSML is deployed to.", type=str, default=default_region)
parser.add_argument("--account", help="The AWS account OSML is deployed to.", type=str, default=default_account)
parser.add_argument("--tile_format", help="Target tile format to use for tiling.", type=str)
parser.add_argument("--tile_compression", help="Compression used for the target image.", type=str)
parser.add_argument("--tile_size", help="Tile size to split the image into for model processing.", type=str)
parser.add_argument("--tile_overlap", help="Tile overlap to consider when processing regions.", type=str)
parser.add_argument("--feature_selection_options", help="Feature selection options JSON string.", type=str)
parser.add_argument("--region", help="AWS region OSML is deployed to.", type=str, default=default_region)
parser.add_argument("--account", help="AWS account OSML is deployed to.", type=str, default=default_account)
parser.add_argument("--endpoint_type", help="Type of model endpoint to test, sm or http.", type=str, default="sm")
args = parser.parse_args()

# standard test images deployed by CDK
Expand All @@ -67,7 +68,7 @@
"sicd_interferometric_hh_ntf": f"s3://{image_bucket}/sicd-interferometric-hh.nitf",
}

# call into root directory of this package so that we can run this script from anywhere.
# call into the root directory of this package so that we can run this script from anywhere.
os.chdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))

# set the python path to include the project source
Expand Down Expand Up @@ -95,6 +96,6 @@
test = "src/aws/osml/process_image/test_process_image.py"
else:
# run integration test against known results
test = f"src/aws/osml/integ/{args.model}/test_{args.model}_model.py"
test = f"src/aws/osml/integ/{args.endpoint_type}_{args.model}/test_{args.endpoint_type}_{args.model}_model.py"

subprocess.run(["python3", "-m", "pytest", "-o", "log_cli=true", "-vv", test])
4 changes: 2 additions & 2 deletions environment-py311.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ channels:
- conda-forge
dependencies:
- conda-forge::python=3.11
- conda-forge::gdal=3.7.0
- conda-forge::proj=9.2.1
- conda-forge::gdal=3.7.2
- conda-forge::proj=9.3.0
6 changes: 3 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: osml_model_runner_test
name: osml_models
channels:
- conda-forge
dependencies:
- conda-forge::gdal=3.7.0
- conda-forge::proj=9.2.1
- conda-forge::gdal=3.7.2
- conda-forge::proj=9.3.0
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ install_requires =
geojson==3.0.1
pytest==7.3.1
setuptools==68.0.0
toml==0.10.2


[options.packages.find]
Expand Down
File renamed without changes.
69 changes: 69 additions & 0 deletions src/aws/osml/integ/http_centerpoint/test_http_centerpoint_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2023 Amazon.com, Inc. or its affiliates.

import logging

from aws.osml.utils import (
OSMLConfig,
count_features,
count_region_request_items,
ddb_client,
elb_client,
kinesis_client,
run_model_on_image,
s3_client,
sqs_client,
validate_expected_region_request_items,
validate_features_match,
)

logger = logging.getLogger()
logger.setLevel(logging.INFO)


def test_model_runner_centerpoint_http_model() -> None:
"""
Run the test using the CenterPointModel and validate the number of features
and region requests using the HTTP endpoint
:return: None
"""

if OSMLConfig.HTTP_CENTERPOINT_MODEL_URL:
http_endpoint_url = OSMLConfig.HTTP_CENTERPOINT_MODEL_URL
else:
http_endpoint_dns = get_load_balancer_dns_url(OSMLConfig.HTTP_CENTERPOINT_MODEL_ELB_NAME)
http_endpoint_url = f"http://{http_endpoint_dns}{OSMLConfig.HTTP_CENTERPOINT_MODEL_INFERENCE_PATH}"

# launch our image request and validate it completes
image_id, job_id, image_processing_request, shard_iter = run_model_on_image(
sqs_client(), http_endpoint_url, "HTTP_ENDPOINT", kinesis_client()
)

# count the created features in the table for this image
count_features(image_id=image_id, ddb_client=ddb_client())

# verify the results we created in the appropriate syncs
validate_features_match(
image_processing_request=image_processing_request,
job_id=job_id,
shard_iter=shard_iter,
s3_client=s3_client(),
kinesis_client=kinesis_client(),
)

# validate the number of region requests that were created in the process and check if they are succeeded
region_request_count = count_region_request_items(image_id=image_id, ddb_client=ddb_client())
validate_expected_region_request_items(region_request_count)


def get_load_balancer_dns_url(load_balancer_name: str) -> str:
"""
Get the DNS URL for the given load balancer
:param load_balancer_name: The name of the load balancer
:return: The DNS URL for the load balancer
"""
logger.debug("Retrieving DNS name for '{}'...".format(load_balancer_name))
res = elb_client().describe_load_balancers(Names=[load_balancer_name])
dns_name = res.get("LoadBalancers", [])[0].get("DNSName")
logger.debug("Found DNS name: {}".format(dns_name))
return dns_name
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_model_runner_aircraft_model() -> None:

# Launch our image request and validate it completes
image_id, job_id, image_processing_request, kinesis_shard = run_model_on_image(
sqs_client(), OSMLConfig.SM_AIRCRAFT_MODEL, kinesis_client()
sqs_client(), OSMLConfig.SM_AIRCRAFT_MODEL, "SM_ENDPOINT", kinesis_client()
)

# Count the features that were create in the table for this image
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_model_runner_center_point_model() -> None:

# launch our image request and validate it completes
image_id, job_id, image_processing_request, shard_iter = run_model_on_image(
sqs_client(), OSMLConfig.SM_CENTERPOINT_MODEL, kinesis_client()
sqs_client(), OSMLConfig.SM_CENTERPOINT_MODEL, "SM_ENDPOINT", kinesis_client()
)

# count the features that were create in the table for this image
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_model_runner_flood_model() -> None:

# Launch our image request and validate it completes
image_id, job_id, image_processing_request, kinesis_shard = run_model_on_image(
sqs_client(), OSMLConfig.SM_FLOOD_MODEL, kinesis_client()
sqs_client(), OSMLConfig.SM_FLOOD_MODEL, "SM_ENDPOINT", kinesis_client()
)

# Count the features that were create in the table for this image
Expand Down
2 changes: 1 addition & 1 deletion src/aws/osml/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# __init__.py file.
# flake8: noqa

from .clients import cw_client, ddb_client, kinesis_client, s3_client, sm_client, sqs_client
from .clients import cw_client, ddb_client, elb_client, kinesis_client, s3_client, sm_client, sqs_client
from .integ_utils import (
build_image_processing_request,
count_features,
Expand Down
10 changes: 10 additions & 0 deletions src/aws/osml/utils/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,13 @@ def cw_client() -> boto3.client:
"""
session = get_session_credentials()
return session.client("cloudwatch", region_name=OSMLConfig.REGION)


def elb_client() -> boto3.client:
"""
Get resources from the default ElasticLoadBalancing session
:return: boto3.client = ELB client
"""
session = get_session_credentials()
return session.client("elbv2", region_name=OSMLConfig.REGION)
9 changes: 7 additions & 2 deletions src/aws/osml/utils/osml_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class OSMLConfig:
SM_FLOOD_MODEL: str = os.getenv("SM_FLOOD_MODEL", "flood")
SM_AIRCRAFT_MODEL: str = os.getenv("SM_AIRCRAFT_MODEL", "aircraft")

# HTTP model config
HTTP_CENTERPOINT_MODEL_URL: str = os.getenv("HTTP_CENTER_POINT_MODEL_URL", None)
HTTP_CENTERPOINT_MODEL_ELB_NAME: str = os.getenv("HTTP_CENTER_POINT_MODEL_ELB_NAME", "test-http-model-endpoint")
HTTP_CENTERPOINT_MODEL_INFERENCE_PATH = os.getenv("HTTP_CENTERPOINT_MODEL_INFERENCE_PATH", "/invocations")

# bucket name prefixes
S3_RESULTS_BUCKET: str = os.getenv("S3_RESULTS_BUCKET")
S3_RESULTS_BUCKET_PREFIX: str = os.getenv("S3_RESULTS_BUCKET_PREFIX", "test-results")
Expand Down Expand Up @@ -62,5 +67,5 @@ class OSMLLoadTestConfig:
S3_LOAD_TEST_RESULT_BUCKET: str = os.getenv("S3_LOAD_TEST_RESULT_BUCKET")

# processing workflow
PERIODIC_SLEEP_SECS: str = os.getenv("PERIODIC_SLEEP_SECS", "60") # in seconds
PROCESSING_WINDOW_MIN: str = os.getenv("PROCESSING_WINDOW_MIN", "1") # in hours
PERIODIC_SLEEP_SECS: str = os.getenv("PERIODIC_SLEEP_SECS", "60")
PROCESSING_WINDOW_MIN: str = os.getenv("PROCESSING_WINDOW_MIN", "1")
Loading

0 comments on commit 7da9a82

Please sign in to comment.