Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiny face #1906

Merged
merged 9 commits into from
May 28, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@ import "package:ml_linalg/linalg.dart";
/// Calculates the cosine distance between two embeddings/vectors using SIMD from ml_linalg
///
/// WARNING: This assumes both vectors are already normalized!
/// WARNING: For even more performance, consider calculating the logic below inline!
@pragma("vm:prefer-inline")
double cosineDistanceSIMD(Vector vector1, Vector vector2) {
if (vector1.length != vector2.length) {
throw ArgumentError('Vectors must be the same length');
}

return 1 - vector1.dot(vector2);
}

/// Calculates the cosine distance between two embeddings/vectors using SIMD from ml_linalg
///
/// WARNING: Only use when you're not sure if vectors are normalized. If you're sure they are, use [cosineDistanceSIMD] instead for better performance.
/// WARNING: Only use when you're not sure if vectors are normalized. If you're sure they are, use [cosineDistanceSIMD] instead for better performance, or inline for best performance.
double cosineDistanceSIMDSafe(Vector vector1, Vector vector2) {
if (vector1.length != vector2.length) {
throw ArgumentError('Vectors must be the same length');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,16 @@ import "package:logging/logging.dart";
import "package:ml_linalg/dtype.dart";
import "package:ml_linalg/vector.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:simple_cluster/simple_cluster.dart";
import "package:synchronized/synchronized.dart";

class FaceInfo {
final String faceID;
final double? faceScore;
final double? blurValue;
final bool? badFace;
final List<double>? embedding;
final Vector? vEmbedding;
int? clusterId;
String? closestFaceId;
Expand All @@ -33,14 +30,13 @@ class FaceInfo {
this.faceScore,
this.blurValue,
this.badFace,
this.embedding,
this.vEmbedding,
this.clusterId,
this.fileCreationTime,
});
}

enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
enum ClusterOperation { linearIncrementalClustering }

class ClusteringResult {
final Map<String, int> newFaceIdToCluster;
Expand Down Expand Up @@ -129,10 +125,6 @@ class FaceClusteringService {
final result = FaceClusteringService.runLinearClustering(args);
sendPort.send(result);
break;
case ClusterOperation.dbscanClustering:
final result = FaceClusteringService._runDbscanClustering(args);
sendPort.send(result);
break;
}
} catch (e, stackTrace) {
sendPort
Expand Down Expand Up @@ -203,8 +195,6 @@ class FaceClusteringService {
/// Runs the clustering algorithm [runLinearClustering] on the given [input], in an isolate.
///
/// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
///
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
Future<ClusteringResult?> predictLinear(
Set<FaceInfoForClustering> input, {
Map<int, int>? fileIDToCreationTime,
Expand Down Expand Up @@ -401,55 +391,6 @@ class FaceClusteringService {
}
}

Future<List<List<String>>> predictDbscan(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
double eps = 0.3,
int minPts = 5,
}) async {
if (input.isEmpty) {
_logger.warning(
"DBSCAN Clustering dataset of embeddings is empty, returning empty list.",
);
return [];
}
if (isRunning) {
_logger.warning(
"DBSCAN Clustering is already running, returning empty list.",
);
return [];
}

isRunning = true;

// Clustering inside the isolate
_logger.info(
"Start DBSCAN clustering on ${input.length} embeddings inside computer isolate",
);
final stopwatchClustering = Stopwatch()..start();
// final Map<String, int> faceIdToCluster =
// await _runLinearClusteringInComputer(input);
final List<List<String>> clusterFaceIDs = await _runInIsolate(
(
ClusterOperation.dbscanClustering,
{
'input': input,
'fileIDToCreationTime': fileIDToCreationTime,
'eps': eps,
'minPts': minPts,
}
),
);
// return _runLinearClusteringInComputer(input);
_logger.info(
'DBSCAN Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds',
);

isRunning = false;

return clusterFaceIDs;
}

static ClusteringResult? runLinearClustering(Map args) {
// final input = args['input'] as Map<String, (int?, Uint8List)>;
final input = args['input'] as Set<FaceInfoForClustering>;
Expand Down Expand Up @@ -562,19 +503,10 @@ class FaceClusteringService {
if (i % 250 == 0) {
log("[ClusterIsolate] ${DateTime.now()} Processed ${offset != null ? i + offset : i} faces");
}
// WARNING: The loop below is now O(n^2) so be very careful with anything you put in there!
for (int j = i - 1; j >= 0; j--) {
late double distance;
if (sortedFaceInfos[i].vEmbedding != null) {
distance = cosineDistanceSIMD(
sortedFaceInfos[i].vEmbedding!,
sortedFaceInfos[j].vEmbedding!,
);
} else {
distance = cosineDistForNormVectors(
sortedFaceInfos[i].embedding!,
sortedFaceInfos[j].embedding!,
);
}
final double distance = 1 -
sortedFaceInfos[i].vEmbedding!.dot(sortedFaceInfos[j].vEmbedding!);
if (distance < closestDistance) {
if (sortedFaceInfos[j].badFace! &&
distance > conservativeDistanceThreshold) {
Expand Down Expand Up @@ -814,10 +746,8 @@ class FaceClusteringService {
double closestDistance = double.infinity;
for (int j = 0; j < totalFaces; j++) {
if (i == j) continue;
final double distance = cosineDistanceSIMD(
faceInfos[i].vEmbedding!,
faceInfos[j].vEmbedding!,
);
final double distance =
1 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
if (distance < closestDistance) {
closestDistance = distance;
closestIdx = j;
Expand Down Expand Up @@ -870,10 +800,10 @@ class FaceClusteringService {
for (int i = 0; i < clusterIds.length; i++) {
for (int j = 0; j < clusterIds.length; j++) {
if (i == j) continue;
final double newDistance = cosineDistanceSIMD(
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1,
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
);
final double newDistance = 1 -
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!
.$1
.dot(clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1);
if (newDistance < distance) {
distance = newDistance;
clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
Expand Down Expand Up @@ -944,55 +874,6 @@ class FaceClusteringService {
newClusterIdToFaceIds: clusterIdToFaceIds,
);
}

static List<List<String>> _runDbscanClustering(Map args) {
final input = args['input'] as Map<String, Uint8List>;
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
final eps = args['eps'] as double;
final minPts = args['minPts'] as int;

log(
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
);

final DBSCAN dbscan = DBSCAN(
epsilon: eps,
minPoints: minPts,
distanceMeasure: cosineDistForNormVectors,
);

// Organize everything into a list of FaceInfo objects
final List<FaceInfo> faceInfos = [];
for (final entry in input.entries) {
faceInfos.add(
FaceInfo(
faceID: entry.key,
embedding: EVector.fromBuffer(entry.value).values,
fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
),
);
}

if (fileIDToCreationTime != null) {
_sortFaceInfosOnCreationTime(faceInfos);
}

// Get the embeddings
final List<List<double>> embeddings =
faceInfos.map((faceInfo) => faceInfo.embedding!).toList();

// Run the DBSCAN clustering
final List<List<int>> clusterOutput = dbscan.run(embeddings);
// final List<List<FaceInfo>> clusteredFaceInfos = clusterOutput
// .map((cluster) => cluster.map((idx) => faceInfos[idx]).toList())
// .toList();
final List<List<String>> clusteredFaceIDs = clusterOutput
.map((cluster) => cluster.map((idx) => faceInfos[idx].faceID).toList())
.toList();

return clusteredFaceIDs;
}
}

/// Sort the faceInfos based on fileCreationTime, in descending order, so newest faces are first
Expand Down
13 changes: 13 additions & 0 deletions mobile/lib/services/machine_learning/face_ml/face_ml_service.dart
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,19 @@ class FaceMlService {
min(offset + bucketSize, allFaceInfoForClustering.length),
);

if (faceInfoForClustering.every((face) => face.clusterId != null)) {
_logger.info('Everything in bucket $bucket is already clustered');
if (offset + bucketSize >= totalFaces) {
_logger.info('All faces clustered');
break;
} else {
_logger.info('Skipping to next bucket');
offset += offsetIncrement;
bucket++;
continue;
}
}

final clusteringResult =
await FaceClusteringService.instance.predictLinear(
faceInfoForClustering.toSet(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
Expand Down Expand Up @@ -434,7 +433,9 @@ class ClusterFeedbackService {
distanceThreshold: 0.22,
);

if (clusterResult == null || clusterResult.newClusterIdToFaceIds == null || clusterResult.isEmpty) {
if (clusterResult == null ||
clusterResult.newClusterIdToFaceIds == null ||
clusterResult.isEmpty) {
_logger.warning('No clusters found or something went wrong');
return ClusteringResult(newFaceIdToCluster: {});
}
Expand Down Expand Up @@ -537,8 +538,7 @@ class ClusterFeedbackService {
EVector.fromBuffer(clusterSummary.$1).values,
dtype: DType.float32,
);
final bigClustersMeanDistance =
cosineDistanceSIMD(biggestMean, currentMean);
final bigClustersMeanDistance = 1 - biggestMean.dot(currentMean);
_logger.info(
"Mean distance between biggest cluster and current cluster: $bigClustersMeanDistance",
);
Expand Down Expand Up @@ -595,8 +595,7 @@ class ClusterFeedbackService {
final List<double> trueDistances = [];
for (final biggestEmbedding in biggestSampledEmbeddings) {
for (final currentEmbedding in currentSampledEmbeddings) {
distances
.add(cosineDistanceSIMD(biggestEmbedding, currentEmbedding));
distances.add(1 - biggestEmbedding.dot(currentEmbedding));
trueDistances.add(
biggestEmbedding.distanceTo(
currentEmbedding,
Expand Down Expand Up @@ -789,7 +788,7 @@ class ClusterFeedbackService {
final List<double> distances = [];
for (final otherEmbedding in sampledOtherEmbeddings) {
for (final embedding in sampledEmbeddings) {
distances.add(cosineDistanceSIMD(embedding, otherEmbedding));
distances.add(1 - embedding.dot(otherEmbedding));
}
}
distances.sort();
Expand Down Expand Up @@ -1086,7 +1085,7 @@ class ClusterFeedbackService {
final fileIdToDistanceMap = {};
for (final entry in faceIdToVectorMap.entries) {
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
cosineDistanceSIMD(personAvg, entry.value);
1 - personAvg.dot(entry.value);
}
w?.log('calculated distances for cluster $clusterID');
suggestion.filesInCluster.sort((b, a) {
Expand Down Expand Up @@ -1141,7 +1140,7 @@ List<(int, double)> _calcSuggestionsMean(Map<String, dynamic> args) {
continue;
}
final Vector avg = clusterAvg[personCluster]!;
final distance = cosineDistanceSIMD(avg, otherAvg);
final distance = 1 - avg.dot(otherAvg);
comparisons++;
if (distance < maxClusterDistance) {
if (minDistance == null || distance < minDistance) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MachineLearningController {

static const kMaximumTemperature = 42; // 42 degree celsius
static const kMinimumBatteryLevel = 20; // 20%
static const kDefaultInteractionTimeout = Duration(seconds: 10);
static const kDefaultInteractionTimeout = Duration(seconds: 15);
static const kUnhealthyStates = ["over_heat", "over_voltage", "dead"];

bool _isDeviceHealthy = true;
Expand Down
6 changes: 5 additions & 1 deletion mobile/lib/ui/viewer/file_details/face_widget.dart
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class _FaceWidgetState extends State<FaceWidget> {
}
}

Future<Map<String, Uint8List>?> getFaceCrop() async {
Future<Map<String, Uint8List>?> getFaceCrop({int fetchAttempt = 1}) async {
try {
final Uint8List? cachedFace = faceCropCache.get(widget.face.faceID);
if (cachedFace != null) {
Expand Down Expand Up @@ -326,6 +326,10 @@ class _FaceWidgetState extends State<FaceWidget> {
error: e,
stackTrace: s,
);
resetPool(fullFile: true);
if (fetchAttempt <= retryLimit) {
return getFaceCrop(fetchAttempt: fetchAttempt + 1);
}
return null;
}
}
Expand Down
8 changes: 7 additions & 1 deletion mobile/lib/ui/viewer/file_details/faces_item_widget.dart
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ class _FacesItemWidgetState extends State<FacesItemWidget> {
}

Future<Map<String, Uint8List>?> getRelevantFaceCrops(
Iterable<Face> faces,
Iterable<Face> faces, {
int fetchAttempt = 1,
}
) async {
try {
final faceIdToCrop = <String, Uint8List>{};
Expand Down Expand Up @@ -223,6 +225,10 @@ class _FacesItemWidgetState extends State<FacesItemWidget> {
error: e,
stackTrace: s,
);
resetPool(fullFile: true);
if(fetchAttempt <= retryLimit) {
return getRelevantFaceCrops(faces, fetchAttempt: fetchAttempt + 1);
}
return null;
}
}
Expand Down
Loading