Skip to content

Commit

Permalink
refactor: refatoring updating image_request.py and updating tests
Browse files Browse the repository at this point in the history
  • Loading branch information
drduhe committed Oct 28, 2024
1 parent 7f23832 commit 315181d
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 414 deletions.
249 changes: 154 additions & 95 deletions src/aws/osml/model_runner/api/image_request.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright 2023-2024 Amazon.com, Inc. or its affiliates.

import logging
from dataclasses import dataclass, field
from json import dumps, loads
from typing import Any, Dict, List, Optional

import shapely.geometry
import shapely.wkt
from dacite import from_dict
from shapely.geometry.base import BaseGeometry

from aws.osml.model_runner.common import (
Expand All @@ -21,148 +23,204 @@

from .inference import ModelInvokeMode
from .request_utils import shared_properties_are_valid
from .sink import SinkType
from .sink import VALID_SYNC_TYPES, SinkType

logger = logging.getLogger(__name__)


class ImageRequest(object):
@dataclass
class ImageRequest:
"""
Request for the Model Runner to process an image.
This class contains the attributes that make up an image processing request along with
This class contains the attributes that make up an image processing request, along with
constructors and factory methods used to create these requests from common constructs.
"""
def __init__(self, *initial_data: Dict[str, Any], **kwargs: Any):
"""
This constructor allows users to create these objects using a combination of dictionaries
and keyword arguments.
Attributes:
job_id: The unique identifier for the image processing job.
image_id: A combined identifier for the image, usually composed of the job ID and image URL.
image_url: The URL location of the image to be processed.
image_read_role: The IAM role used to read the image from the provided URL.
outputs: A list of output configurations where results should be stored.
model_name: The name of the model to use for image processing.
model_invoke_mode: The mode in which the model is invoked, such as synchronous or asynchronous.
tile_size: Dimensions of the tiles into which the image is split for processing.
tile_overlap: Overlap between tiles, defined in dimensions.
tile_format: The format of the tiles (e.g., NITF, GeoTIFF).
tile_compression: Compression type to use for the tiles (e.g., None, JPEG).
model_invocation_role: IAM role assumed for invoking the model.
feature_properties: Additional properties to include in the feature processing.
roi: Region of interest within the image, defined as a geometric shape.
post_processing: List of post-processing steps to apply to the features detected.
"""

:param initial_data: Dict[str, Any] = dictionaries that contain attributes/values that map to this class's
attributes
:param kwargs: Any = keyword arguments provided on the constructor to set specific attributes
"""
default_post_processing = [
job_id: str = ""
image_id: str = ""
image_url: str = ""
image_read_role: str = ""
outputs: List[Dict[str, Any]] = field(default_factory=list)
model_name: str = ""
model_invoke_mode: ModelInvokeMode = ModelInvokeMode.NONE
tile_size: ImageDimensions = (1024, 1024)
tile_overlap: ImageDimensions = (50, 50)
tile_format: str = ImageFormats.NITF.value
tile_compression: str = ImageCompression.NONE.value
model_invocation_role: str = ""
feature_properties: List[Dict[str, Any]] = field(default_factory=list)
roi: Optional[BaseGeometry] = None
post_processing: List[MRPostProcessing] = field(
default_factory=lambda: [
MRPostProcessing(step=MRPostprocessingStep.FEATURE_DISTILLATION, algorithm=FeatureDistillationNMS())
]
)

self.job_id: str = ""
self.image_id: str = ""
self.image_url: str = ""
self.image_read_role: str = ""
self.outputs: List[dict] = []
self.model_name: str = ""
self.model_invoke_mode: ModelInvokeMode = ModelInvokeMode.NONE
self.tile_size: ImageDimensions = (1024, 1024)
self.tile_overlap: ImageDimensions = (50, 50)
self.tile_format: ImageFormats = ImageFormats.NITF
self.tile_compression: ImageCompression = ImageCompression.NONE
self.model_invocation_role: str = ""
self.feature_properties: List[dict] = []
self.roi: Optional[BaseGeometry] = None
self.post_processing: List[MRPostProcessing] = default_post_processing

for dictionary in initial_data:
for key in dictionary:
setattr(self, key, dictionary[key])
for key in kwargs:
setattr(self, key, kwargs[key])
@staticmethod
def from_external_message(image_request: Dict[str, Any]) -> "ImageRequest":
"""
Constructs an ImageRequest from a dictionary that represents an external message.
:param image_request: Dictionary of values from the decoded JSON request.
:return: ImageRequest instance.
"""
properties: Dict[str, Any] = {
"job_id": image_request.get("jobId", ""),
"image_url": image_request.get("imageUrls", [""])[0],
"image_id": f"{image_request.get('jobId', '')}:{image_request.get('imageUrls', [''])[0]}",
"image_read_role": image_request.get("imageReadRole", ""),
"model_name": image_request["imageProcessor"]["name"],
"model_invoke_mode": ImageRequest._parse_model_invoke_mode(image_request["imageProcessor"].get("type")),
"model_invocation_role": image_request["imageProcessor"].get("assumedRole", ""),
"tile_size": ImageRequest._parse_tile_dimension(image_request.get("imageProcessorTileSize")),
"tile_overlap": ImageRequest._parse_tile_dimension(image_request.get("imageProcessorTileOverlap")),
"tile_format": ImageRequest._parse_tile_format(image_request.get("imageProcessorTileFormat")),
"tile_compression": ImageRequest._parse_tile_compression(image_request.get("imageProcessorTileCompression")),
"roi": ImageRequest._parse_roi(image_request.get("regionOfInterest")),
"outputs": ImageRequest._parse_outputs(image_request),
"feature_properties": image_request.get("featureProperties", []),
"post_processing": ImageRequest._parse_post_processing(image_request.get("postProcessing")),
}
return from_dict(ImageRequest, properties)

@staticmethod
def from_external_message(image_request: Dict[str, Any]):
def _parse_tile_dimension(value: Optional[str]) -> ImageDimensions:
"""
This method is used to construct an ImageRequest given a dictionary reconstructed from the
JSON representation of a request that appears on the Image Job Queue. The structure of
that message is generally governed by AWS API best practices and may evolve over time as
the public APIs for this service mature.
Converts a string value to a tuple of integers representing tile dimensions.
:param image_request: Dict[str, Any] = dictionary of values from the decoded JSON request
:param value: String value representing tile dimension.
:return: Tuple of integers as tile dimensions.
"""
return (int(value), int(value)) if value else None

:return: the ImageRequest
@staticmethod
def _parse_roi(roi: Optional[str]) -> Optional[BaseGeometry]:
"""
properties: Dict[str, Any] = {}
if "imageProcessorTileSize" in image_request:
tile_dimension = int(image_request["imageProcessorTileSize"])
properties["tile_size"] = (tile_dimension, tile_dimension)
Parses the region of interest from a WKT string.
if "imageProcessorTileOverlap" in image_request:
overlap_dimension = int(image_request["imageProcessorTileOverlap"])
properties["tile_overlap"] = (overlap_dimension, overlap_dimension)
:param roi: WKT string representing the region of interest.
:return: Parsed BaseGeometry object or None.
"""
return shapely.wkt.loads(roi) if roi else None

if "imageProcessorTileFormat" in image_request:
properties["tile_format"] = image_request["imageProcessorTileFormat"]
@staticmethod
def _parse_tile_format(tile_format: Optional[str]) -> Optional[ImageFormats]:
"""
Parses the region desired tile format to use for processing.
if "imageProcessorTileCompression" in image_request:
properties["tile_compression"] = image_request["imageProcessorTileCompression"]
:param tile_format: String representing the tile format to use.
:return: Parsed ImageFormats object or ImageFormats.NITF.
"""
return ImageFormats[tile_format].value if tile_format else ImageFormats.NITF.value

properties["job_id"] = image_request["jobId"]
@staticmethod
def _parse_tile_compression(tile_compression: Optional[str]) -> Optional[ImageCompression]:
"""
Parses the region desired tile compression format to use for processing.
properties["image_url"] = image_request["imageUrls"][0]
properties["image_id"] = image_request["jobId"] + ":" + properties["image_url"]
if "imageReadRole" in image_request:
properties["image_read_role"] = image_request["imageReadRole"]
:param tile_compression: String representing the tile compression format to use.
:return: Parsed ImageFormats object or ImageCompression.NONE.
"""
return ImageCompression[tile_compression].value if tile_compression else ImageCompression.NONE.value

properties["model_name"] = image_request["imageProcessor"]["name"]
properties["model_invoke_mode"] = image_request["imageProcessor"]["type"]
if "assumedRole" in image_request["imageProcessor"]:
properties["model_invocation_role"] = image_request["imageProcessor"]["assumedRole"]
@staticmethod
def _parse_model_invoke_mode(model_invoke_mode: Optional[str]) -> Optional[ModelInvokeMode]:
"""
Parses the region desired tile compression format to use for processing.
if "regionOfInterest" in image_request:
properties["roi"] = shapely.wkt.loads(image_request["regionOfInterest"])
:param model_invoke_mode: String representing the tile compression format to use.
:return: Parsed ModelInvokeMode object or ModelInvokeMode.SM_ENDPOINT.
"""
return ModelInvokeMode[model_invoke_mode] if model_invoke_mode else ModelInvokeMode.SM_ENDPOINT

# Support explicit outputs
@staticmethod
def _parse_outputs(image_request: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Parses the output configuration from the image request, including support for legacy inputs.
:param image_request: Dictionary of image request attributes.
:return: List of output configurations.
"""
if image_request.get("outputs"):
properties["outputs"] = image_request["outputs"]
# Support legacy image request
elif image_request.get("outputBucket") and image_request.get("outputPrefix"):
properties["outputs"] = [
return image_request["outputs"]

# Support legacy image request fields: outputBucket and outputPrefix
if image_request.get("outputBucket") and image_request.get("outputPrefix"):
return [
{
"type": SinkType.S3.value,
"bucket": image_request["outputBucket"],
"prefix": image_request["outputPrefix"],
}
]
if image_request.get("featureProperties"):
properties["feature_properties"] = image_request["featureProperties"]
if image_request.get("postProcessing"):
image_request["postProcessing"] = loads(
dumps(image_request["postProcessing"])
.replace("algorithmType", "algorithm_type")
.replace("iouThreshold", "iou_threshold")
.replace("skipBoxThreshold", "skip_box_threshold")
)
properties["post_processing"] = deserialize_post_processing_list(image_request.get("postProcessing"))

return ImageRequest(properties)
# No outputs were defined in the request
logger.warning("No output syncs were present in this request.")
return []

@staticmethod
def _parse_post_processing(post_processing: Optional[Dict[str, Any]]) -> List[MRPostProcessing]:
"""
Deserializes and cleans up post-processing data.
:param post_processing: Dictionary of post-processing configurations.
:return: List of MRPostProcessing instances.
"""
if not post_processing:
return [MRPostProcessing(step=MRPostprocessingStep.FEATURE_DISTILLATION, algorithm=FeatureDistillationNMS())]
cleaned_post_processing = loads(
dumps(post_processing)
.replace("algorithmType", "algorithm_type")
.replace("iouThreshold", "iou_threshold")
.replace("skipBoxThreshold", "skip_box_threshold")
)
return deserialize_post_processing_list(cleaned_post_processing)

def is_valid(self) -> bool:
"""
Check to see if this request contains required attributes and meaningful values
Validates whether the ImageRequest instance has all required attributes.
:return: bool = True if the request contains all the mandatory attributes with acceptable values,
False otherwise
:return: True if valid, False otherwise.
"""
if not shared_properties_are_valid(self):
logger.error("Invalid shared properties in ImageRequest")
return False

if not self.job_id or not self.outputs:
logger.error("Missing job id or outputs properties in ImageRequest")
if not self.job_id:
logger.error("Missing job id in ImageRequest")
return False

num_feature_detection_options = len(self.get_feature_distillation_option())
if num_feature_detection_options > 1:
logger.error(f"{num_feature_detection_options} feature distillation options in ImageRequest")
if len(self.get_feature_distillation_option()) > 1:
logger.error("Multiple feature distillation options in ImageRequest")
return False

if len(self.outputs) > 0:
for output in self.outputs:
sink_type = output.get("type")
if sink_type not in VALID_SYNC_TYPES:
logger.error(f"Invalid sink type '{sink_type}' in ImageRequest")
return False
return True

def get_shared_values(self) -> Dict[str, Any]:
"""
Returns a formatted dict that contains the properties of an image
Retrieves a dictionary of shared values related to the image.
:return: Dict[str, Any] = the properties of an image
:return: Dictionary of shared image properties.
"""
return {
"image_id": self.image_id,
Expand All @@ -180,8 +238,9 @@ def get_shared_values(self) -> Dict[str, Any]:

def get_feature_distillation_option(self) -> List[FeatureDistillationAlgorithm]:
"""
Parses the post-processing property and extracts the relevant feature distillation selection, if present
:return:
Extracts the feature distillation options from the post-processing configuration.
:return: List of FeatureDistillationAlgorithm instances.
"""
return [
op.algorithm
Expand Down
Loading

0 comments on commit 315181d

Please sign in to comment.