From 4445a8451e7855edfac28453863a64e7ec29d97f Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Wed, 20 Mar 2024 18:57:24 +0200 Subject: [PATCH 1/4] Fix validations for boxes from points task creation --- .../src/handlers/job_creation.py | 27 +++++++++---------- .../exchange-oracle/src/utils/annotations.py | 1 + 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py index 3e63a04de6..f71e7aac09 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py @@ -180,7 +180,7 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int): self.embedded_point_color = (0, 255, 255) self.oracle_data_bucket = BucketAccessInfo.parse_obj(Config.storage_config) - self.min_class_samples_for_roi_estimation = 50 + self.min_class_samples_for_roi_estimation = 25 self.max_discarded_threshold = 0.05 """ @@ -253,7 +253,7 @@ def _validate_gt_labels(self): if gt_labels - manifest_labels: raise DatasetValidationError( "GT labels do not match job labels. Unknown labels: {}".format( - self._format_list(gt_labels - manifest_labels), + self._format_list(list(gt_labels - manifest_labels)), ) ) @@ -296,11 +296,10 @@ def _validate_gt_annotations(self): sample_boxes = [a for a in gt_sample.annotations if isinstance(a, dm.Bbox)] valid_boxes = [] for bbox in sample_boxes: - if (0 <= bbox.x < bbox.x + bbox.w < img_w) and ( - 0 <= bbox.y < bbox.y + bbox.h < img_h + if not ( + (0 <= bbox.x < bbox.x + bbox.w <= img_w) + and (0 <= bbox.y < bbox.y + bbox.h <= img_h) ): - valid_boxes.append(bbox) - else: excluded_gt_info.add_error( "Sample '{}': GT bbox #{} ({}) - invalid coordinates. " "The image will be skipped".format( @@ -633,6 +632,13 @@ def _prepare_gt(self): gt_dataset.put(gt_sample.wrap(annotations=matched_boxes)) + if excluded_gt_info.excluded_count: + self.logger.warning( + "Some GT annotations were excluded due to the errors found: {}".format( + self._format_list([e.message for e in excluded_gt_info.errors], separator="\n") + ) + ) + if ( excluded_gt_info.excluded_count > excluded_gt_info.total_count * self.max_discarded_threshold @@ -645,13 +651,6 @@ def _prepare_gt(self): ) ) - if excluded_gt_info.excluded_count: - self.logger.warning( - "Some GT annotations were excluded due to the errors found: {}".format( - self._format_list([e.message for e in excluded_gt_info.errors], separator="\n") - ) - ) - gt_labels_without_anns = [ gt_label_cat[label_id] for label_id, label_count in gt_count_per_class.items() @@ -827,10 +826,8 @@ def _upload_task_meta(self): ) storage_client = self._make_cloud_storage_client(self.oracle_data_bucket) - bucket_name = self.oracle_data_bucket.bucket_name for file_data, filename in file_list: storage_client.create_file( - bucket_name, compose_data_bucket_filename(self.escrow_address, self.chain_id, filename), file_data, ) diff --git a/packages/examples/cvat/exchange-oracle/src/utils/annotations.py b/packages/examples/cvat/exchange-oracle/src/utils/annotations.py index 41d1bb6fa8..2dbf33d88c 100644 --- a/packages/examples/cvat/exchange-oracle/src/utils/annotations.py +++ b/packages/examples/cvat/exchange-oracle/src/utils/annotations.py @@ -142,6 +142,7 @@ def remove_duplicated_gt_frames(dataset: dm.Dataset, known_frames: Sequence[str] T = TypeVar("T", bound=dm.Annotation) + def shift_ann(ann: T, offset_x: float, offset_y: float, *, img_w: int, img_h: int) -> T: "Shift annotation coordinates with clipping to the image size" From ecc93af724ba782bc43ae647e4564fd226d16090 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 22 Mar 2024 14:57:53 +0200 Subject: [PATCH 2/4] Clean code --- .../examples/cvat/exchange-oracle/src/handlers/job_creation.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py index f71e7aac09..44852b2ce2 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py @@ -1078,9 +1078,6 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int): self.roi_background_color = (245, 240, 242) # BGR - CVAT background color self.oracle_data_bucket = BucketAccessInfo.parse_obj(Config.storage_config) - # TODO: add - # credentials=BucketCredentials() - "Exchange Oracle's private bucket info" self.min_label_gt_samples = 2 # TODO: find good threshold From cbc625bff8109aaf2089fda10a4cb81c9d458c20 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 22 Mar 2024 14:58:48 +0200 Subject: [PATCH 3/4] disable roi estimation for unreliable cases --- .../src/handlers/job_creation.py | 77 ++++++++++++++++--- 1 file changed, 65 insertions(+), 12 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py index 44852b2ce2..dc75a2e272 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py @@ -178,10 +178,19 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int): self.min_embedded_point_radius_percent = 0.005 self.max_embedded_point_radius_percent = 0.01 self.embedded_point_color = (0, 255, 255) + self.roi_background_color = (245, 240, 242) # BGR - CVAT background color self.oracle_data_bucket = BucketAccessInfo.parse_obj(Config.storage_config) + self.min_class_samples_for_roi_estimation = 25 + self.max_class_roi_image_side_threshold = 0.5 + """ + The maximum allowed percent of the image for the estimated class RoI, + before the default RoI is used. Too big RoI estimations reduce the overall + prediction quality, making them unreliable. + """ + self.max_discarded_threshold = 0.05 """ The maximum allowed percent of discarded @@ -689,25 +698,39 @@ def _estimate_roi_sizes(self): # For big enough datasets, it should be reasonable approximation # (due to the central limit theorem). This can work bad for small datasets, # so we only do this if there are enough class samples. - classes_with_default_roi = [] + classes_with_default_roi: dict[int, str] = {} # label_id -> reason roi_size_estimations_per_label = {} # label id -> (w, h) + default_roi_size = (2, 2) # 2 will yield just the image size after halving for label_id, label_sizes in bbox_sizes_per_label.items(): if len(label_sizes) < self.min_class_samples_for_roi_estimation: - classes_with_default_roi.append(label_id) - estimated_size = (2, 2) # 2 will yield just the image size after halving + estimated_size = default_roi_size + classes_with_default_roi[label_id] = "too few GT provided" else: max_bbox = np.max(label_sizes, axis=0) - estimated_size = max_bbox * self.roi_size_mult + if np.any(max_bbox > self.max_class_roi_image_side_threshold): + estimated_size = default_roi_size + classes_with_default_roi[label_id] = "estimated RoI is unreliable" + else: + estimated_size = 2 * max_bbox * self.roi_size_mult roi_size_estimations_per_label[label_id] = estimated_size if classes_with_default_roi: label_cat = self._gt_dataset.categories()[dm.AnnotationType.label] + labels_by_reason = { + g_reason: list(v[0] for v in g_items) + for g_reason, g_items in groupby( + sorted(classes_with_default_roi.items(), key=lambda v: v[1]), key=lambda v: v[1] + ) + } self.logger.warning( - "Some classes will use the full image instead of RoI" - "- too few GT provided: {}".format( - self._format_list( - [label_cat[label_id].name for label_id in classes_with_default_roi] + "Some classes will use the full image instead of RoI - {}".format( + "; ".join( + "{}: {}".format( + g_reason, + self._format_list([label_cat[label_id].name for label_id in g_labels]), + ) + for g_reason, g_labels in labels_by_reason.items() ) ) ) @@ -832,6 +855,38 @@ def _upload_task_meta(self): file_data, ) + def _extract_roi( + self, source_pixels: np.ndarray, roi_info: boxes_from_points_task.RoiInfo + ) -> np.ndarray: + img_h, img_w, *_ = source_pixels.shape + + roi_pixels = source_pixels[ + max(0, roi_info.roi_y) : min(img_h, roi_info.roi_y + roi_info.roi_h), + max(0, roi_info.roi_x) : min(img_w, roi_info.roi_x + roi_info.roi_w), + ] + + if not ( + (0 <= roi_info.roi_x < roi_info.roi_x + roi_info.roi_w < img_w) + and (0 <= roi_info.roi_y < roi_info.roi_y + roi_info.roi_h < img_h) + ): + # Coords can be outside the original image + # In this case a border should be added to RoI, so that the image was centered on bbox + wrapped_roi_pixels = np.zeros((roi_info.roi_h, roi_info.roi_w, 3), dtype=np.float32) + wrapped_roi_pixels[:, :] = self.roi_background_color + + dst_y = max(-roi_info.roi_y, 0) + dst_x = max(-roi_info.roi_x, 0) + wrapped_roi_pixels[ + dst_y : dst_y + roi_pixels.shape[0], + dst_x : dst_x + roi_pixels.shape[1], + ] = roi_pixels + + roi_pixels = wrapped_roi_pixels + else: + roi_pixels = roi_pixels.copy() + + return roi_pixels + def _draw_roi_point( self, roi_pixels: np.ndarray, roi_info: boxes_from_points_task.RoiInfo ) -> np.ndarray: @@ -865,6 +920,7 @@ def _draw_roi_point( def _extract_and_upload_rois(self): # TODO: maybe optimize via splitting into separate threads (downloading, uploading, processing) + # Watch for the memory used, as the whole dataset can be quite big (gigabytes, terabytes) # Consider also packing RoIs cut into archives assert self._points_dataset is not _unset @@ -910,10 +966,7 @@ def _extract_and_upload_rois(self): image_rois = {} for roi_info in image_roi_infos: - roi_pixels = image_pixels[ - roi_info.roi_y : roi_info.roi_y + roi_info.roi_h, - roi_info.roi_x : roi_info.roi_x + roi_info.roi_w, - ] + roi_pixels = self._extract_roi(image_pixels, roi_info) if self.embed_point_in_roi_image: roi_pixels = self._draw_roi_point(roi_pixels, roi_info) From 71cbf1aadbf8d8021107d242d4fd2a6756918452 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 22 Mar 2024 21:24:47 +0200 Subject: [PATCH 4/4] Fix manifest parsing --- .../cvat/recording-oracle/src/handlers/validation.py | 5 +++-- .../examples/cvat/recording-oracle/src/utils/assignments.py | 6 ------ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/packages/examples/cvat/recording-oracle/src/handlers/validation.py b/packages/examples/cvat/recording-oracle/src/handlers/validation.py index 24e73de7f4..b46c3ff13e 100644 --- a/packages/examples/cvat/recording-oracle/src/handlers/validation.py +++ b/packages/examples/cvat/recording-oracle/src/handlers/validation.py @@ -30,7 +30,8 @@ from src.log import ROOT_LOGGER_NAME from src.services.cloud import make_client as make_cloud_client from src.services.cloud.utils import BucketAccessInfo -from src.utils.assignments import compute_resulting_annotations_hash, parse_manifest +from src.core.manifest import parse_manifest +from src.utils.assignments import compute_resulting_annotations_hash from src.utils.logging import NullLogger, get_function_logger module_logger_name = f"{ROOT_LOGGER_NAME}.cron.webhook" @@ -196,7 +197,7 @@ def _handle_validation_result(self, validation_result: ValidationResult): # TODO: update wrt. M2 API changes, send reason rejected_job_ids=list( jid - for jid, reason in validation_result.rejected_jobs + for jid, reason in validation_result.rejected_jobs.items() if not isinstance( reason, TooFewGtError ) # prevent such jobs from reannotation, can also be handled in ExcOr diff --git a/packages/examples/cvat/recording-oracle/src/utils/assignments.py b/packages/examples/cvat/recording-oracle/src/utils/assignments.py index dc3da37e50..1453425d8d 100644 --- a/packages/examples/cvat/recording-oracle/src/utils/assignments.py +++ b/packages/examples/cvat/recording-oracle/src/utils/assignments.py @@ -1,11 +1,5 @@ from hashlib import sha256 -from src.core.manifest import TaskManifest - - -def parse_manifest(manifest: dict) -> TaskManifest: - return TaskManifest.parse_obj(manifest) - def compute_resulting_annotations_hash(data: bytes) -> str: return sha256(data, usedforsecurity=False).hexdigest()