diff --git a/changelog.d/20240311_093929_klakhov_improve_sam.md b/changelog.d/20240311_093929_klakhov_improve_sam.md new file mode 100644 index 000000000000..bec31a252964 --- /dev/null +++ b/changelog.d/20240311_093929_klakhov_improve_sam.md @@ -0,0 +1,4 @@ +### Fixed + +- Segment anything decoder is loaded anytime when CVAT is opened, but might be not required + () diff --git a/cvat-ui/plugins/sam/src/ts/index.tsx b/cvat-ui/plugins/sam/src/ts/index.tsx index 461a093bf66d..0473ee35dbec 100644 --- a/cvat-ui/plugins/sam/src/ts/index.tsx +++ b/cvat-ui/plugins/sam/src/ts/index.tsx @@ -1,11 +1,12 @@ -// Copyright (C) 2023 CVAT.ai Corporation +// Copyright (C) 2023-2024 CVAT.ai Corporation // // SPDX-License-Identifier: MIT -import { InferenceSession, Tensor } from 'onnxruntime-web'; +import { Tensor } from 'onnxruntime-web'; import { LRUCache } from 'lru-cache'; -import { Job } from 'cvat-core-wrapper'; +import { CVATCore, MLModel, Job } from 'cvat-core-wrapper'; import { PluginEntryPoint, APIWrapperEnterOptions, ComponentBuilder } from 'components/plugins-entrypoint'; +import { InitBody, DecodeBody, WorkerAction } from './inference.worker'; interface SAMPlugin { name: string; @@ -16,14 +17,14 @@ interface SAMPlugin { enter: ( plugin: SAMPlugin, taskID: number, - model: any, + model: MLModel, args: any, ) => Promise; leave: ( plugin: SAMPlugin, - result: any, + result: object, taskID: number, - model: any, + model: MLModel, args: any, ) => Promise; }; @@ -39,42 +40,32 @@ interface SAMPlugin { }; }; data: { - core: any; - jobs: Record; + initialized: boolean; + worker: Worker; + core: CVATCore | null; + jobs: Record; modelID: string; modelURL: string; embeddings: LRUCache; lowResMasks: LRUCache; - session: InferenceSession | null; + lastClicks: ClickType[]; }; callbacks: { onStatusChange: ((status: string) => void) | null; }; } -interface ONNXInput { - image_embeddings: Tensor; - point_coords: Tensor; - point_labels: Tensor; - orig_im_size: Tensor; - mask_input: Tensor; - has_mask_input: Tensor; - readonly [name: string]: Tensor; -} - interface ClickType { - clickType: -1 | 0 | 1, - height: number | null, - width: number | null, - x: number, - y: number, + clickType: 0 | 1; + x: number; + y: number; } function getModelScale(w: number, h: number): number { // Input images to SAM must be resized so the longest side is 1024 const LONG_SIDE_LENGTH = 1024; - const samScale = LONG_SIDE_LENGTH / Math.max(h, w); - return samScale; + const scale = LONG_SIDE_LENGTH / Math.max(h, w); + return scale; } function modelData( @@ -83,39 +74,27 @@ function modelData( }: { clicks: ClickType[]; tensor: Tensor; - modelScale: { height: number; width: number; samScale: number }; + modelScale: { height: number; width: number; scale: number }; maskInput: Tensor | null; }, -): ONNXInput { +): DecodeBody { const imageEmbedding = tensor; const n = clicks.length; - // If there is no box input, a single padding point with - // label -1 and coordinates (0.0, 0.0) should be concatenated - // so initialize the array to support (n + 1) points. - const pointCoords = new Float32Array(2 * (n + 1)); - const pointLabels = new Float32Array(n + 1); + const pointCoords = new Float32Array(2 * n); + const pointLabels = new Float32Array(n); - // Add clicks and scale to what SAM expects + // Scale and add clicks for (let i = 0; i < n; i++) { - pointCoords[2 * i] = clicks[i].x * modelScale.samScale; - pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale; + pointCoords[2 * i] = clicks[i].x * modelScale.scale; + pointCoords[2 * i + 1] = clicks[i].y * modelScale.scale; pointLabels[i] = clicks[i].clickType; } - // Add in the extra point/label when only clicks and no box - // The extra point is at (0, 0) with label -1 - pointCoords[2 * n] = 0.0; - pointCoords[2 * n + 1] = 0.0; - pointLabels[n] = -1.0; - // Create the tensor - const pointCoordsTensor = new Tensor('float32', pointCoords, [1, n + 1, 2]); - const pointLabelsTensor = new Tensor('float32', pointLabels, [1, n + 1]); - const imageSizeTensor = new Tensor('float32', [ - modelScale.height, - modelScale.width, - ]); + const pointCoordsTensor = new Tensor('float32', pointCoords, [1, n, 2]); + const pointLabelsTensor = new Tensor('float32', pointLabels, [1, n]); + const imageSizeTensor = new Tensor('float32', [modelScale.height, modelScale.width]); const prevMask = maskInput || new Tensor('float32', new Float32Array(256 * 256), [1, 1, 256, 256]); @@ -154,27 +133,55 @@ const samPlugin: SAMPlugin = { async enter( plugin: SAMPlugin, taskID: number, - model: any, { frame }: { frame: number }, + model: MLModel, { frame }: { frame: number }, ): Promise { - if (model.id === plugin.data.modelID) { - if (!plugin.data.session) { - throw new Error('SAM plugin is not ready, session was not initialized'); + return new Promise((resolve, reject) => { + function resolvePromise(): void { + const key = `${taskID}_${frame}`; + if (plugin.data.embeddings.has(key)) { + resolve({ preventMethodCall: true }); + } else { + resolve(null); + } } - const key = `${taskID}_${frame}`; - if (plugin.data.embeddings.has(key)) { - return { preventMethodCall: true }; + if (model.id === plugin.data.modelID) { + if (!plugin.data.initialized) { + samPlugin.data.worker.postMessage({ + action: WorkerAction.INIT, + payload: { + decoderURL: samPlugin.data.modelURL, + } as InitBody, + }); + + samPlugin.data.worker.onmessage = (e: MessageEvent) => { + if (e.data.action !== WorkerAction.INIT) { + reject(new Error( + `Caught unexpected action response from worker: ${e.data.action}`, + )); + } + + if (!e.data.error) { + samPlugin.data.initialized = true; + resolvePromise(); + } else { + reject(new Error(`SAM worker was not initialized. ${e.data.error}`)); + } + }; + } else { + resolvePromise(); + } + } else { + resolve(null); } - } - - return null; + }); }, async leave( plugin: SAMPlugin, result: any, taskID: number, - model: any, + model: MLModel, { frame, pos_points, neg_points }: { frame: number, pos_points: number[][], neg_points: number[][], }, @@ -183,106 +190,136 @@ const samPlugin: SAMPlugin = { mask: number[][]; bounds: [number, number, number, number]; }> { - if (model.id !== plugin.data.modelID) { - return result; - } - - const job = Object.values(plugin.data.jobs).find((_job) => ( - _job.taskId === taskID && frame >= _job.startFrame && frame <= _job.stopFrame - )) as Job; - if (!job) { - throw new Error('Could not find a job corresponding to the request'); - } - - const { height: imHeight, width: imWidth } = await job.frames.get(frame); - const key = `${taskID}_${frame}`; - - if (result) { - const bin = window.atob(result.blob); - const uint8Array = new Uint8Array(bin.length); - for (let i = 0; i < bin.length; i++) { - uint8Array[i] = bin.charCodeAt(i); + return new Promise((resolve, reject) => { + if (model.id !== plugin.data.modelID) { + resolve(result); } - const float32Arr = new Float32Array(uint8Array.buffer); - plugin.data.embeddings.set(key, new Tensor('float32', float32Arr, [1, 256, 64, 64])); - } - - const modelScale = { - width: imWidth, - height: imHeight, - samScale: getModelScale(imWidth, imHeight), - }; - - const composedClicks = [...pos_points, ...neg_points].map(([x, y], index) => ({ - clickType: index < pos_points.length ? 1 : 0 as 0 | 1 | -1, - height: null, - width: null, - x, - y, - })); - - const feeds = modelData({ - clicks: composedClicks, - tensor: plugin.data.embeddings.get(key) as Tensor, - modelScale, - maskInput: plugin.data.lowResMasks.has(key) ? plugin.data.lowResMasks.get(key) as Tensor : null, - }); - function toMatImage(input: number[], width: number, height: number): number[][] { - const image = Array(height).fill(0); - for (let i = 0; i < image.length; i++) { - image[i] = Array(width).fill(0); - } + const job = Object.values(plugin.data.jobs).find((_job) => ( + _job.taskId === taskID && frame >= _job.startFrame && frame <= _job.stopFrame + )) as Job; - for (let i = 0; i < input.length; i++) { - const row = Math.floor(i / width); - const col = i % width; - image[row][col] = input[i] * 255; + if (!job) { + throw new Error('Could not find a job corresponding to the request'); } - return image; - } - - function onnxToImage(input: any, width: number, height: number): number[][] { - return toMatImage(input, width, height); - } - - const data = await (plugin.data.session as InferenceSession).run(feeds); - const { masks, low_res_masks: lowResMasks } = data; - const imageData = onnxToImage(masks.data, masks.dims[3], masks.dims[2]); - plugin.data.lowResMasks.set(key, lowResMasks); - - const xtl = Number(data.xtl.data[0]); - const xbr = Number(data.xbr.data[0]); - const ytl = Number(data.ytl.data[0]); - const ybr = Number(data.ybr.data[0]); - - return { - mask: imageData, - bounds: [xtl, ytl, xbr, ybr], - }; + plugin.data.jobs = { + // we do not need to store old job instances + [job.id]: job, + }; + + job.frames.get(frame) + .then(({ height: imHeight, width: imWidth }: { height: number; width: number }) => { + const key = `${taskID}_${frame}`; + + if (result) { + const bin = window.atob(result.blob); + const uint8Array = new Uint8Array(bin.length); + for (let i = 0; i < bin.length; i++) { + uint8Array[i] = bin.charCodeAt(i); + } + const float32Arr = new Float32Array(uint8Array.buffer); + plugin.data.embeddings.set(key, new Tensor('float32', float32Arr, [1, 256, 64, 64])); + } + + const modelScale = { + width: imWidth, + height: imHeight, + scale: getModelScale(imWidth, imHeight), + }; + + const composedClicks = [...pos_points, ...neg_points].map(([x, y], index) => ({ + clickType: index < pos_points.length ? 1 : 0 as 0 | 1, + x, + y, + })); + + const isLowResMaskSuitable = JSON + .stringify(composedClicks.slice(0, -1)) === JSON.stringify(plugin.data.lastClicks); + const feeds = modelData({ + clicks: composedClicks, + tensor: plugin.data.embeddings.get(key) as Tensor, + modelScale, + maskInput: isLowResMaskSuitable ? plugin.data.lowResMasks.get(key) || null : null, + }); + + function toMatImage(input: number[], width: number, height: number): number[][] { + const image = Array(height).fill(0); + for (let i = 0; i < image.length; i++) { + image[i] = Array(width).fill(0); + } + + for (let i = 0; i < input.length; i++) { + const row = Math.floor(i / width); + const col = i % width; + image[row][col] = input[i] > 0 ? 255 : 0; + } + + return image; + } + + function onnxToImage(input: any, width: number, height: number): number[][] { + return toMatImage(input, width, height); + } + + plugin.data.worker.postMessage({ + action: WorkerAction.DECODE, + payload: feeds, + }); + + plugin.data.worker.onmessage = ((e) => { + if (e.data.action !== WorkerAction.DECODE) { + const error = 'Caught unexpected action response from worker: ' + + `${e.data.action}, while "${WorkerAction.DECODE}" was expected`; + reject(new Error(error)); + } + + if (!e.data.error) { + const { + masks, lowResMasks, xtl, ytl, xbr, ybr, + } = e.data.payload; + const imageData = onnxToImage(masks.data, masks.dims[3], masks.dims[2]); + plugin.data.lowResMasks.set(key, lowResMasks); + plugin.data.lastClicks = composedClicks; + + resolve({ + mask: imageData, + bounds: [xtl, ytl, xbr, ybr], + }); + } else { + reject(new Error(`Decoder error. ${e.data.error}`)); + } + }); + + plugin.data.worker.onerror = ((error) => { + reject(error); + }); + }); + }); }, }, }, }, data: { + initialized: false, core: null, + worker: new Worker(new URL('./inference.worker', import.meta.url)), jobs: {}, modelID: 'pth-facebookresearch-sam-vit-h', modelURL: '/assets/decoder.onnx', embeddings: new LRUCache({ - // float32 tensor [256, 64, 64] is 4 MB, max 512 MB - max: 128, + // float32 tensor [256, 64, 64] is 4 MB, max 128 MB + max: 32, updateAgeOnGet: true, updateAgeOnHas: true, }), lowResMasks: new LRUCache({ - // float32 tensor [1, 256, 256] is 0.25 MB, max 32 MB - max: 128, + // float32 tensor [1, 256, 256] is 0.25 MB, max 8 MB + max: 32, updateAgeOnGet: true, updateAgeOnHas: true, }), - session: null, + lastClicks: [], }, callbacks: { onStatusChange: null, @@ -292,9 +329,6 @@ const samPlugin: SAMPlugin = { const builder: ComponentBuilder = ({ core }) => { samPlugin.data.core = core; core.plugins.register(samPlugin); - InferenceSession.create(samPlugin.data.modelURL).then((session) => { - samPlugin.data.session = session; - }); return { name: samPlugin.name, diff --git a/cvat-ui/plugins/sam/src/ts/inference.worker.ts b/cvat-ui/plugins/sam/src/ts/inference.worker.ts new file mode 100644 index 000000000000..ff3927efaf87 --- /dev/null +++ b/cvat-ui/plugins/sam/src/ts/inference.worker.ts @@ -0,0 +1,90 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { InferenceSession, env, Tensor } from 'onnxruntime-web'; + +let decoder: InferenceSession | null = null; + +env.wasm.wasmPaths = '/assets/'; + +export enum WorkerAction { + INIT = 'init', + DECODE = 'decode', +} + +export interface InitBody { + decoderURL: string; +} + +export interface DecodeBody { + image_embeddings: Tensor; + point_coords: Tensor; + point_labels: Tensor; + orig_im_size: Tensor; + mask_input: Tensor; + has_mask_input: Tensor; + readonly [name: string]: Tensor; +} + +export interface WorkerOutput { + action: WorkerAction; + error?: string; +} + +export interface WorkerInput { + action: WorkerAction; + payload: InitBody | DecodeBody; +} + +const errorToMessage = (error: unknown): string => { + if (error instanceof Error) { + return error.message; + } + if (typeof error === 'string') { + return error; + } + + console.error(error); + return 'Unknown error, please check console'; +}; + +// eslint-disable-next-line no-restricted-globals +if ((self as any).importScripts) { + onmessage = (e: MessageEvent) => { + if (e.data.action === WorkerAction.INIT) { + if (decoder) { + return; + } + + const body = e.data.payload as InitBody; + InferenceSession.create(body.decoderURL).then((decoderSession) => { + decoder = decoderSession; + postMessage({ action: WorkerAction.INIT }); + }).catch((error: unknown) => { + postMessage({ action: WorkerAction.INIT, error: errorToMessage(error) }); + }); + } else if (!decoder) { + postMessage({ + action: e.data.action, + error: 'Worker was not initialized', + }); + } else if (e.data.action === WorkerAction.DECODE) { + decoder.run((e.data.payload as DecodeBody)).then((results) => { + postMessage({ + action: WorkerAction.DECODE, + payload: { + masks: results.masks, + lowResMasks: results.low_res_masks, + xtl: Number(results.xtl.data[0]), + ytl: Number(results.ytl.data[0]), + xbr: Number(results.xbr.data[0]), + ybr: Number(results.ybr.data[0]), + }, + }); + }).catch((error: unknown) => { + postMessage({ action: WorkerAction.DECODE, error: errorToMessage(error) }); + }); + } + }; +}