Skip to content

Commit

Permalink
Merge pull request #1747 from humanprotocol/zm/m2-experiment
Browse files Browse the repository at this point in the history
[CVAT-M2] Fix validations for boxes from points task creation
  • Loading branch information
zhiltsov-max authored Mar 25, 2024
2 parents 14f805f + 71cbf1a commit ef1b003
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 38 deletions.
107 changes: 77 additions & 30 deletions packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,18 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int):
self.min_embedded_point_radius_percent = 0.005
self.max_embedded_point_radius_percent = 0.01
self.embedded_point_color = (0, 255, 255)
self.roi_background_color = (245, 240, 242) # BGR - CVAT background color

self.oracle_data_bucket = BucketAccessInfo.parse_obj(Config.storage_config)
self.min_class_samples_for_roi_estimation = 50

self.min_class_samples_for_roi_estimation = 25

self.max_class_roi_image_side_threshold = 0.5
"""
The maximum allowed percent of the image for the estimated class RoI,
before the default RoI is used. Too big RoI estimations reduce the overall
prediction quality, making them unreliable.
"""

self.max_discarded_threshold = 0.05
"""
Expand Down Expand Up @@ -253,7 +262,7 @@ def _validate_gt_labels(self):
if gt_labels - manifest_labels:
raise DatasetValidationError(
"GT labels do not match job labels. Unknown labels: {}".format(
self._format_list(gt_labels - manifest_labels),
self._format_list(list(gt_labels - manifest_labels)),
)
)

Expand Down Expand Up @@ -296,11 +305,10 @@ def _validate_gt_annotations(self):
sample_boxes = [a for a in gt_sample.annotations if isinstance(a, dm.Bbox)]
valid_boxes = []
for bbox in sample_boxes:
if (0 <= bbox.x < bbox.x + bbox.w < img_w) and (
0 <= bbox.y < bbox.y + bbox.h < img_h
if not (
(0 <= bbox.x < bbox.x + bbox.w <= img_w)
and (0 <= bbox.y < bbox.y + bbox.h <= img_h)
):
valid_boxes.append(bbox)
else:
excluded_gt_info.add_error(
"Sample '{}': GT bbox #{} ({}) - invalid coordinates. "
"The image will be skipped".format(
Expand Down Expand Up @@ -633,6 +641,13 @@ def _prepare_gt(self):

gt_dataset.put(gt_sample.wrap(annotations=matched_boxes))

if excluded_gt_info.excluded_count:
self.logger.warning(
"Some GT annotations were excluded due to the errors found: {}".format(
self._format_list([e.message for e in excluded_gt_info.errors], separator="\n")
)
)

if (
excluded_gt_info.excluded_count
> excluded_gt_info.total_count * self.max_discarded_threshold
Expand All @@ -645,13 +660,6 @@ def _prepare_gt(self):
)
)

if excluded_gt_info.excluded_count:
self.logger.warning(
"Some GT annotations were excluded due to the errors found: {}".format(
self._format_list([e.message for e in excluded_gt_info.errors], separator="\n")
)
)

gt_labels_without_anns = [
gt_label_cat[label_id]
for label_id, label_count in gt_count_per_class.items()
Expand Down Expand Up @@ -690,25 +698,39 @@ def _estimate_roi_sizes(self):
# For big enough datasets, it should be reasonable approximation
# (due to the central limit theorem). This can work bad for small datasets,
# so we only do this if there are enough class samples.
classes_with_default_roi = []
classes_with_default_roi: dict[int, str] = {} # label_id -> reason
roi_size_estimations_per_label = {} # label id -> (w, h)
default_roi_size = (2, 2) # 2 will yield just the image size after halving
for label_id, label_sizes in bbox_sizes_per_label.items():
if len(label_sizes) < self.min_class_samples_for_roi_estimation:
classes_with_default_roi.append(label_id)
estimated_size = (2, 2) # 2 will yield just the image size after halving
estimated_size = default_roi_size
classes_with_default_roi[label_id] = "too few GT provided"
else:
max_bbox = np.max(label_sizes, axis=0)
estimated_size = max_bbox * self.roi_size_mult
if np.any(max_bbox > self.max_class_roi_image_side_threshold):
estimated_size = default_roi_size
classes_with_default_roi[label_id] = "estimated RoI is unreliable"
else:
estimated_size = 2 * max_bbox * self.roi_size_mult

roi_size_estimations_per_label[label_id] = estimated_size

if classes_with_default_roi:
label_cat = self._gt_dataset.categories()[dm.AnnotationType.label]
labels_by_reason = {
g_reason: list(v[0] for v in g_items)
for g_reason, g_items in groupby(
sorted(classes_with_default_roi.items(), key=lambda v: v[1]), key=lambda v: v[1]
)
}
self.logger.warning(
"Some classes will use the full image instead of RoI"
"- too few GT provided: {}".format(
self._format_list(
[label_cat[label_id].name for label_id in classes_with_default_roi]
"Some classes will use the full image instead of RoI - {}".format(
"; ".join(
"{}: {}".format(
g_reason,
self._format_list([label_cat[label_id].name for label_id in g_labels]),
)
for g_reason, g_labels in labels_by_reason.items()
)
)
)
Expand Down Expand Up @@ -827,14 +849,44 @@ def _upload_task_meta(self):
)

storage_client = self._make_cloud_storage_client(self.oracle_data_bucket)
bucket_name = self.oracle_data_bucket.bucket_name
for file_data, filename in file_list:
storage_client.create_file(
bucket_name,
compose_data_bucket_filename(self.escrow_address, self.chain_id, filename),
file_data,
)

def _extract_roi(
self, source_pixels: np.ndarray, roi_info: boxes_from_points_task.RoiInfo
) -> np.ndarray:
img_h, img_w, *_ = source_pixels.shape

roi_pixels = source_pixels[
max(0, roi_info.roi_y) : min(img_h, roi_info.roi_y + roi_info.roi_h),
max(0, roi_info.roi_x) : min(img_w, roi_info.roi_x + roi_info.roi_w),
]

if not (
(0 <= roi_info.roi_x < roi_info.roi_x + roi_info.roi_w < img_w)
and (0 <= roi_info.roi_y < roi_info.roi_y + roi_info.roi_h < img_h)
):
# Coords can be outside the original image
# In this case a border should be added to RoI, so that the image was centered on bbox
wrapped_roi_pixels = np.zeros((roi_info.roi_h, roi_info.roi_w, 3), dtype=np.float32)
wrapped_roi_pixels[:, :] = self.roi_background_color

dst_y = max(-roi_info.roi_y, 0)
dst_x = max(-roi_info.roi_x, 0)
wrapped_roi_pixels[
dst_y : dst_y + roi_pixels.shape[0],
dst_x : dst_x + roi_pixels.shape[1],
] = roi_pixels

roi_pixels = wrapped_roi_pixels
else:
roi_pixels = roi_pixels.copy()

return roi_pixels

def _draw_roi_point(
self, roi_pixels: np.ndarray, roi_info: boxes_from_points_task.RoiInfo
) -> np.ndarray:
Expand Down Expand Up @@ -868,6 +920,7 @@ def _draw_roi_point(

def _extract_and_upload_rois(self):
# TODO: maybe optimize via splitting into separate threads (downloading, uploading, processing)

# Watch for the memory used, as the whole dataset can be quite big (gigabytes, terabytes)
# Consider also packing RoIs cut into archives
assert self._points_dataset is not _unset
Expand Down Expand Up @@ -913,10 +966,7 @@ def _extract_and_upload_rois(self):

image_rois = {}
for roi_info in image_roi_infos:
roi_pixels = image_pixels[
roi_info.roi_y : roi_info.roi_y + roi_info.roi_h,
roi_info.roi_x : roi_info.roi_x + roi_info.roi_w,
]
roi_pixels = self._extract_roi(image_pixels, roi_info)

if self.embed_point_in_roi_image:
roi_pixels = self._draw_roi_point(roi_pixels, roi_info)
Expand Down Expand Up @@ -1081,9 +1131,6 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int):
self.roi_background_color = (245, 240, 242) # BGR - CVAT background color

self.oracle_data_bucket = BucketAccessInfo.parse_obj(Config.storage_config)
# TODO: add
# credentials=BucketCredentials()
"Exchange Oracle's private bucket info"

self.min_label_gt_samples = 2 # TODO: find good threshold

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def remove_duplicated_gt_frames(dataset: dm.Dataset, known_frames: Sequence[str]

T = TypeVar("T", bound=dm.Annotation)


def shift_ann(ann: T, offset_x: float, offset_y: float, *, img_w: int, img_h: int) -> T:
"Shift annotation coordinates with clipping to the image size"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from src.log import ROOT_LOGGER_NAME
from src.services.cloud import make_client as make_cloud_client
from src.services.cloud.utils import BucketAccessInfo
from src.utils.assignments import compute_resulting_annotations_hash, parse_manifest
from src.core.manifest import parse_manifest
from src.utils.assignments import compute_resulting_annotations_hash
from src.utils.logging import NullLogger, get_function_logger

module_logger_name = f"{ROOT_LOGGER_NAME}.cron.webhook"
Expand Down Expand Up @@ -196,7 +197,7 @@ def _handle_validation_result(self, validation_result: ValidationResult):
# TODO: update wrt. M2 API changes, send reason
rejected_job_ids=list(
jid
for jid, reason in validation_result.rejected_jobs
for jid, reason in validation_result.rejected_jobs.items()
if not isinstance(
reason, TooFewGtError
) # prevent such jobs from reannotation, can also be handled in ExcOr
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
from hashlib import sha256

from src.core.manifest import TaskManifest


def parse_manifest(manifest: dict) -> TaskManifest:
return TaskManifest.parse_obj(manifest)


def compute_resulting_annotations_hash(data: bytes) -> str:
return sha256(data, usedforsecurity=False).hexdigest()

0 comments on commit ef1b003

Please sign in to comment.