Skip to content

Commit

Permalink
feat: implement region status tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
drduhe committed Sep 18, 2024
1 parent bc3cdf7 commit b2c6d69
Show file tree
Hide file tree
Showing 31 changed files with 673 additions and 521 deletions.
2 changes: 1 addition & 1 deletion scripts/run_container.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ echo "/_/ /_/_/ |_| \____/\____/_/ /_/\__/\__,_/_/_/ /_/\___/_/ ";

# Inputs
PATTERN="${1:-"MRDataplane"}"
IMAGE_NAME=PATTERN="${2:-"osml-model-runner:local"}"
IMAGE_NAME="${2:-"osml-model-runner:local"}"
AWS_REGION="${3:-"us-west-2"}"

# Get the latest task definition ARN based on a string pattern
Expand Down
68 changes: 27 additions & 41 deletions src/aws/osml/model_runner/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
GeojsonDetectionField,
ImageDimensions,
ImageRegion,
ImageRequestStatus,
RegionRequestStatus,
RequestStatus,
ThreadingLocalContextFilter,
Timer,
build_embedded_metrics_config,
Expand All @@ -61,7 +60,7 @@
from .inference import FeatureSelector, calculate_processing_bounds, get_source_property
from .queue import RequestQueue
from .sink import SinkFactory
from .status import StatusMonitor
from .status import ImageStatusMonitor, RegionStatusMonitor
from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy, process_tiles, setup_tile_workers

# Set up metrics configuration
Expand Down Expand Up @@ -92,7 +91,8 @@ def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrate
self.endpoint_statistics_table = EndpointStatisticsTable(ServiceConfig.endpoint_statistics_table)
self.region_request_queue = RequestQueue(ServiceConfig.region_queue, wait_seconds=10)
self.region_requests_iter = iter(self.region_request_queue)
self.status_monitor = StatusMonitor()
self.image_status_monitor = ImageStatusMonitor()
self.region_status_monitor = RegionStatusMonitor()
self.elevation_model = ModelRunner.create_elevation_model()
self.endpoint_utils = EndpointUtils()
self.running = False
Expand Down Expand Up @@ -243,7 +243,7 @@ def monitor_work_queues(self) -> None:
minimal_job_item = JobItem(
image_id=min_image_id,
job_id=min_job_id,
processing_time=Decimal(0),
processing_duration=Decimal(0),
)
self.fail_image_request_send_messages(minimal_job_item, err)
self.image_request_queue.finish_request(receipt_handle)
Expand Down Expand Up @@ -295,7 +295,7 @@ def process_image_request(self, image_request: ImageRequest) -> None:

# Start the image processing
self.job_table.start_image_request(image_request_item)
self.status_monitor.process_event(image_request_item, ImageRequestStatus.STARTED, "Started image request")
self.image_status_monitor.process_event(image_request_item, RequestStatus.STARTED, "Started image request")

# Check we have a valid image request, throws if not
self.validate_model_hosting(image_request_item)
Expand Down Expand Up @@ -335,7 +335,7 @@ def process_image_request(self, image_request: ImageRequest) -> None:
# Update the image request job to have new derived image data
image_request_item = self.job_table.update_image_request(image_request_item)

self.status_monitor.process_event(image_request_item, ImageRequestStatus.IN_PROGRESS, "Processing regions")
self.image_status_monitor.process_event(image_request_item, RequestStatus.IN_PROGRESS, "Processing regions")

# Place the resulting region requests on the appropriate work queue
self.queue_region_request(all_regions, image_request, raster_dataset, sensor_model, image_extension)
Expand All @@ -348,7 +348,7 @@ def process_image_request(self, image_request: ImageRequest) -> None:
minimal_job_item = JobItem(
image_id=image_request.image_id,
job_id=image_request.job_id,
processing_time=Decimal(0),
processing_duration=Decimal(0),
)
self.fail_image_request(minimal_job_item, err)

Expand All @@ -374,7 +374,7 @@ def queue_region_request(
:param raster_dataset: Dataset = the raster dataset containing the region
:param sensor_model: Optional[SensorModel] = the sensor model for this raster dataset
:return None
:return: None
"""
# Set aside the first region
first_region = all_regions.pop(0)
Expand Down Expand Up @@ -506,23 +506,24 @@ def process_region_request(
tile_queue, tile_workers = setup_tile_workers(region_request, sensor_model, self.elevation_model)

# Process all our tiles
total_tile_count, tile_error_count = process_tiles(
total_tile_count, tile_error_count, failed_tiles = process_tiles(
self.tiling_strategy, region_request, tile_queue, tile_workers, raster_dataset, sensor_model
)

# Update table w/ total tile counts
region_request_item.total_tiles = Decimal(total_tile_count)
region_request_item.completed_tiles = Decimal(total_tile_count - tile_error_count)
region_request_item.failed_tiles = Decimal(tile_error_count)
region_request_item.failed_tiles = failed_tiles
region_request_item = self.region_request_table.update_region_request(region_request_item)

# Update the image request to complete this region
image_request_item = self.job_table.complete_region_request(region_request.image_id, bool(tile_error_count))

# Update region request table if that region succeeded
region_request_item = self.region_request_table.complete_region_request(
region_request_item, self.calculate_region_status(total_tile_count, tile_error_count)
)
region_status = self.region_status_monitor.get_status(region_request_item)
region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status)

self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing")

# Write CloudWatch Metrics to the Logs
if isinstance(metrics, MetricsLogger):
Expand Down Expand Up @@ -631,7 +632,7 @@ def fail_image_request_send_messages(self, image_request_item: JobItem, err: Exc
:return: None
"""
logger.exception(f"Failed to start image processing!: {err}")
self.status_monitor.process_event(image_request_item, ImageRequestStatus.FAILED, str(err))
self.image_status_monitor.process_event(image_request_item, RequestStatus.FAILED, str(err))

def complete_image_request(
self, region_request: RegionRequest, image_format: str, raster_dataset: gdal.Dataset, sensor_model: SensorModel
Expand All @@ -641,8 +642,8 @@ def complete_image_request(
completion logic for the associated ImageRequest.
:param region_request: RegionRequest = the region request to update.
:param image_format: the format of the image
:param raster_dataset: the image data
:param image_format: Format of the image data
:param raster_dataset: the image data rater
:param sensor_model: the image sensor model
:return: None
Expand Down Expand Up @@ -679,9 +680,9 @@ def complete_image_request(

# Ensure we have a valid start time for our record
# TODO: Figure out why we wouldn't have a valid start time?!?!
if completed_image_request_item.processing_time is not None:
image_request_status = self.status_monitor.get_image_request_status(completed_image_request_item)
self.status_monitor.process_event(
if completed_image_request_item.processing_duration is not None:
image_request_status = self.image_status_monitor.get_status(completed_image_request_item)
self.image_status_monitor.process_event(
completed_image_request_item, image_request_status, "Completed image processing"
)
self.generate_image_processing_metrics(completed_image_request_item, image_format)
Expand Down Expand Up @@ -716,7 +717,7 @@ def generate_image_processing_metrics(
}
)

metrics.put_metric(MetricLabels.DURATION, float(image_request_item.processing_time), str(Unit.SECONDS.value))
metrics.put_metric(MetricLabels.DURATION, float(image_request_item.processing_duration), str(Unit.SECONDS.value))
metrics.put_metric(MetricLabels.INVOCATIONS, 1, str(Unit.COUNT.value))
if image_request_item.region_error > 0:
metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value))
Expand All @@ -738,9 +739,9 @@ def fail_region_request(
if isinstance(metrics, MetricsLogger):
metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value))
try:
region_request_item = self.region_request_table.complete_region_request(
region_request_item, RegionRequestStatus.FAILED
)
region_status = RequestStatus.FAILED
region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status)
self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing")
return self.job_table.complete_region_request(region_request_item.image_id, error=True)
except Exception as status_error:
logger.error("Unable to update region status in job table")
Expand All @@ -757,9 +758,9 @@ def validate_model_hosting(self, image_request: JobItem):
"""
if not image_request.model_invoke_mode or image_request.model_invoke_mode not in VALID_MODEL_HOSTING_OPTIONS:
error = f"Application only supports ${VALID_MODEL_HOSTING_OPTIONS} Endpoints"
self.status_monitor.process_event(
self.image_status_monitor.process_event(
image_request,
ImageRequestStatus.FAILED,
RequestStatus.FAILED,
error,
)
raise UnsupportedModelException(error)
Expand Down Expand Up @@ -1003,18 +1004,3 @@ def get_extents(ds: gdal.Dataset, sm: SensorModel) -> Dict[str, Any]:
}
except Exception as e:
logger.error(f"Error in getting extents: {e}")

@staticmethod
def calculate_region_status(total_tile_count: int, tile_error_count: int) -> RegionRequestStatus:
"""
Calculate the processing status of a region upon completion
:param total_tile_count: number of tiles that were processed
:param tile_error_count: number of tiles with errors
:return: RegionRequestStatus
"""
region_status = RegionRequestStatus.SUCCESS
if total_tile_count == tile_error_count:
region_status = RegionRequestStatus.FAILED
if 0 < tile_error_count < total_tile_count:
region_status = RegionRequestStatus.PARTIAL
return region_status
3 changes: 1 addition & 2 deletions src/aws/osml/model_runner/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,5 @@
ImageDimensions,
ImageFormats,
ImageRegion,
ImageRequestStatus,
RegionRequestStatus,
RequestStatus,
)
14 changes: 1 addition & 13 deletions src/aws/osml/model_runner/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ImageFormats(str, AutoStringEnum):
GTIFF = auto()


class ImageRequestStatus(str, AutoStringEnum):
class RequestStatus(str, AutoStringEnum):
"""
Enumeration defining the image request status
"""
Expand All @@ -47,18 +47,6 @@ class ImageRequestStatus(str, AutoStringEnum):
FAILED = auto()


class RegionRequestStatus(str, AutoStringEnum):
"""
Enumeration defining status for region
"""

STARTING = auto()
PARTIAL = auto()
IN_PROGRESS = auto()
SUCCESS = auto()
FAILED = auto()


class GeojsonDetectionField(str, Enum):
"""
Enumeration defining the model geojson field to index depending on the shape
Expand Down
33 changes: 22 additions & 11 deletions src/aws/osml/model_runner/database/job_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,31 @@
@dataclass
class JobItem(DDBItem):
"""
JobItem is a dataclass meant to represent a single item in the JobStatus table
JobItem is a dataclass meant to represent a single item in the JobStatus table.
The data schema is defined as follows:
image_id: str = unique image_id for the job
image_id: str = unique identifier for the image associated with the job
job_id: Optional[str] = unique identifier for the job
image_url: Optional[str] = S3 URL or another source location for the image
image_read_role: Optional[str] = IAM role ARN for accessing the image from its source
model_invoke_mode: Optional[str] = mode in which the model is invoked (e.g., batch or streaming)
start_time: Optional[Decimal] = time in epoch milliseconds when the job started
expire_time: Optional[Decimal] = time in epoch seconds when the job will expire
end_time: Optional[Decimal] = time in epoch milliseconds when the job ended
region_success: Optional[Decimal] = current count of regions that have succeeded for this image
region_error: Optional[Decimal] = current count of regions that have errored for this image
region_success: Optional[Decimal] = current count of regions that have successfully processed for this image
region_error: Optional[Decimal] = current count of regions that have errored during processing
region_count: Optional[Decimal] = total count of regions expected for this image
width: Optional[Decimal] = width of the image
height: Optional[Decimal] = height of the image
feature_distillation_options: Optional[str] = the options used in selecting features (NMS/SOFT_NMS, thresholds, etc.)
roi_wkt: a Well Known Text representation of the requested processing bounds
width: Optional[Decimal] = width of the image in pixels
height: Optional[Decimal] = height of the image in pixels
extents: Optional[str] = string representation of the image extents
tile_size: Optional[str] = size of the tiles used during processing
tile_overlap: Optional[str] = overlap between tiles during processing
model_name: Optional[str] = name of the model used for processing
outputs: Optional[str] = details about the job output
processing_duration: Optional[Decimal] = time in seconds taken to complete processing
feature_properties: Optional[str] = additional feature properties or metadata from the image processing
feature_distillation_option: Optional[str] = the options used in selecting features (e.g., NMS/SOFT_NMS, thresholds)
roi_wkt: Optional[str] = a Well-Known Text (WKT) representation of the requested processing bounds
"""

image_id: str
Expand All @@ -54,7 +65,7 @@ class JobItem(DDBItem):
tile_overlap: Optional[str] = None
model_name: Optional[str] = None
outputs: Optional[str] = None
processing_time: Optional[Decimal] = None
processing_duration: Optional[Decimal] = None
feature_properties: Optional[str] = None
feature_distillation_option: Optional[str] = None
roi_wkt: Optional[str] = None
Expand Down Expand Up @@ -96,7 +107,7 @@ def start_image_request(self, image_request_item: JobItem) -> JobItem:

# Update the job item to have the correct start parameters
image_request_item.start_time = start_time_millisec
image_request_item.processing_time = Decimal(0)
image_request_item.processing_duration = Decimal(0)
image_request_item.expire_time = expire_time_epoch_sec
image_request_item.region_success = Decimal(0)
image_request_item.region_error = Decimal(0)
Expand Down Expand Up @@ -212,7 +223,7 @@ def update_image_request(self, image_request_item: JobItem) -> JobItem:
"""
# Update the processing time on our message
if image_request_item.start_time is not None:
image_request_item.processing_time = self.get_processing_time(image_request_item.start_time)
image_request_item.processing_duration = self.get_processing_time(image_request_item.start_time)

return from_dict(JobItem, self.update_ddb_item(image_request_item))

Expand Down
Loading

0 comments on commit b2c6d69

Please sign in to comment.