diff --git a/src/aws/osml/model_runner/api/image_request.py b/src/aws/osml/model_runner/api/image_request.py index ec3a1f2b..c6423c25 100755 --- a/src/aws/osml/model_runner/api/image_request.py +++ b/src/aws/osml/model_runner/api/image_request.py @@ -1,11 +1,13 @@ # Copyright 2023-2024 Amazon.com, Inc. or its affiliates. import logging +from dataclasses import dataclass, field from json import dumps, loads from typing import Any, Dict, List, Optional import shapely.geometry import shapely.wkt +from dacite import from_dict from shapely.geometry.base import BaseGeometry from aws.osml.model_runner.common import ( @@ -21,148 +23,207 @@ from .inference import ModelInvokeMode from .request_utils import shared_properties_are_valid -from .sink import SinkType +from .sink import VALID_SYNC_TYPES, SinkType logger = logging.getLogger(__name__) +DEFAULT_TILE_SIZE = (1024, 1024) +DEFAULT_TILE_OVERLAP = (50, 50) -class ImageRequest(object): + +@dataclass +class ImageRequest: """ Request for the Model Runner to process an image. - This class contains the attributes that make up an image processing request along with + This class contains the attributes that make up an image processing request, along with constructors and factory methods used to create these requests from common constructs. - """ - def __init__(self, *initial_data: Dict[str, Any], **kwargs: Any): - """ - This constructor allows users to create these objects using a combination of dictionaries - and keyword arguments. + Attributes: + job_id: The unique identifier for the image processing job. + image_id: A combined identifier for the image, usually composed of the job ID and image URL. + image_url: The URL location of the image to be processed. + image_read_role: The IAM role used to read the image from the provided URL. + outputs: A list of output configurations where results should be stored. + model_name: The name of the model to use for image processing. + model_invoke_mode: The mode in which the model is invoked, such as synchronous or asynchronous. + tile_size: Dimensions of the tiles into which the image is split for processing. + tile_overlap: Overlap between tiles, defined in dimensions. + tile_format: The format of the tiles (e.g., NITF, GeoTIFF). + tile_compression: Compression type to use for the tiles (e.g., None, JPEG). + model_invocation_role: IAM role assumed for invoking the model. + feature_properties: Additional properties to include in the feature processing. + roi: Region of interest within the image, defined as a geometric shape. + post_processing: List of post-processing steps to apply to the features detected. + """ - :param initial_data: Dict[str, Any] = dictionaries that contain attributes/values that map to this class's - attributes - :param kwargs: Any = keyword arguments provided on the constructor to set specific attributes - """ - default_post_processing = [ + job_id: str = "" + image_id: str = "" + image_url: str = "" + image_read_role: str = "" + outputs: List[Dict[str, Any]] = field(default_factory=list) + model_name: str = "" + model_invoke_mode: ModelInvokeMode = ModelInvokeMode.NONE + tile_size: ImageDimensions = DEFAULT_TILE_SIZE + tile_overlap: ImageDimensions = DEFAULT_TILE_OVERLAP + tile_format: str = ImageFormats.NITF.value + tile_compression: str = ImageCompression.NONE.value + model_invocation_role: str = "" + feature_properties: List[Dict[str, Any]] = field(default_factory=list) + roi: Optional[BaseGeometry] = None + post_processing: List[MRPostProcessing] = field( + default_factory=lambda: [ MRPostProcessing(step=MRPostprocessingStep.FEATURE_DISTILLATION, algorithm=FeatureDistillationNMS()) ] + ) - self.job_id: str = "" - self.image_id: str = "" - self.image_url: str = "" - self.image_read_role: str = "" - self.outputs: List[dict] = [] - self.model_name: str = "" - self.model_invoke_mode: ModelInvokeMode = ModelInvokeMode.NONE - self.tile_size: ImageDimensions = (1024, 1024) - self.tile_overlap: ImageDimensions = (50, 50) - self.tile_format: ImageFormats = ImageFormats.NITF - self.tile_compression: ImageCompression = ImageCompression.NONE - self.model_invocation_role: str = "" - self.feature_properties: List[dict] = [] - self.roi: Optional[BaseGeometry] = None - self.post_processing: List[MRPostProcessing] = default_post_processing - - for dictionary in initial_data: - for key in dictionary: - setattr(self, key, dictionary[key]) - for key in kwargs: - setattr(self, key, kwargs[key]) + @staticmethod + def from_external_message(image_request: Dict[str, Any]) -> "ImageRequest": + """ + Constructs an ImageRequest from a dictionary that represents an external message. + + :param image_request: Dictionary of values from the decoded JSON request. + :return: ImageRequest instance. + """ + properties: Dict[str, Any] = { + "job_id": image_request.get("jobId", ""), + "image_url": image_request.get("imageUrls", [""])[0], + "image_id": f"{image_request.get('jobId', '')}:{image_request.get('imageUrls', [''])[0]}", + "image_read_role": image_request.get("imageReadRole", ""), + "model_name": image_request["imageProcessor"]["name"], + "model_invoke_mode": ImageRequest._parse_model_invoke_mode(image_request["imageProcessor"].get("type")), + "model_invocation_role": image_request["imageProcessor"].get("assumedRole", ""), + "tile_size": ImageRequest._parse_tile_dimension(image_request.get("imageProcessorTileSize")), + "tile_overlap": ImageRequest._parse_tile_dimension(image_request.get("imageProcessorTileOverlap")), + "tile_format": ImageRequest._parse_tile_format(image_request.get("imageProcessorTileFormat")), + "tile_compression": ImageRequest._parse_tile_compression(image_request.get("imageProcessorTileCompression")), + "roi": ImageRequest._parse_roi(image_request.get("regionOfInterest")), + "outputs": ImageRequest._parse_outputs(image_request), + "feature_properties": image_request.get("featureProperties", []), + "post_processing": ImageRequest._parse_post_processing(image_request.get("postProcessing")), + } + return from_dict(ImageRequest, properties) @staticmethod - def from_external_message(image_request: Dict[str, Any]): + def _parse_tile_dimension(value: Optional[str]) -> ImageDimensions: """ - This method is used to construct an ImageRequest given a dictionary reconstructed from the - JSON representation of a request that appears on the Image Job Queue. The structure of - that message is generally governed by AWS API best practices and may evolve over time as - the public APIs for this service mature. + Converts a string value to a tuple of integers representing tile dimensions. - :param image_request: Dict[str, Any] = dictionary of values from the decoded JSON request + :param value: String value representing tile dimension. + :return: Tuple of integers as tile dimensions. + """ + return (int(value), int(value)) if value else DEFAULT_TILE_SIZE - :return: the ImageRequest + @staticmethod + def _parse_roi(roi: Optional[str]) -> Optional[BaseGeometry]: """ - properties: Dict[str, Any] = {} - if "imageProcessorTileSize" in image_request: - tile_dimension = int(image_request["imageProcessorTileSize"]) - properties["tile_size"] = (tile_dimension, tile_dimension) + Parses the region of interest from a WKT string. - if "imageProcessorTileOverlap" in image_request: - overlap_dimension = int(image_request["imageProcessorTileOverlap"]) - properties["tile_overlap"] = (overlap_dimension, overlap_dimension) + :param roi: WKT string representing the region of interest. + :return: Parsed BaseGeometry object or None. + """ + return shapely.wkt.loads(roi) if roi else None - if "imageProcessorTileFormat" in image_request: - properties["tile_format"] = image_request["imageProcessorTileFormat"] + @staticmethod + def _parse_tile_format(tile_format: Optional[str]) -> Optional[ImageFormats]: + """ + Parses the region desired tile format to use for processing. + + :param tile_format: String representing the tile format to use. + :return: Parsed ImageFormats object or ImageFormats.NITF. + """ + return ImageFormats[tile_format].value if tile_format else ImageFormats.NITF.value - if "imageProcessorTileCompression" in image_request: - properties["tile_compression"] = image_request["imageProcessorTileCompression"] + @staticmethod + def _parse_tile_compression(tile_compression: Optional[str]) -> Optional[ImageCompression]: + """ + Parses the region desired tile compression format to use for processing. - properties["job_id"] = image_request["jobId"] + :param tile_compression: String representing the tile compression format to use. + :return: Parsed ImageFormats object or ImageCompression.NONE. + """ + return ImageCompression[tile_compression].value if tile_compression else ImageCompression.NONE.value - properties["image_url"] = image_request["imageUrls"][0] - properties["image_id"] = image_request["jobId"] + ":" + properties["image_url"] - if "imageReadRole" in image_request: - properties["image_read_role"] = image_request["imageReadRole"] + @staticmethod + def _parse_model_invoke_mode(model_invoke_mode: Optional[str]) -> Optional[ModelInvokeMode]: + """ + Parses the region desired tile compression format to use for processing. - properties["model_name"] = image_request["imageProcessor"]["name"] - properties["model_invoke_mode"] = image_request["imageProcessor"]["type"] - if "assumedRole" in image_request["imageProcessor"]: - properties["model_invocation_role"] = image_request["imageProcessor"]["assumedRole"] + :param model_invoke_mode: String representing the tile compression format to use. + :return: Parsed ModelInvokeMode object or ModelInvokeMode.SM_ENDPOINT. + """ + return ModelInvokeMode[model_invoke_mode] if model_invoke_mode else ModelInvokeMode.SM_ENDPOINT - if "regionOfInterest" in image_request: - properties["roi"] = shapely.wkt.loads(image_request["regionOfInterest"]) + @staticmethod + def _parse_outputs(image_request: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Parses the output configuration from the image request, including support for legacy inputs. - # Support explicit outputs + :param image_request: Dictionary of image request attributes. + :return: List of output configurations. + """ if image_request.get("outputs"): - properties["outputs"] = image_request["outputs"] - # Support legacy image request - elif image_request.get("outputBucket") and image_request.get("outputPrefix"): - properties["outputs"] = [ + return image_request["outputs"] + + # Support legacy image request fields: outputBucket and outputPrefix + if image_request.get("outputBucket") and image_request.get("outputPrefix"): + return [ { "type": SinkType.S3.value, "bucket": image_request["outputBucket"], "prefix": image_request["outputPrefix"], } ] - if image_request.get("featureProperties"): - properties["feature_properties"] = image_request["featureProperties"] - if image_request.get("postProcessing"): - image_request["postProcessing"] = loads( - dumps(image_request["postProcessing"]) - .replace("algorithmType", "algorithm_type") - .replace("iouThreshold", "iou_threshold") - .replace("skipBoxThreshold", "skip_box_threshold") - ) - properties["post_processing"] = deserialize_post_processing_list(image_request.get("postProcessing")) - - return ImageRequest(properties) + # No outputs were defined in the request + logger.warning("No output syncs were present in this request.") + return [] + + @staticmethod + def _parse_post_processing(post_processing: Optional[Dict[str, Any]]) -> List[MRPostProcessing]: + """ + Deserializes and cleans up post-processing data. + + :param post_processing: Dictionary of post-processing configurations. + :return: List of MRPostProcessing instances. + """ + if not post_processing: + return [MRPostProcessing(step=MRPostprocessingStep.FEATURE_DISTILLATION, algorithm=FeatureDistillationNMS())] + cleaned_post_processing = loads( + dumps(post_processing) + .replace("algorithmType", "algorithm_type") + .replace("iouThreshold", "iou_threshold") + .replace("skipBoxThreshold", "skip_box_threshold") + ) + return deserialize_post_processing_list(cleaned_post_processing) def is_valid(self) -> bool: """ - Check to see if this request contains required attributes and meaningful values + Validates whether the ImageRequest instance has all required attributes. - :return: bool = True if the request contains all the mandatory attributes with acceptable values, - False otherwise + :return: True if valid, False otherwise. """ if not shared_properties_are_valid(self): logger.error("Invalid shared properties in ImageRequest") return False - - if not self.job_id or not self.outputs: - logger.error("Missing job id or outputs properties in ImageRequest") + if not self.job_id: + logger.error("Missing job id in ImageRequest") return False - - num_feature_detection_options = len(self.get_feature_distillation_option()) - if num_feature_detection_options > 1: - logger.error(f"{num_feature_detection_options} feature distillation options in ImageRequest") + if len(self.get_feature_distillation_option()) > 1: + logger.error("Multiple feature distillation options in ImageRequest") return False - + if len(self.outputs) > 0: + for output in self.outputs: + sink_type = output.get("type") + if sink_type not in VALID_SYNC_TYPES: + logger.error(f"Invalid sink type '{sink_type}' in ImageRequest") + return False return True def get_shared_values(self) -> Dict[str, Any]: """ - Returns a formatted dict that contains the properties of an image + Retrieves a dictionary of shared values related to the image. - :return: Dict[str, Any] = the properties of an image + :return: Dictionary of shared image properties. """ return { "image_id": self.image_id, @@ -180,8 +241,9 @@ def get_shared_values(self) -> Dict[str, Any]: def get_feature_distillation_option(self) -> List[FeatureDistillationAlgorithm]: """ - Parses the post-processing property and extracts the relevant feature distillation selection, if present - :return: + Extracts the feature distillation options from the post-processing configuration. + + :return: List of FeatureDistillationAlgorithm instances. """ return [ op.algorithm diff --git a/src/aws/osml/model_runner/api/request_utils.py b/src/aws/osml/model_runner/api/request_utils.py index b95393a7..934f0b1c 100755 --- a/src/aws/osml/model_runner/api/request_utils.py +++ b/src/aws/osml/model_runner/api/request_utils.py @@ -1,4 +1,5 @@ # Copyright 2023-2024 Amazon.com, Inc. or its affiliates. +import logging import boto3 @@ -8,34 +9,44 @@ from .exceptions import InvalidS3ObjectException from .inference import VALID_MODEL_HOSTING_OPTIONS +logger = logging.getLogger(__name__) + def shared_properties_are_valid(request) -> bool: """ - There are some attributes that are shared between ImageRequests and RegionRequests. This - function exists to validate if ImageRequests/RegionRequests have all the metadata info - in order to process it + Validates that the request contains all mandatory attributes with acceptable values. - :param request: an object of either ImageRequests or RegionRequests + This function checks attributes shared between ImageRequests and RegionRequests, ensuring + they contain all required metadata for processing. - :return: bool = True if the request contains all the mandatory attributes with acceptable values, - False otherwise + :param request: An object of either ImageRequests or RegionRequests. + :return: True if the request is valid; False otherwise. Logs warnings for each failed validation. """ if not request.image_id or not request.image_url: + logger.error("Validation failed: `image_id` or `image_url` is missing.") return False if not request.model_name: + logger.error("Validation failed: `model_name` is missing.") return False if not request.model_invoke_mode or request.model_invoke_mode not in VALID_MODEL_HOSTING_OPTIONS: + logger.error( + f"Validation failed: `model_invoke_mode` is either missing or invalid. " + f"Expected one of {VALID_MODEL_HOSTING_OPTIONS}, but got '{request.model_invoke_mode}'." + ) return False if not request.tile_size or len(request.tile_size) != 2: + logger.error("Validation failed: `tile_size` is missing or does not contain two dimensions.") return False if request.tile_size[0] <= 0 or request.tile_size[1] <= 0: + logger.error("Validation failed: `tile_size` dimensions must be positive values.") return False if not request.tile_overlap or len(request.tile_overlap) != 2: + logger.error("Validation failed: `tile_overlap` is missing or does not contain two dimensions.") return False if ( @@ -44,18 +55,29 @@ def shared_properties_are_valid(request) -> bool: or request.tile_overlap[1] < 0 or request.tile_overlap[1] >= request.tile_size[1] ): + logger.error("Validation failed: `tile_overlap` values must be non-negative and less than `tile_size` dimensions.") return False if not request.tile_format or request.tile_format not in VALID_IMAGE_FORMATS: + logger.error( + f"Validation failed: `tile_format` is either missing or invalid. " + f"Expected one of {VALID_IMAGE_FORMATS}, but got '{request.tile_format}'." + ) return False if request.tile_compression and request.tile_compression not in VALID_IMAGE_COMPRESSION: + logger.error( + f"Validation failed: `tile_compression` is invalid. " + f"Expected one of {VALID_IMAGE_COMPRESSION}, but got '{request.tile_compression}'." + ) return False if request.image_read_role and not request.image_read_role.startswith("arn:"): + logger.error("Validation failed: `image_read_role` does not start with 'arn:'.") return False if request.model_invocation_role and not request.model_invocation_role.startswith("arn:"): + logger.error("Validation failed: `model_invocation_role` does not start with 'arn:'.") return False return True @@ -80,7 +102,7 @@ def get_image_path(image_url: str, assumed_role: str) -> str: return image_url -def validate_image_path(image_url: str, assumed_role: str) -> bool: +def validate_image_path(image_url: str, assumed_role: str = None) -> bool: """ Validate if an image exists in S3 bucket diff --git a/src/aws/osml/model_runner/api/sink.py b/src/aws/osml/model_runner/api/sink.py index 3a72545c..e15ffaa2 100755 --- a/src/aws/osml/model_runner/api/sink.py +++ b/src/aws/osml/model_runner/api/sink.py @@ -22,3 +22,6 @@ class SinkType(str, AutoStringEnum): S3 = auto() # Mode not set to auto due to contract having been set as "Kinesis" already KINESIS = "Kinesis" + + +VALID_SYNC_TYPES = {sink_type.value for sink_type in SinkType} diff --git a/test/aws/osml/model_runner/api/test_image_request.py b/test/aws/osml/model_runner/api/test_image_request.py index ba82e86f..c6891de4 100755 --- a/test/aws/osml/model_runner/api/test_image_request.py +++ b/test/aws/osml/model_runner/api/test_image_request.py @@ -1,56 +1,183 @@ # Copyright 2023-2024 Amazon.com, Inc. or its affiliates. import unittest +from unittest import TestCase -legacy_execution_role = "arn:aws:iam::012345678910:role/OversightMLBetaInvokeRole" +import boto3 +import pytest +from botocore.stub import Stubber +from aws.osml.model_runner.api import InvalidS3ObjectException +from aws.osml.model_runner.api.image_request import ImageRequest, ModelInvokeMode +from aws.osml.model_runner.api.request_utils import validate_image_path +from aws.osml.model_runner.app_config import BotoConfig +from aws.osml.model_runner.sink import Sink, SinkFactory -class TestImageRequest(unittest.TestCase): - def setUp(self): + +class TestImageRequest(TestCase): + def test_invalid_data(self): """ - Set up virtual DDB resources/tables for each test to use + Test ImageRequest with missing or invalid image_id. """ - self.sample_request_data = self.build_request_data() + ir = self.build_request_data() + ir.image_id = None + assert not ir.is_valid() - def tearDown(self): - self.sample_request_data = None + def test_invalid_job_id(self): + """ + Test ImageRequest with missing job_id. + """ + ir = self.build_request_data() + ir.job_id = None + assert not ir.is_valid() - def test_invalid_data(self): - from aws.osml.model_runner.api.image_request import ImageRequest + def test_valid_data(self): + """ + Test ImageRequest with valid data to ensure it passes validation. + """ + ir = self.build_request_data() + assert ir.is_valid() - ir = ImageRequest( - self.sample_request_data, - image_id="", + def test_invalid_tile_size(self): + """ + Test ImageRequest with invalid tile size to check error handling. + """ + ir = self.build_request_data() + ir.tile_size = None + assert not ir.is_valid() + + def test_from_external_message(self): + """ + Test ImageRequest created from external message deserialization. + """ + ir = ImageRequest.from_external_message( + { + "jobName": "test-job-name", + "jobId": "test-job-id", + "imageUrls": ["test-image-url"], + "outputs": [ + {"type": "S3", "bucket": "test-bucket", "prefix": "test-bucket-prefix"}, + {"type": "Kinesis", "stream": "test-stream", "batchSize": 1000}, + ], + "imageProcessor": {"name": "test-model", "type": "SM_ENDPOINT"}, + "imageProcessorTileSize": 1024, + "imageProcessorTileOverlap": 50, + } ) + assert ir.is_valid() + assert ir.job_id == "test-job-id" + assert ir.model_name == "test-model" + assert ir.tile_size == (1024, 1024) + assert ir.tile_overlap == (50, 50) - assert not ir.is_valid() + def test_default_initialization(self): + """ + Test ImageRequest default initialization to ensure default values are set correctly. + """ + ir = ImageRequest() + assert ir.tile_size == (1024, 1024) + assert ir.tile_overlap == (50, 50) + assert ir.tile_format == "NITF" + assert ir.model_invoke_mode == ModelInvokeMode.NONE - def test_invalid_job_id(self): - from aws.osml.model_runner.api.image_request import ImageRequest + def test_feature_distillation_parsing(self): + """ + Test that ImageRequest can correctly parse and handle feature distillation options. + """ + ir = self.build_request_data() + distillation_option = ir.get_feature_distillation_option() + assert isinstance(distillation_option, list) + assert len(distillation_option) == 1 - ir = ImageRequest( - self.sample_request_data, - job_id=None, + def test_image_request_from_minimal_message_legacy_output(self): + """ + Test ImageRequest creation from a minimal message using legacy output fields. + """ + ir = ImageRequest.from_external_message( + { + "jobName": "test-job-name", + "jobId": "test-job-id", + "imageUrls": ["test-image-url"], + "imageProcessor": {"name": "test-model", "type": "SM_ENDPOINT"}, + "imageProcessorTileSize": 1024, + "imageProcessorTileOverlap": 50, + "outputBucket": "test-bucket", + "outputPrefix": "images/outputs", + } ) - assert not ir.is_valid() + assert ir.is_valid() + assert len(ir.outputs) == 1 + + # Check S3 Sink creation from outputs + sinks = SinkFactory.outputs_to_sinks(ir.outputs) + s3_sink: Sink = sinks[0] + assert s3_sink.name() == "S3" + assert getattr(s3_sink, "bucket") == "test-bucket" + assert getattr(s3_sink, "prefix") == "images/outputs" + + def test_image_request_invalid_sink(self): + """ + Test ImageRequest creation with an invalid sink type. + """ + request = ImageRequest.from_external_message( + { + "jobName": "test-job-name", + "jobId": "test-job-id", + "imageUrls": ["test-image-url"], + "outputs": [{"type": "SQS", "queue": "FakeQueue"}], + "imageProcessor": {"name": "test-model", "type": "SM_ENDPOINT"}, + "imageProcessorTileSize": 1024, + "imageProcessorTileOverlap": 50, + } + ) + + # Should fail with an invalid sync type provided + assert not request.is_valid() + + def test_image_request_invalid_image_path(self): + """ + Test validation of an invalid S3 image path. + """ + s3_client = boto3.client("s3", config=BotoConfig.default) + s3_client_stub = Stubber(s3_client) + s3_client_stub.activate() + + image_path = "s3://test-results-bucket/test/data/small.ntf" + + s3_client_stub.add_client_error( + "head_object", + service_error_code="404", + service_message="Not Found", + expected_params={"Bucket": image_path}, + ) + + with pytest.raises(InvalidS3ObjectException): + validate_image_path(image_path, None) + + s3_client_stub.deactivate() @staticmethod def build_request_data(): - return { - "job_id": "test-job", - "image_id": "test-image-id", - "image_url": "test-image-url", - "image_read_role": "arn:aws:iam::012345678910:role/OversightMLS3ReadOnly", - "output_bucket": "unit-test", - "output_prefix": "region-request", - "tile_size": (10, 10), - "tile_overlap": (1, 1), - "tile_format": "NITF", - "model_name": "test-model-name", - "model_invoke_mode": "SM_ENDPOINT", - "model_invocation_role": "arn:aws:iam::012345678910:role/OversightMLModelInvoker", - } + """ + Helper method to build sample request data for tests. + """ + return ImageRequest( + job_id="test-job-id", + image_id="test-image-id", + image_url="test-image-url", + image_read_role="arn:aws:iam::012345678910:role/TestRole", + outputs=[ + {"type": "S3", "bucket": "test-bucket", "prefix": "test-bucket-prefix"}, + {"type": "Kinesis", "stream": "test-stream", "batchSize": 1000}, + ], + tile_size=(1024, 1024), + tile_overlap=(50, 50), + tile_format="NITF", + model_name="test-model-name", + model_invoke_mode="SM_ENDPOINT", + model_invocation_role="arn:aws:iam::012345678910:role/TestRole", + ) if __name__ == "__main__": diff --git a/test/test_api.py b/test/test_api.py deleted file mode 100755 index 3bda0b25..00000000 --- a/test/test_api.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright 2023-2024 Amazon.com, Inc. or its affiliates. - -from typing import Any, Dict -from unittest import TestCase, main -from unittest.mock import patch - -import boto3 -import pytest -import shapely.geometry -from botocore.stub import Stubber - -TEST_S3_FULL_BUCKET_PATH = "s3://test-results-bucket/test/data/small.ntf" - -base_request = { - "jobName": "test-job", - "jobId": "5f4e8a55-95cf-4d96-95cd-9b037f767eff", - "imageUrls": ["s3://fake-bucket/images/test-image-id"], - "imageProcessor": {"name": "test-model-name", "type": "SM_ENDPOINT"}, -} - - -class TestModelRunnerAPI(TestCase): - def test_region_request_constructor(self): - from aws.osml.model_runner.api.image_request import ModelInvokeMode - from aws.osml.model_runner.api.region_request import RegionRequest - from aws.osml.model_runner.common.typing import ImageCompression, ImageFormats - - region_request_template = { - "model_name": "test-model-name", - "model_invoke_mode": "SM_ENDPOINT", - "model_invocation_role": "arn:aws:iam::012345678910:role/OversightMLBetaModelInvokerRole", - } - - rr = RegionRequest( - region_request_template, - image_id="test-image-id", - image_url="s3://fake-bucket/images/test-image-id", - image_read_role="arn:aws:iam::012345678910:role/OversightMLBetaS3ReadOnly", - region_bounds=[0, 1, 2, 3], - ) - - # Check to ensure we've created a valid request - assert rr.is_valid() - - # Checks to ensure the dictionary provided values are set - assert rr.model_name == "test-model-name" - assert rr.model_invoke_mode == ModelInvokeMode.SM_ENDPOINT - assert rr.model_invocation_role == "arn:aws:iam::012345678910:role/OversightMLBetaModelInvokerRole" - - # Checks to ensure the keyword arguments are set - assert rr.image_id == "test-image-id" - assert rr.image_url == "s3://fake-bucket/images/test-image-id" - assert rr.image_read_role == "arn:aws:iam::012345678910:role/OversightMLBetaS3ReadOnly" - assert rr.region_bounds == [0, 1, 2, 3] - - # Checks to ensure the defaults are set - assert rr.tile_size == (1024, 1024) - assert rr.tile_overlap == (50, 50) - assert rr.tile_format == ImageFormats.NITF - assert rr.tile_compression == ImageCompression.NONE - - def test_image_request_constructor(self): - from aws.osml.model_runner.api.image_request import ImageRequest - from aws.osml.model_runner.common.typing import ImageCompression - from aws.osml.model_runner.sink.sink import Sink - from aws.osml.model_runner.sink.sink_factory import SinkFactory - - image_request_template = { - "model_name": "test-model-name", - "model_invoke_mode": "SM_ENDPOINT", - "image_read_role": "arn:aws:iam::012345678910:role/OversightMLBetaS3ReadOnly", - } - fake_s3_sink = { - "type": "S3", - "bucket": "fake-bucket", - "prefix": "images/outputs", - "mode": "Aggregate", - } - ir = ImageRequest( - image_request_template, - job_name="test-job", - job_id="5f4e8a55-95cf-4d96-95cd-9b037f767eff", - image_id="5f4e8a55-95cf-4d96-95cd-9b037f767eff:s3://fake-bucket/images/test-image-id", - image_url="s3://fake-bucket/images/test-image-id", - outputs=[fake_s3_sink], - ) - - assert ir.is_valid() - assert ir.image_url == "s3://fake-bucket/images/test-image-id" - assert ir.image_id == "5f4e8a55-95cf-4d96-95cd-9b037f767eff:s3://fake-bucket/images/test-image-id" - assert ir.image_read_role == "arn:aws:iam::012345678910:role/OversightMLBetaS3ReadOnly" - assert ir.tile_size == (1024, 1024) - assert ir.tile_overlap == (50, 50) - assert ir.model_name == "test-model-name" - assert ir.model_invoke_mode == "SM_ENDPOINT" - assert ir.model_invocation_role == "" - assert ir.tile_format == "NITF" - assert ir.tile_compression == ImageCompression.NONE - assert ir.job_id == "5f4e8a55-95cf-4d96-95cd-9b037f767eff" - assert len(ir.outputs) == 1 - sinks = SinkFactory.outputs_to_sinks(ir.outputs) - s3_sink: Sink = sinks[0] - assert s3_sink.name() == "S3" - assert s3_sink.__getattribute__("bucket") == "fake-bucket" - assert s3_sink.__getattribute__("prefix") == "images/outputs" - assert ir.roi is None - - @patch("aws.osml.model_runner.common.credentials_utils.sts_client") - def test_image_request_from_message(self, mock_sts): - from aws.osml.model_runner.api.image_request import ImageRequest - from aws.osml.model_runner.common.typing import ImageCompression - from aws.osml.model_runner.sink.sink import Sink - from aws.osml.model_runner.sink.sink_factory import SinkFactory - - test_access_key_id = "123456789" - test_secret_access_key = "987654321" - test_secret_token = "SecretToken123" - mock_sts.assume_role.return_value = { - "Credentials": { - "AccessKeyId": test_access_key_id, - "SecretAccessKey": test_secret_access_key, - "SessionToken": test_secret_token, - } - } - updates: Dict[str, Any] = { - "jobStatus": "SUBMITTED", - "processingSubmitted": "2021-09-14T00:18:32.130000+00:00", - "imageReadRole": "arn:aws:iam::012345678910:role/OversightMLS3ReadOnly", - "outputs": [ - { - "type": "S3", - "bucket": "fake-bucket", - "prefix": "images/outputs", - "assumedRole": "arn:aws:iam::012345678910:role/OversightMLBetaS3ReadOnlyRole", - } - ], - "imageProcessorTileSize": 2048, - "imageProcessorTileOverlap": 100, - "imageProcessorTileFormat": "PNG", - "imageProcessorTileCompression": ImageCompression.NONE, - "regionOfInterest": "POLYGON((0.5 0.5,5 0,5 5,0 5,0.5 0.5), (1.5 1,4 3,4 1,1.5 1))", - } - message_body = base_request.copy() - message_body.update(updates) - - ir = ImageRequest.from_external_message(message_body) - - assert ir.is_valid() - assert ir.image_url == "s3://fake-bucket/images/test-image-id" - assert ir.image_id == "5f4e8a55-95cf-4d96-95cd-9b037f767eff:s3://fake-bucket/images/test-image-id" - assert ir.image_read_role == "arn:aws:iam::012345678910:role/OversightMLS3ReadOnly" - assert ir.tile_size == (2048, 2048) - assert ir.tile_overlap == (100, 100) - assert ir.model_name == "test-model-name" - assert ir.model_invoke_mode == "SM_ENDPOINT" - assert ir.tile_format == "PNG" - assert ir.tile_compression == ImageCompression.NONE - assert ir.job_id == "5f4e8a55-95cf-4d96-95cd-9b037f767eff" - assert len(ir.outputs) == 1 - sinks = SinkFactory.outputs_to_sinks(ir.outputs) - s3_sink: Sink = sinks[0] - assert s3_sink.name() == "S3" - assert s3_sink.__getattribute__("bucket") == "fake-bucket" - assert s3_sink.__getattribute__("prefix") == "images/outputs" - assert isinstance(ir.roi, shapely.geometry.Polygon) - - def test_image_request_from_minimal_message_legacy_output(self): - from aws.osml.model_runner.api.image_request import ImageRequest - from aws.osml.model_runner.common.typing import ImageCompression - from aws.osml.model_runner.sink.sink import Sink - from aws.osml.model_runner.sink.sink_factory import SinkFactory - - updates: Dict[str, Any] = {"outputBucket": "fake-bucket", "outputPrefix": "images/outputs"} - message_body = base_request.copy() - message_body.update(updates) - - ir = ImageRequest.from_external_message(message_body) - - assert ir.is_valid() - assert ir.image_url == "s3://fake-bucket/images/test-image-id" - assert ir.image_id == "5f4e8a55-95cf-4d96-95cd-9b037f767eff:s3://fake-bucket/images/test-image-id" - assert ir.image_read_role == "" - assert ir.tile_size == (1024, 1024) - assert ir.tile_overlap == (50, 50) - assert ir.model_name == "test-model-name" - assert ir.model_invoke_mode == "SM_ENDPOINT" - assert ir.model_invocation_role == "" - assert ir.tile_format == "NITF" - assert ir.tile_compression == ImageCompression.NONE - assert ir.job_id == "5f4e8a55-95cf-4d96-95cd-9b037f767eff" - assert len(ir.outputs) == 1 - sinks = SinkFactory.outputs_to_sinks(ir.outputs) - s3_sink: Sink = sinks[0] - assert s3_sink.name() == "S3" - assert s3_sink.__getattribute__("bucket") == "fake-bucket" - assert s3_sink.__getattribute__("prefix") == "images/outputs" - assert ir.roi is None - - def test_image_request_multiple_sinks(self): - from aws.osml.model_runner.api.image_request import ImageRequest - from aws.osml.model_runner.common.typing import ImageCompression - from aws.osml.model_runner.sink.sink import Sink - from aws.osml.model_runner.sink.sink_factory import SinkFactory - - updates: Dict[str, Any] = { - "outputs": [ - { - "type": "S3", - "bucket": "fake-bucket", - "prefix": "images/outputs", - "mode": "Aggregate", - }, - {"type": "Kinesis", "stream": "FakeStream", "batchSize": 500}, - ] - } - message_body = base_request.copy() - message_body.update(updates) - - ir = ImageRequest.from_external_message(message_body) - - assert ir.is_valid() - assert ir.image_url == "s3://fake-bucket/images/test-image-id" - assert ir.image_id == "5f4e8a55-95cf-4d96-95cd-9b037f767eff:s3://fake-bucket/images/test-image-id" - assert ir.image_read_role == "" - assert ir.tile_size == (1024, 1024) - assert ir.tile_overlap == (50, 50) - assert ir.model_name == "test-model-name" - assert ir.model_invoke_mode == "SM_ENDPOINT" - assert ir.model_invocation_role == "" - assert ir.tile_format == "NITF" - assert ir.tile_compression == ImageCompression.NONE - assert ir.job_id == "5f4e8a55-95cf-4d96-95cd-9b037f767eff" - assert len(ir.outputs) == 2 - sinks = SinkFactory.outputs_to_sinks(ir.outputs) - s3_sink: Sink = sinks[0] - assert s3_sink.name() == "S3" - assert s3_sink.__getattribute__("bucket") == "fake-bucket" - assert s3_sink.__getattribute__("prefix") == "images/outputs" - kinesis_sink: Sink = sinks[1] - assert kinesis_sink.name() == "Kinesis" - assert kinesis_sink.__getattribute__("stream") == "FakeStream" - assert kinesis_sink.__getattribute__("batch_size") == 500 - assert ir.roi is None - - def test_image_request_invalid_sink(self): - from aws.osml.model_runner.api.exceptions import InvalidImageRequestException - from aws.osml.model_runner.api.image_request import ImageRequest - from aws.osml.model_runner.sink.sink_factory import SinkFactory - - updates: Dict[str, Any] = {"outputs": [{"type": "SQS", "queue": "FakeQueue"}]} - message_body = base_request.copy() - message_body.update(updates) - - with self.assertRaises(InvalidImageRequestException): - ir = ImageRequest.from_external_message(message_body) - SinkFactory.outputs_to_sinks(ir.outputs) - - def test_image_request_invalid_image_path(self): - from aws.osml.model_runner.api.exceptions import InvalidS3ObjectException - from aws.osml.model_runner.api.request_utils import validate_image_path - from aws.osml.model_runner.app_config import BotoConfig - - s3_client = boto3.client("s3", config=BotoConfig.default) - s3_client_stub = Stubber(s3_client) - s3_client_stub.activate() - - s3_client_stub.add_client_error( - "head_object", - service_error_code="404", - service_message="Not Found", - expected_params={"Bucket": TEST_S3_FULL_BUCKET_PATH}, - ) - - with pytest.raises(InvalidS3ObjectException): - validate_image_path(TEST_S3_FULL_BUCKET_PATH, None) - - -if __name__ == "__main__": - main()