From 3a8ec5167484e954d57123bf63b1492f2ac54b29 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Wed, 23 Oct 2024 17:05:13 +0530 Subject: [PATCH] [mob][photos] Revert onnx preprocessing --- .../face_ml/face_detection/detection.dart | 4 +- .../face_detection_postprocessing.dart | 138 +++++++++++++ .../face_detection_service.dart | 98 +++++---- .../services/machine_learning/ml_model.dart | 1 - .../clip/clip_image_encoder.dart | 40 ++-- .../lib/services/remote_assets_service.dart | 4 +- mobile/lib/utils/image_ml_util.dart | 190 +++++++++++++++++- .../ente/photos/onnx_dart/OnnxDartPlugin.kt | 34 ++-- 8 files changed, 420 insertions(+), 89 deletions(-) create mode 100644 mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_postprocessing.dart diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart index 9f464c4f544..7c52e530462 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart @@ -92,8 +92,8 @@ class FaceDetectionRelative extends Detection { // Calculate the scaling final double scaleX = originalSize.width / newSize.width; final double scaleY = originalSize.height / newSize.height; - final double translateX = - ((originalSize.width - newSize.width) / 2) / originalSize.width; - final double translateY = - ((originalSize.height - newSize.height) / 2) / originalSize.height; + const double translateX = 0; + const double translateY = 0; // Transform Box _transformBox(box, scaleX, scaleY, translateX, translateY); diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_postprocessing.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_postprocessing.dart new file mode 100644 index 00000000000..a01017559dc --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_postprocessing.dart @@ -0,0 +1,138 @@ +import 'dart:math' as math show max, min; + +import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; + +List yoloOnnxFilterExtractDetections( + double minScoreSigmoidThreshold, + int inputWidth, + int inputHeight, { + required List> results, // // [25200, 16] +}) { + final outputDetections = []; + final output = >[]; + + // Go through the raw output and check the scores + for (final result in results) { + // Filter out raw detections with low scores + if (result[4] < minScoreSigmoidThreshold) { + continue; + } + + // Get the raw detection + final rawDetection = List.from(result); + + // Append the processed raw detection to the output + output.add(rawDetection); + } + + if (output.isEmpty) { + double maxScore = 0; + for (final result in results) { + if (result[4] > maxScore) { + maxScore = result[4]; + } + } + } + + for (final List rawDetection in output) { + // Get absolute bounding box coordinates in format [xMin, yMin, xMax, yMax] https://github.com/deepcam-cn/yolov5-face/blob/eb23d18defe4a76cc06449a61cd51004c59d2697/utils/general.py#L216 + final xMinAbs = rawDetection[0] - rawDetection[2] / 2; + final yMinAbs = rawDetection[1] - rawDetection[3] / 2; + final xMaxAbs = rawDetection[0] + rawDetection[2] / 2; + final yMaxAbs = rawDetection[1] + rawDetection[3] / 2; + + // Get the relative bounding box coordinates in format [xMin, yMin, xMax, yMax] + final box = [ + xMinAbs / inputWidth, + yMinAbs / inputHeight, + xMaxAbs / inputWidth, + yMaxAbs / inputHeight, + ]; + + // Get the keypoints coordinates in format [x, y] + final allKeypoints = >[ + [ + rawDetection[5] / inputWidth, + rawDetection[6] / inputHeight, + ], + [ + rawDetection[7] / inputWidth, + rawDetection[8] / inputHeight, + ], + [ + rawDetection[9] / inputWidth, + rawDetection[10] / inputHeight, + ], + [ + rawDetection[11] / inputWidth, + rawDetection[12] / inputHeight, + ], + [ + rawDetection[13] / inputWidth, + rawDetection[14] / inputHeight, + ], + ]; + + // Get the score + final score = + rawDetection[4]; // Or should it be rawDetection[4]*rawDetection[15]? + + // Create the relative detection + final detection = FaceDetectionRelative( + score: score, + box: box, + allKeypoints: allKeypoints, + ); + + // Append the relative detection to the output + outputDetections.add(detection); + } + + return outputDetections; +} + +List naiveNonMaxSuppression({ + required List detections, + required double iouThreshold, +}) { + // Sort the detections by score, the highest first + detections.sort((a, b) => b.score.compareTo(a.score)); + + // Loop through the detections and calculate the IOU + for (var i = 0; i < detections.length - 1; i++) { + for (var j = i + 1; j < detections.length; j++) { + final iou = _calculateIOU(detections[i], detections[j]); + if (iou >= iouThreshold) { + detections.removeAt(j); + j--; + } + } + } + return detections; +} + +double _calculateIOU( + FaceDetectionRelative detectionA, + FaceDetectionRelative detectionB, +) { + final areaA = detectionA.width * detectionA.height; + final areaB = detectionB.width * detectionB.height; + + final intersectionMinX = math.max(detectionA.xMinBox, detectionB.xMinBox); + final intersectionMinY = math.max(detectionA.yMinBox, detectionB.yMinBox); + final intersectionMaxX = math.min(detectionA.xMaxBox, detectionB.xMaxBox); + final intersectionMaxY = math.min(detectionA.yMaxBox, detectionB.yMaxBox); + + final intersectionWidth = intersectionMaxX - intersectionMinX; + final intersectionHeight = intersectionMaxY - intersectionMinY; + + if (intersectionWidth < 0 || intersectionHeight < 0) { + return 0.0; // If boxes do not overlap, IoU is 0 + } + + final intersectionArea = intersectionWidth * intersectionHeight; + + final unionArea = areaA + areaB - intersectionArea; + + return intersectionArea / unionArea; +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart index 5364a7b944a..5f2458c6291 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart @@ -1,6 +1,5 @@ import "dart:async"; -import "dart:math" show min; -import 'dart:typed_data' show Int32List, Uint8List; +import 'dart:typed_data' show Float32List, Uint8List; import 'dart:ui' as ui show Image; import 'package:logging/logging.dart'; @@ -8,13 +7,15 @@ import "package:onnx_dart/onnx_dart.dart"; import 'package:onnxruntime/onnxruntime.dart'; import "package:photos/models/ml/face/dimension.dart"; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; +import "package:photos/services/machine_learning/face_ml/face_detection/face_detection_postprocessing.dart"; import "package:photos/services/machine_learning/ml_model.dart"; +import "package:photos/utils/image_ml_util.dart"; class YOLOFaceInterpreterRunException implements Exception {} /// This class is responsible for running the face detection model (YOLOv5Face) on ONNX runtime, and can be accessed through the singleton instance [FaceDetectionService.instance]. class FaceDetectionService extends MlModel { - static const kRemoteBucketModelPath = "yolov5s_face_opset18_rgba_opt_nosplits.onnx"; + static const kRemoteBucketModelPath = "yolov5s_face_640_640_dynamic.onnx"; static const _modelName = "YOLOv5Face"; @override @@ -54,32 +55,34 @@ class FaceDetectionService extends MlModel { ); final startTime = DateTime.now(); - final inputShape = [image.height, image.width, 4]; // [H, W, C] - final scaledSize = _getScaledSize(image.width, image.height); + + final (inputImageList, scaledSize) = await preprocessImageYoloFace( + image, + rawRgbaBytes, + ); + final preprocessingTime = DateTime.now(); + final preprocessingMs = + preprocessingTime.difference(startTime).inMilliseconds; // Run inference - List>? nestedResults = []; + List>>? nestedResults = []; try { if (MlModel.usePlatformPlugin) { - nestedResults = - await _runPlatformPluginPredict(rawRgbaBytes, inputShape); + nestedResults = await _runPlatformPluginPredict(inputImageList); } else { nestedResults = _runFFIBasedPredict( - rawRgbaBytes, - inputShape, sessionAddress, - ); // [detections, 16] + inputImageList, + ); } final inferenceTime = DateTime.now(); + final inferenceMs = + inferenceTime.difference(preprocessingTime).inMilliseconds; _logger.info( - 'Face detection is finished, in ${inferenceTime.difference(startTime).inMilliseconds} ms', + 'Face detection is finished, in ${inferenceTime.difference(startTime).inMilliseconds} ms (preprocessing: $preprocessingMs ms, inference: $inferenceMs ms)', ); } catch (e, s) { - _logger.severe( - 'Error while running inference (PlatformPlugin: ${MlModel.usePlatformPlugin})', - e, - s, - ); + _logger.severe('Error while running inference (PlatformPlugin: ${MlModel.usePlatformPlugin})', e, s); throw YOLOFaceInterpreterRunException(); } try { @@ -92,20 +95,26 @@ class FaceDetectionService extends MlModel { } } - static List>? _runFFIBasedPredict( - Uint8List inputImageList, - List inputImageShape, + static List>>? _runFFIBasedPredict( int sessionAddress, + Float32List inputImageList, ) { + const inputShape = [ + 1, + 3, + kInputHeight, + kInputWidth, + ]; final inputOrt = OrtValueTensor.createTensorWithDataList( inputImageList, - inputImageShape, + inputShape, ); final inputs = {'input': inputOrt}; final runOptions = OrtRunOptions(); final session = OrtSession.fromAddress(sessionAddress); final List outputs = session.run(runOptions, inputs); - final result = outputs[0]?.value as List>; // [detections, 16] + final result = + outputs[0]?.value as List>>; // [1, 25200, 16] inputOrt.release(); runOptions.release(); for (var element in outputs) { @@ -115,36 +124,41 @@ class FaceDetectionService extends MlModel { return result; } - static Future>> _runPlatformPluginPredict( - Uint8List inputImageList, - List inputImageShape, + static Future>>> _runPlatformPluginPredict( + Float32List inputImageList, ) async { final OnnxDart plugin = OnnxDart(); - final result = await plugin.predictRgba( + final result = await plugin.predict( inputImageList, - Int32List.fromList(inputImageShape), _modelName, ); final int resultLength = result!.length; - assert(resultLength % 16 == 0); - final int detections = resultLength ~/ 16; + assert(resultLength % 25200 * 16 == 0); + const int outerLength = 1; + const int middleLength = 25200; + const int innerLength = 16; return List.generate( - detections, - (index) => result.sublist(index * 16, (index + 1) * 16).toList(), + outerLength, + (_) => List.generate( + middleLength, + (j) => result.sublist(j * innerLength, (j + 1) * innerLength).toList(), + ), ); } static List _yoloPostProcessOutputs( - List> nestedResults, + List>> nestedResults, Dimensions scaledSize, ) { + final firstResults = nestedResults[0]; // [25200, 16] + // Filter output - final relativeDetections = _yoloOnnxFilterExtractDetections( + var relativeDetections = _yoloOnnxFilterExtractDetections( kMinScoreSigmoidThreshold, kInputWidth, kInputHeight, - results: nestedResults, + results: firstResults, ); // Account for the fact that the aspect ratio was maintained @@ -158,15 +172,13 @@ class FaceDetectionService extends MlModel { ); } - return relativeDetections; - } - - static Dimensions _getScaledSize(int imageWidth, int imageHeight) { - final scale = min(kInputWidth / imageWidth, kInputHeight / imageHeight); - final scaledWidth = (imageWidth * scale).round().clamp(0, kInputWidth); - final scaledHeight = (imageHeight * scale).round().clamp(0, kInputHeight); + // Non-maximum suppression to remove duplicate detections + relativeDetections = naiveNonMaxSuppression( + detections: relativeDetections, + iouThreshold: kIouThreshold, + ); - return Dimensions(width: scaledWidth, height: scaledHeight); + return relativeDetections; } } @@ -252,4 +264,4 @@ List _yoloOnnxFilterExtractDetections( } return outputDetections; -} \ No newline at end of file +} diff --git a/mobile/lib/services/machine_learning/ml_model.dart b/mobile/lib/services/machine_learning/ml_model.dart index 1b97a858b9a..ec91f0a0dfc 100644 --- a/mobile/lib/services/machine_learning/ml_model.dart +++ b/mobile/lib/services/machine_learning/ml_model.dart @@ -140,7 +140,6 @@ abstract class MlModel { ONNXEnvFFI.instance.initONNX(modelName); try { final sessionOptions = OrtSessionOptions() - ..appendCPUProvider(CPUFlags.useArena) ..setInterOpNumThreads(1) ..setIntraOpNumThreads(1) ..setSessionGraphOptimizationLevel( diff --git a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart index 1307ee28368..be601a71f28 100644 --- a/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart +++ b/mobile/lib/services/machine_learning/semantic_search/clip/clip_image_encoder.dart @@ -1,17 +1,15 @@ -import "dart:typed_data" show Int32List, Uint8List; +import "dart:typed_data" show Uint8List, Float32List; import "dart:ui" show Image; import "package:logging/logging.dart"; import "package:onnx_dart/onnx_dart.dart"; import "package:onnxruntime/onnxruntime.dart"; import "package:photos/services/machine_learning/ml_model.dart"; +import "package:photos/utils/image_ml_util.dart"; import "package:photos/utils/ml_util.dart"; class ClipImageEncoder extends MlModel { - static const kRemoteBucketModelPath = - "mobileclip_s2_image_opset18_rgba_opt.onnx"; // FP32 model - // static const kRemoteBucketModelPath = - // "mobileclip_s2_image_opset18_fp16.onnx"; // FP16 model + static const kRemoteBucketModelPath = "mobileclip_s2_image.onnx"; static const _modelName = "ClipImageEncoder"; @override @@ -36,13 +34,16 @@ class ClipImageEncoder extends MlModel { int? enteFileID, ]) async { final startTime = DateTime.now(); - final inputShape = [image.height, image.width, 4]; // [H, W, C] + final inputList = await preprocessImageClip(image, rawRgbaBytes); + final preprocessingTime = DateTime.now(); + final preprocessingMs = + preprocessingTime.difference(startTime).inMilliseconds; late List result; try { if (MlModel.usePlatformPlugin) { - result = await _runPlatformPluginPredict(rawRgbaBytes, inputShape); + result = await _runPlatformPluginPredict(inputList); } else { - result = _runFFIBasedPredict(rawRgbaBytes, inputShape, sessionAddress); + result = _runFFIBasedPredict(inputList, sessionAddress); } } catch (e, stackTrace) { _logger.severe( @@ -52,22 +53,21 @@ class ClipImageEncoder extends MlModel { ); rethrow; } - final totalMs = DateTime.now().difference(startTime).inMilliseconds; + final inferTime = DateTime.now(); + final inferenceMs = inferTime.difference(preprocessingTime).inMilliseconds; + final totalMs = inferTime.difference(startTime).inMilliseconds; _logger.info( - "Clip image predict took $totalMs ms${enteFileID != null ? " with fileID $enteFileID" : ""}", + "Clip image predict took $totalMs ms${enteFileID != null ? " with fileID $enteFileID" : ""} (inference: $inferenceMs ms, preprocessing: $preprocessingMs ms)", ); return result; } static List _runFFIBasedPredict( - Uint8List inputImageList, - List inputImageShape, + Float32List inputList, int sessionAddress, ) { - final inputOrt = OrtValueTensor.createTensorWithDataList( - inputImageList, - inputImageShape, - ); + final inputOrt = + OrtValueTensor.createTensorWithDataList(inputList, [1, 3, 256, 256]); final inputs = {'input': inputOrt}; final session = OrtSession.fromAddress(sessionAddress); final runOptions = OrtRunOptions(); @@ -83,13 +83,11 @@ class ClipImageEncoder extends MlModel { } static Future> _runPlatformPluginPredict( - Uint8List inputImageList, - List inputImageShape, + Float32List inputList, ) async { final OnnxDart plugin = OnnxDart(); - final result = await plugin.predictRgba( - inputImageList, - Int32List.fromList(inputImageShape), + final result = await plugin.predict( + inputList, _modelName, ); final List embedding = result!.sublist(0, 512); diff --git a/mobile/lib/services/remote_assets_service.dart b/mobile/lib/services/remote_assets_service.dart index b9d54fe8b75..e14155af1b4 100644 --- a/mobile/lib/services/remote_assets_service.dart +++ b/mobile/lib/services/remote_assets_service.dart @@ -121,11 +121,11 @@ class RemoteAssetsService { const oldModelNames = [ "https://models.ente.io/clip-image-vit-32-float32.onnx", "https://models.ente.io/clip-text-vit-32-uint8.onnx", - "https://models.ente.io/mobileclip_s2_image.onnx", "https://models.ente.io/mobileclip_s2_image_opset18_rgba_sim.onnx", + "https://models.ente.io/mobileclip_s2_image_opset18_rgba_opt.onnx", "https://models.ente.io/mobileclip_s2_text_int32.onnx", - "https://models.ente.io/yolov5s_face_640_640_dynamic.onnx", "https://models.ente.io/yolov5s_face_opset18_rgba_opt.onnx", + "https://models.ente.io/yolov5s_face_opset18_rgba_opt_nosplits.onnx", ]; for (final remotePath in oldModelNames) { diff --git a/mobile/lib/utils/image_ml_util.dart b/mobile/lib/utils/image_ml_util.dart index 4075ddf7656..dd1e4512bcf 100644 --- a/mobile/lib/utils/image_ml_util.dart +++ b/mobile/lib/utils/image_ml_util.dart @@ -1,6 +1,6 @@ import "dart:async"; import "dart:io" show File, Platform; -import "dart:math" show min; +import "dart:math" show exp, max, min, pi; import "dart:typed_data" show Float32List, Uint8List; import "dart:ui"; @@ -9,6 +9,7 @@ import "package:flutter_image_compress/flutter_image_compress.dart"; import "package:logging/logging.dart"; import 'package:ml_linalg/linalg.dart'; import "package:photos/models/ml/face/box.dart"; +import "package:photos/models/ml/face/dimension.dart"; import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; import 'package:photos/services/machine_learning/face_ml/face_alignment/similarity_transform.dart'; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; @@ -22,6 +23,15 @@ final _logger = Logger("ImageMlUtil"); /// These are 8 bit unsigned integers in range 0-255 for each RGB channel typedef RGB = (int, int, int); +const gaussianKernelSize = 5; +const gaussianKernelRadius = gaussianKernelSize ~/ 2; +const gaussianSigma = 10.0; +final List> gaussianKernel = + create2DGaussianKernel(gaussianKernelSize, gaussianSigma); + +const maxKernelSize = gaussianKernelSize; +const maxKernelRadius = maxKernelSize ~/ 2; + Future<(Image, Uint8List)> decodeImageFromPath(String imagePath) async { try { final imageData = await File(imagePath).readAsBytes(); @@ -168,6 +178,83 @@ Future> generateFaceThumbnailsUsingCanvas( } } +Future<(Float32List, Dimensions)> preprocessImageYoloFace( + Image image, + Uint8List rawRgbaBytes, +) async { + const requiredWidth = 640; + const requiredHeight = 640; + final scale = min(requiredWidth / image.width, requiredHeight / image.height); + final scaledWidth = (image.width * scale).round().clamp(0, requiredWidth); + final scaledHeight = (image.height * scale).round().clamp(0, requiredHeight); + + final processedBytes = Float32List(3 * requiredHeight * requiredWidth); + + final buffer = Float32List.view(processedBytes.buffer); + int pixelIndex = 0; + const int channelOffsetGreen = requiredHeight * requiredWidth; + const int channelOffsetBlue = 2 * requiredHeight * requiredWidth; + for (var h = 0; h < requiredHeight; h++) { + for (var w = 0; w < requiredWidth; w++) { + late RGB pixel; + if (w >= scaledWidth || h >= scaledHeight) { + pixel = const (114, 114, 114); + } else { + pixel = _getPixelBilinear( + w / scale, + h / scale, + image, + rawRgbaBytes, + ); + } + buffer[pixelIndex] = pixel.$1 / 255; + buffer[pixelIndex + channelOffsetGreen] = pixel.$2 / 255; + buffer[pixelIndex + channelOffsetBlue] = pixel.$3 / 255; + pixelIndex++; + } + } + + return (processedBytes, Dimensions(width: scaledWidth, height: scaledHeight)); +} + +Future preprocessImageClip( + Image image, + Uint8List rawRgbaBytes, +) async { + const int requiredWidth = 256; + const int requiredHeight = 256; + const int requiredSize = 3 * requiredWidth * requiredHeight; + final scale = max(requiredWidth / image.width, requiredHeight / image.height); + final bool useAntiAlias = scale < 0.8; + final scaledWidth = (image.width * scale).round(); + final scaledHeight = (image.height * scale).round(); + final widthOffset = max(0, scaledWidth - requiredWidth) / 2; + final heightOffset = max(0, scaledHeight - requiredHeight) / 2; + + final processedBytes = Float32List(requiredSize); + final buffer = Float32List.view(processedBytes.buffer); + int pixelIndex = 0; + const int greenOff = requiredHeight * requiredWidth; + const int blueOff = 2 * requiredHeight * requiredWidth; + for (var h = 0 + heightOffset; h < scaledHeight - heightOffset; h++) { + for (var w = 0 + widthOffset; w < scaledWidth - widthOffset; w++) { + final RGB pixel = _getPixelBilinear( + w / scale, + h / scale, + image, + rawRgbaBytes, + antiAlias: useAntiAlias, + ); + buffer[pixelIndex] = pixel.$1 / 255; + buffer[pixelIndex + greenOff] = pixel.$2 / 255; + buffer[pixelIndex + blueOff] = pixel.$3 / 255; + pixelIndex++; + } + } + + return processedBytes; +} + Future<(Float32List, List, List, List, Size)> preprocessToMobileFaceNetFloat32List( Image image, @@ -245,7 +332,10 @@ RGB _readPixelColor( Uint8List rgbaBytes, ) { if (y < 0 || y >= image.height || x < 0 || x >= image.width) { - if (y < -2 || y >= image.height + 2 || x < -2 || x >= image.width + 2) { + if (y < -maxKernelRadius || + y >= image.height + maxKernelRadius || + x < -maxKernelRadius || + x >= image.width + maxKernelRadius) { _logger.severe( '`readPixelColor`: Invalid pixel coordinates, out of bounds. x: $x, y: $y', ); @@ -263,6 +353,29 @@ RGB _readPixelColor( ); } +RGB _getPixelBlurred( + int x, + int y, + Image image, + Uint8List rgbaBytes, +) { + double r = 0, g = 0, b = 0; + for (int ky = 0; ky < gaussianKernelSize; ky++) { + for (int kx = 0; kx < gaussianKernelSize; kx++) { + final int px = (x - gaussianKernelRadius + kx); + final int py = (y - gaussianKernelRadius + ky); + + final RGB pixelRgbTuple = _readPixelColor(px, py, image, rgbaBytes); + final double weight = gaussianKernel[ky][kx]; + + r += pixelRgbTuple.$1 * weight; + g += pixelRgbTuple.$2 * weight; + b += pixelRgbTuple.$3 * weight; + } + } + return (r.round(), g.round(), b.round()); +} + List> _createGrayscaleIntMatrixFromNormalized2List( Float32List imageList, int startIndex, { @@ -406,6 +519,52 @@ Future _cropAndEncodeCanvas( return await _encodeImageToPng(croppedImage); } +RGB _getPixelBilinear( + num fx, + num fy, + Image image, + Uint8List rawRgbaBytes, { + bool antiAlias = false, +}) { + // Clamp to image boundaries + fx = fx.clamp(0, image.width - 1); + fy = fy.clamp(0, image.height - 1); + + // Get the surrounding coordinates and their weights + final int x0 = fx.floor(); + final int x1 = fx.ceil(); + final int y0 = fy.floor(); + final int y1 = fy.ceil(); + final dx = fx - x0; + final dy = fy - y0; + final dx1 = 1.0 - dx; + final dy1 = 1.0 - dy; + + // Get the original pixels (with gaussian blur if antialias) + final RGB Function(int, int, Image, Uint8List) readPixel = + antiAlias ? _getPixelBlurred : _readPixelColor; + final RGB pixel1 = readPixel(x0, y0, image, rawRgbaBytes); + final RGB pixel2 = readPixel(x1, y0, image, rawRgbaBytes); + final RGB pixel3 = readPixel(x0, y1, image, rawRgbaBytes); + final RGB pixel4 = readPixel(x1, y1, image, rawRgbaBytes); + + int bilinear( + num val1, + num val2, + num val3, + num val4, + ) => + (val1 * dx1 * dy1 + val2 * dx * dy1 + val3 * dx1 * dy + val4 * dx * dy) + .round(); + + // Calculate the weighted sum of pixels + final int r = bilinear(pixel1.$1, pixel2.$1, pixel3.$1, pixel4.$1); + final int g = bilinear(pixel1.$2, pixel2.$2, pixel3.$2, pixel4.$2); + final int b = bilinear(pixel1.$3, pixel2.$3, pixel3.$3, pixel4.$3); + + return (r, g, b); +} + /// Get the pixel value using Bicubic Interpolation. Code taken mainly from https://github.com/brendan-duncan/image/blob/6e407612752ffdb90b28cd5863c7f65856349348/lib/src/image/image.dart#L697 RGB _getPixelBicubic(num fx, num fy, Image image, Uint8List rawRgbaBytes) { fx = fx.clamp(0, image.width - 1); @@ -497,3 +656,30 @@ RGB _getPixelBicubic(num fx, num fy, Image image, Uint8List rawRgbaBytes) { return (c0, c1, c2); // (red, green, blue) } + +List> create2DGaussianKernel(int size, double sigma) { + final List> kernel = + List.generate(size, (_) => List.filled(size, 0)); + double sum = 0.0; + final int center = size ~/ 2; + + for (int y = 0; y < size; y++) { + for (int x = 0; x < size; x++) { + final int dx = x - center; + final int dy = y - center; + final double g = (1 / (2 * pi * sigma * sigma)) * + exp(-(dx * dx + dy * dy) / (2 * sigma * sigma)); + kernel[y][x] = g; + sum += g; + } + } + + // Normalize the kernel + for (int y = 0; y < size; y++) { + for (int x = 0; x < size; x++) { + kernel[y][x] /= sum; + } + } + + return kernel; +} diff --git a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt index 82688af394b..9774d07c67c 100644 --- a/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt +++ b/mobile/plugins/onnx_dart/android/src/main/kotlin/io/ente/photos/onnx_dart/OnnxDartPlugin.kt @@ -172,30 +172,20 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { try { val env = OrtEnvironment.getEnvironment() - var inputTensorShape: LongArray = longArrayOf(1, 112, 112, 3) + var inputTensorShape: LongArray = longArrayOf(1, 3, 640, 640) when (modelType) { ModelType.MobileFaceNet -> { val totalSize = inputDataFloat!!.size.toLong() / FACENET_SINGLE_INPUT_SIZE - if (totalSize != 1.toLong()) { inputTensorShape = longArrayOf(totalSize, 112, 112, 3) - } } ModelType.ClipImageEncoder -> { - if (inputShapeArray != null) { - inputTensorShape = inputShapeArray.map { it.toLong() }.toLongArray() - } else { - result.error("INVALID_ARGUMENT", "Input shape is missing for clip image input", null) - } + inputTensorShape = longArrayOf(1, 3, 256, 256) } ModelType.ClipTextEncoder -> { inputTensorShape = longArrayOf(1, 77) } ModelType.YOLOv5Face -> { - if (inputShapeArray != null) { - inputTensorShape = inputShapeArray.map { it.toLong() }.toLongArray() - } else { - result.error("INVALID_ARGUMENT", "Input shape is missing for YOLOv5Face input", null) - } + inputTensorShape = longArrayOf(1, 3, 640, 640) } } @@ -212,11 +202,20 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { inputs["input"] = inputTensor } val outputs = session.run(inputs) - val outputTensor = (outputs[0].value as Array) - val flatList = outputTensor.flattenToFloatArray() - withContext(Dispatchers.Main) { - result.success(flatList) + if (modelType == ModelType.YOLOv5Face) { + val outputTensor = (outputs[0].value as Array>).get(0) + val flatList = outputTensor.flattenToFloatArray() + withContext(Dispatchers.Main) { + result.success(flatList) + } + } else { + val outputTensor = (outputs[0].value as Array) + val flatList = outputTensor.flattenToFloatArray() + withContext(Dispatchers.Main) { + result.success(flatList) + } } + outputs.close() inputTensor.close() } catch (e: OrtException) { @@ -234,7 +233,6 @@ class OnnxDartPlugin: FlutterPlugin, MethodCallHandler { private fun createSession(env: OrtEnvironment, modalPath: String): OrtSession? { val sessionOptions = OrtSession.SessionOptions() - sessionOptions.addCPU(true) sessionOptions.setInterOpNumThreads(1) sessionOptions.setIntraOpNumThreads(1) sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT)