From db31a50eae9d7de43bf66be805049e16d50cfbbf Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 28 Nov 2024 15:52:34 +0100 Subject: [PATCH] Relax bounding box requirements for model training (#8222) * Remove requirements of bounding boxes to have the same x/y dimension as well as the same dimensions overall for model training * update changelog * Show warnings if bounding boxes for CNN training are suboptimal. Make errors and warnings more prominent by using Alerts. Include annotation ID and topleft + size for offending boxes. * Change map to forEach * Include bounding box name in warnings --- CHANGELOG.unreleased.md | 1 + .../oxalis/view/jobs/train_ai_model.tsx | 169 +++++++++++++----- 2 files changed, 125 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.unreleased.md b/CHANGELOG.unreleased.md index 72a07b2592..ce95f2c697 100644 --- a/CHANGELOG.unreleased.md +++ b/CHANGELOG.unreleased.md @@ -20,6 +20,7 @@ For upgrade instructions, please check the [migration guide](MIGRATIONS.released - Terms of Service for Webknossos are now accepted at registration, not afterward. [#8193](https://github.com/scalableminds/webknossos/pull/8193) - Removed bounding box size restriction for inferral jobs for super users. [#8200](https://github.com/scalableminds/webknossos/pull/8200) - Improved logging for errors when loading datasets and problems arise during a conversion step. [#8202](https://github.com/scalableminds/webknossos/pull/8202) +- Allowed to train an AI model using differently sized bounding boxes. We recommend all bounding boxes to have equal dimensions or to have dimensions which are multiples of the smallest bounding box. [#8222](https://github.com/scalableminds/webknossos/pull/8222) ### Fixed - Fix performance bottleneck when deleting a lot of trees at once. [#8176](https://github.com/scalableminds/webknossos/pull/8176) diff --git a/frontend/javascripts/oxalis/view/jobs/train_ai_model.tsx b/frontend/javascripts/oxalis/view/jobs/train_ai_model.tsx index f79f9b5d96..97aa43ae16 100644 --- a/frontend/javascripts/oxalis/view/jobs/train_ai_model.tsx +++ b/frontend/javascripts/oxalis/view/jobs/train_ai_model.tsx @@ -34,11 +34,11 @@ import _ from "lodash"; import BoundingBox from "oxalis/model/bucket_data_handling/bounding_box"; import { formatVoxels } from "libs/format_utils"; import * as Utils from "libs/utils"; -import { V3 } from "libs/mjs"; import type { APIAnnotation, APIDataset, ServerVolumeTracing } from "types/api_flow_types"; -import type { Vector3 } from "oxalis/constants"; +import type { Vector3, Vector6 } from "oxalis/constants"; import { serverVolumeToClientVolumeTracing } from "oxalis/model/reducers/volumetracing_reducer"; import { convertUserBoundingBoxesFromServerToFrontend } from "oxalis/model/reducers/reducer_helpers"; +import { computeArrayFromBoundingBox } from "libs/utils"; const { TextArea } = Input; const FormItem = Form.Item; @@ -66,8 +66,8 @@ enum AiModelCategory { const ExperimentalWarning = () => ( @@ -217,19 +217,29 @@ export function TrainAiModelTab userBoundingBoxes); + const userBoundingBoxes = annotationInfos.flatMap(({ userBoundingBoxes, annotation }) => + userBoundingBoxes.map((box) => ({ + ...box, + annotationId: "id" in annotation ? annotation.id : annotation.annotationId, + })), + ); const bboxesVoxelCount = _.sum( (userBoundingBoxes || []).map((bbox) => new BoundingBox(bbox.boundingBox).getVolume()), ); - const { areSomeAnnotationsInvalid, invalidAnnotationsReason } = - areInvalidAnnotationsIncluded(annotationInfos); - const { areSomeBBoxesInvalid, invalidBBoxesReason } = - areInvalidBoundingBoxesIncluded(userBoundingBoxes); - const invalidReasons = [invalidAnnotationsReason, invalidBBoxesReason] - .filter((reason) => reason) - .join("\n"); + const { hasAnnotationErrors, errors: annotationErrors } = + checkAnnotationsForErrorsAndWarnings(annotationInfos); + const { + hasBBoxErrors, + hasBBoxWarnings, + errors: bboxErrors, + warnings: bboxWarnings, + } = checkBoundingBoxesForErrorsAndWarnings(userBoundingBoxes); + const hasErrors = hasAnnotationErrors || hasBBoxErrors; + const hasWarnings = hasBBoxWarnings; + const errors = [...annotationErrors, ...bboxErrors]; + const warnings = bboxWarnings; return (
) : null} + + {hasErrors + ? errors.map((error) => ( + + )) + : null} + {hasWarnings + ? warnings.map((warning) => ( + + )) + : null} + - + @@ -385,16 +425,16 @@ export function CollapsibleWorkflowYamlEditor({ ); } -function areInvalidAnnotationsIncluded( +function checkAnnotationsForErrorsAndWarnings( annotationsWithDatasets: Array>, ): { - areSomeAnnotationsInvalid: boolean; - invalidAnnotationsReason: string | null; + hasAnnotationErrors: boolean; + errors: string[]; } { if (annotationsWithDatasets.length === 0) { return { - areSomeAnnotationsInvalid: true, - invalidAnnotationsReason: "At least one annotation must be defined.", + hasAnnotationErrors: true, + errors: ["At least one annotation must be defined."], }; } const annotationsWithoutBoundingBoxes = annotationsWithDatasets.filter( @@ -407,42 +447,81 @@ function areInvalidAnnotationsIncluded( "id" in annotation ? annotation.id : annotation.annotationId, ); return { - areSomeAnnotationsInvalid: true, - invalidAnnotationsReason: `All annotations must have at least one bounding box. Annotations without bounding boxes are: ${annotationIds.join(", ")}`, + hasAnnotationErrors: true, + errors: [ + `All annotations must have at least one bounding box. Annotations without bounding boxes are:\n${annotationIds.join(", ")}`, + ], }; } - return { areSomeAnnotationsInvalid: false, invalidAnnotationsReason: null }; + return { hasAnnotationErrors: false, errors: [] }; } -function areInvalidBoundingBoxesIncluded(userBoundingBoxes: UserBoundingBox[]): { - areSomeBBoxesInvalid: boolean; - invalidBBoxesReason: string | null; +function checkBoundingBoxesForErrorsAndWarnings( + userBoundingBoxes: (UserBoundingBox & { annotationId: string })[], +): { + hasBBoxErrors: boolean; + hasBBoxWarnings: boolean; + errors: string[]; + warnings: string[]; } { + let hasBBoxErrors = false; + let hasBBoxWarnings = false; + const errors = []; + const warnings = []; if (userBoundingBoxes.length === 0) { - return { - areSomeBBoxesInvalid: true, - invalidBBoxesReason: "At least one bounding box must be defined.", - }; + hasBBoxErrors = true; + errors.push("At least one bounding box must be defined."); } - const getSize = (bbox: UserBoundingBox) => V3.sub(bbox.boundingBox.max, bbox.boundingBox.min); + // Find smallest bounding box dimensions + const minDimensions = userBoundingBoxes.reduce( + (min, { boundingBox: box }) => ({ + x: Math.min(min.x, box.max[0] - box.min[0]), + y: Math.min(min.y, box.max[1] - box.min[1]), + z: Math.min(min.z, box.max[2] - box.min[2]), + }), + { x: Number.POSITIVE_INFINITY, y: Number.POSITIVE_INFINITY, z: Number.POSITIVE_INFINITY }, + ); - const size = getSize(userBoundingBoxes[0]); - // width must equal height - if (size[0] !== size[1]) { - return { - areSomeBBoxesInvalid: true, - invalidBBoxesReason: "The bounding box width must equal its height.", - }; + // Validate minimum size and multiple requirements + type BoundingBoxWithAnnotationId = { boundingBox: Vector6; name: string; annotationId: string }; + const tooSmallBoxes: BoundingBoxWithAnnotationId[] = []; + const nonMultipleBoxes: BoundingBoxWithAnnotationId[] = []; + userBoundingBoxes.forEach(({ boundingBox: box, name, annotationId }) => { + const arrayBox = computeArrayFromBoundingBox(box); + const [_x, _y, _z, width, height, depth] = arrayBox; + if (width < 10 || height < 10 || depth < 10) { + tooSmallBoxes.push({ boundingBox: arrayBox, name, annotationId }); + } + + if ( + width % minDimensions.x !== 0 || + height % minDimensions.y !== 0 || + depth % minDimensions.z !== 0 + ) { + nonMultipleBoxes.push({ boundingBox: arrayBox, name, annotationId }); + } + }); + + const boxWithIdToString = ({ boundingBox, name, annotationId }: BoundingBoxWithAnnotationId) => + `'${name}' of annotation ${annotationId}: ${boundingBox.join(", ")}`; + + if (tooSmallBoxes.length > 0) { + hasBBoxWarnings = true; + const tooSmallBoxesStrings = tooSmallBoxes.map(boxWithIdToString); + warnings.push( + `The following bounding boxes are not at least 10 Vx in each dimension which is suboptimal for the training:\n${tooSmallBoxesStrings.join("\n")}`, + ); } - // all bounding boxes must have the same size - const areSizesIdentical = userBoundingBoxes.every((bbox) => V3.isEqual(getSize(bbox), size)); - if (areSizesIdentical) { - return { areSomeBBoxesInvalid: false, invalidBBoxesReason: null }; + + if (nonMultipleBoxes.length > 0) { + hasBBoxWarnings = true; + const nonMultipleBoxesStrings = nonMultipleBoxes.map(boxWithIdToString); + warnings.push( + `The minimum bounding box dimensions are ${minDimensions.x} x ${minDimensions.y} x ${minDimensions.z}. The following bounding boxes have dimensions which are not a multiple of the minimum dimensions which is suboptimal for the training:\n${nonMultipleBoxesStrings.join("\n")}`, + ); } - return { - areSomeBBoxesInvalid: true, - invalidBBoxesReason: "All bounding boxes must have the same size.", - }; + + return { hasBBoxErrors, hasBBoxWarnings, errors, warnings }; } function AnnotationsCsvInput({