Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mob][photos] Revert onnx preprocessing #3818

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ class FaceDetectionRelative extends Detection {
// Calculate the scaling
final double scaleX = originalSize.width / newSize.width;
final double scaleY = originalSize.height / newSize.height;
final double translateX = - ((originalSize.width - newSize.width) / 2) / originalSize.width;
final double translateY = - ((originalSize.height - newSize.height) / 2) / originalSize.height;
const double translateX = 0;
const double translateY = 0;

// Transform Box
_transformBox(box, scaleX, scaleY, translateX, translateY);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import 'dart:math' as math show max, min;

import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart';

List<FaceDetectionRelative> yoloOnnxFilterExtractDetections(
double minScoreSigmoidThreshold,
int inputWidth,
int inputHeight, {
required List<List<double>> results, // // [25200, 16]
}) {
final outputDetections = <FaceDetectionRelative>[];
final output = <List<double>>[];

// Go through the raw output and check the scores
for (final result in results) {
// Filter out raw detections with low scores
if (result[4] < minScoreSigmoidThreshold) {
continue;
}

// Get the raw detection
final rawDetection = List<double>.from(result);

// Append the processed raw detection to the output
output.add(rawDetection);
}

if (output.isEmpty) {
double maxScore = 0;
for (final result in results) {
if (result[4] > maxScore) {
maxScore = result[4];
}
}
}

for (final List<double> rawDetection in output) {
// Get absolute bounding box coordinates in format [xMin, yMin, xMax, yMax] https://github.com/deepcam-cn/yolov5-face/blob/eb23d18defe4a76cc06449a61cd51004c59d2697/utils/general.py#L216
final xMinAbs = rawDetection[0] - rawDetection[2] / 2;
final yMinAbs = rawDetection[1] - rawDetection[3] / 2;
final xMaxAbs = rawDetection[0] + rawDetection[2] / 2;
final yMaxAbs = rawDetection[1] + rawDetection[3] / 2;

// Get the relative bounding box coordinates in format [xMin, yMin, xMax, yMax]
final box = [
xMinAbs / inputWidth,
yMinAbs / inputHeight,
xMaxAbs / inputWidth,
yMaxAbs / inputHeight,
];

// Get the keypoints coordinates in format [x, y]
final allKeypoints = <List<double>>[
[
rawDetection[5] / inputWidth,
rawDetection[6] / inputHeight,
],
[
rawDetection[7] / inputWidth,
rawDetection[8] / inputHeight,
],
[
rawDetection[9] / inputWidth,
rawDetection[10] / inputHeight,
],
[
rawDetection[11] / inputWidth,
rawDetection[12] / inputHeight,
],
[
rawDetection[13] / inputWidth,
rawDetection[14] / inputHeight,
],
];

// Get the score
final score =
rawDetection[4]; // Or should it be rawDetection[4]*rawDetection[15]?

// Create the relative detection
final detection = FaceDetectionRelative(
score: score,
box: box,
allKeypoints: allKeypoints,
);

// Append the relative detection to the output
outputDetections.add(detection);
}

return outputDetections;
}

List<FaceDetectionRelative> naiveNonMaxSuppression({
required List<FaceDetectionRelative> detections,
required double iouThreshold,
}) {
// Sort the detections by score, the highest first
detections.sort((a, b) => b.score.compareTo(a.score));

// Loop through the detections and calculate the IOU
for (var i = 0; i < detections.length - 1; i++) {
for (var j = i + 1; j < detections.length; j++) {
final iou = _calculateIOU(detections[i], detections[j]);
if (iou >= iouThreshold) {
detections.removeAt(j);
j--;
}
}
}
return detections;
}

double _calculateIOU(
FaceDetectionRelative detectionA,
FaceDetectionRelative detectionB,
) {
final areaA = detectionA.width * detectionA.height;
final areaB = detectionB.width * detectionB.height;

final intersectionMinX = math.max(detectionA.xMinBox, detectionB.xMinBox);
final intersectionMinY = math.max(detectionA.yMinBox, detectionB.yMinBox);
final intersectionMaxX = math.min(detectionA.xMaxBox, detectionB.xMaxBox);
final intersectionMaxY = math.min(detectionA.yMaxBox, detectionB.yMaxBox);

final intersectionWidth = intersectionMaxX - intersectionMinX;
final intersectionHeight = intersectionMaxY - intersectionMinY;

if (intersectionWidth < 0 || intersectionHeight < 0) {
return 0.0; // If boxes do not overlap, IoU is 0
}

final intersectionArea = intersectionWidth * intersectionHeight;

final unionArea = areaA + areaB - intersectionArea;

return intersectionArea / unionArea;
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import "dart:async";
import "dart:math" show min;
import 'dart:typed_data' show Int32List, Uint8List;
import 'dart:typed_data' show Float32List, Uint8List;
import 'dart:ui' as ui show Image;

import 'package:logging/logging.dart';
import "package:onnx_dart/onnx_dart.dart";
import 'package:onnxruntime/onnxruntime.dart';
import "package:photos/models/ml/face/dimension.dart";
import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart';
import "package:photos/services/machine_learning/face_ml/face_detection/face_detection_postprocessing.dart";
import "package:photos/services/machine_learning/ml_model.dart";
import "package:photos/utils/image_ml_util.dart";

class YOLOFaceInterpreterRunException implements Exception {}

/// This class is responsible for running the face detection model (YOLOv5Face) on ONNX runtime, and can be accessed through the singleton instance [FaceDetectionService.instance].
class FaceDetectionService extends MlModel {
static const kRemoteBucketModelPath = "yolov5s_face_opset18_rgba_opt_nosplits.onnx";
static const kRemoteBucketModelPath = "yolov5s_face_640_640_dynamic.onnx";
static const _modelName = "YOLOv5Face";

@override
Expand Down Expand Up @@ -54,32 +55,34 @@ class FaceDetectionService extends MlModel {
);

final startTime = DateTime.now();
final inputShape = <int>[image.height, image.width, 4]; // [H, W, C]
final scaledSize = _getScaledSize(image.width, image.height);

final (inputImageList, scaledSize) = await preprocessImageYoloFace(
image,
rawRgbaBytes,
);
final preprocessingTime = DateTime.now();
final preprocessingMs =
preprocessingTime.difference(startTime).inMilliseconds;

// Run inference
List<List<double>>? nestedResults = [];
List<List<List<double>>>? nestedResults = [];
try {
if (MlModel.usePlatformPlugin) {
nestedResults =
await _runPlatformPluginPredict(rawRgbaBytes, inputShape);
nestedResults = await _runPlatformPluginPredict(inputImageList);
} else {
nestedResults = _runFFIBasedPredict(
rawRgbaBytes,
inputShape,
sessionAddress,
); // [detections, 16]
inputImageList,
);
}
final inferenceTime = DateTime.now();
final inferenceMs =
inferenceTime.difference(preprocessingTime).inMilliseconds;
_logger.info(
'Face detection is finished, in ${inferenceTime.difference(startTime).inMilliseconds} ms',
'Face detection is finished, in ${inferenceTime.difference(startTime).inMilliseconds} ms (preprocessing: $preprocessingMs ms, inference: $inferenceMs ms)',
);
} catch (e, s) {
_logger.severe(
'Error while running inference (PlatformPlugin: ${MlModel.usePlatformPlugin})',
e,
s,
);
_logger.severe('Error while running inference (PlatformPlugin: ${MlModel.usePlatformPlugin})', e, s);
throw YOLOFaceInterpreterRunException();
}
try {
Expand All @@ -92,20 +95,26 @@ class FaceDetectionService extends MlModel {
}
}

static List<List<double>>? _runFFIBasedPredict(
Uint8List inputImageList,
List<int> inputImageShape,
static List<List<List<double>>>? _runFFIBasedPredict(
int sessionAddress,
Float32List inputImageList,
) {
const inputShape = [
1,
3,
kInputHeight,
kInputWidth,
];
final inputOrt = OrtValueTensor.createTensorWithDataList(
inputImageList,
inputImageShape,
inputShape,
);
final inputs = {'input': inputOrt};
final runOptions = OrtRunOptions();
final session = OrtSession.fromAddress(sessionAddress);
final List<OrtValue?> outputs = session.run(runOptions, inputs);
final result = outputs[0]?.value as List<List<double>>; // [detections, 16]
final result =
outputs[0]?.value as List<List<List<double>>>; // [1, 25200, 16]
inputOrt.release();
runOptions.release();
for (var element in outputs) {
Expand All @@ -115,36 +124,41 @@ class FaceDetectionService extends MlModel {
return result;
}

static Future<List<List<double>>> _runPlatformPluginPredict(
Uint8List inputImageList,
List<int> inputImageShape,
static Future<List<List<List<double>>>> _runPlatformPluginPredict(
Float32List inputImageList,
) async {
final OnnxDart plugin = OnnxDart();
final result = await plugin.predictRgba(
final result = await plugin.predict(
inputImageList,
Int32List.fromList(inputImageShape),
_modelName,
);

final int resultLength = result!.length;
assert(resultLength % 16 == 0);
final int detections = resultLength ~/ 16;
assert(resultLength % 25200 * 16 == 0);
const int outerLength = 1;
const int middleLength = 25200;
const int innerLength = 16;
return List.generate(
detections,
(index) => result.sublist(index * 16, (index + 1) * 16).toList(),
outerLength,
(_) => List.generate(
middleLength,
(j) => result.sublist(j * innerLength, (j + 1) * innerLength).toList(),
),
);
}

static List<FaceDetectionRelative> _yoloPostProcessOutputs(
List<List<double>> nestedResults,
List<List<List<double>>> nestedResults,
Dimensions scaledSize,
) {
final firstResults = nestedResults[0]; // [25200, 16]

// Filter output
final relativeDetections = _yoloOnnxFilterExtractDetections(
var relativeDetections = _yoloOnnxFilterExtractDetections(
kMinScoreSigmoidThreshold,
kInputWidth,
kInputHeight,
results: nestedResults,
results: firstResults,
);

// Account for the fact that the aspect ratio was maintained
Expand All @@ -158,15 +172,13 @@ class FaceDetectionService extends MlModel {
);
}

return relativeDetections;
}

static Dimensions _getScaledSize(int imageWidth, int imageHeight) {
final scale = min(kInputWidth / imageWidth, kInputHeight / imageHeight);
final scaledWidth = (imageWidth * scale).round().clamp(0, kInputWidth);
final scaledHeight = (imageHeight * scale).round().clamp(0, kInputHeight);
// Non-maximum suppression to remove duplicate detections
relativeDetections = naiveNonMaxSuppression(
detections: relativeDetections,
iouThreshold: kIouThreshold,
);

return Dimensions(width: scaledWidth, height: scaledHeight);
return relativeDetections;
}
}

Expand Down Expand Up @@ -252,4 +264,4 @@ List<FaceDetectionRelative> _yoloOnnxFilterExtractDetections(
}

return outputDetections;
}
}
1 change: 0 additions & 1 deletion mobile/lib/services/machine_learning/ml_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ abstract class MlModel {
ONNXEnvFFI.instance.initONNX(modelName);
try {
final sessionOptions = OrtSessionOptions()
..appendCPUProvider(CPUFlags.useArena)
..setInterOpNumThreads(1)
..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(
Expand Down
Loading