From 4d258ca248c2e7c45284836bf8e032b9defaae19 Mon Sep 17 00:00:00 2001 From: Kagami Hiiragi Date: Mon, 14 Jan 2019 00:11:13 +0300 Subject: [PATCH] Add ClassifyThreshold Fixes #9 --- classify.cc | 19 ++++++++++++++----- classify.h | 3 ++- face.go | 10 +++++++++- face_test.go | 40 +++++++++++++++++++++++++++++----------- facerec.cc | 10 ++++------ facerec.h | 2 +- 6 files changed, 59 insertions(+), 25 deletions(-) diff --git a/classify.cc b/classify.cc index 75fb9cf..3c4034a 100644 --- a/classify.cc +++ b/classify.cc @@ -4,28 +4,37 @@ int classify( const std::vector& samples, const std::unordered_map& cats, - const descriptor& test_sample + const descriptor& test_sample, + float tolerance ) { - std::vector> distances; + if (samples.size() == 0) + return -1; + + std::vector> distances; distances.reserve(samples.size()); auto dist_func = dlib::squared_euclidean_distance(); int idx = 0; for (const auto& sample : samples) { - double dist = dist_func(sample, test_sample); + float dist = dist_func(sample, test_sample); + if (dist < tolerance) + continue; distances.push_back({idx, dist}); idx++; } + if (distances.size() == 0) + return -1; + std::sort( distances.begin(), distances.end(), [](const auto a, const auto b) { return a.second < b.second; } ); int len = std::min((int)distances.size(), 10); - std::unordered_map> hits_by_cat; + std::unordered_map> hits_by_cat; for (int i = 0; i < len; i++) { int idx = distances[i].first; - double dist = distances[i].second; + float dist = distances[i].second; auto cat = cats.find(idx); if (cat == cats.end()) continue; diff --git a/classify.h b/classify.h index 668379a..e21fbc0 100644 --- a/classify.h +++ b/classify.h @@ -7,5 +7,6 @@ typedef dlib::matrix descriptor; int classify( const std::vector& samples, const std::unordered_map& cats, - const descriptor& test_sample + const descriptor& test_sample, + float tolerance ); diff --git a/face.go b/face.go index cf65ea8..1370392 100644 --- a/face.go +++ b/face.go @@ -168,7 +168,15 @@ func (rec *Recognizer) SetSamples(samples []Descriptor, cats []int32) { // returned if no match. Thread-safe. func (rec *Recognizer) Classify(testSample Descriptor) int { cTestSample := (*C.float)(unsafe.Pointer(&testSample)) - return int(C.facerec_classify(rec.ptr, cTestSample)) + return int(C.facerec_classify(rec.ptr, cTestSample, -1)) +} + +// Same as Classify but allows to specify how much distance between +// faces to consider it a match. Start with 0.6 if not sure. +func (rec *Recognizer) ClassifyThreshold(testSample Descriptor, tolerance float32) int { + cTestSample := (*C.float)(unsafe.Pointer(&testSample)) + cTolerance := C.float(tolerance) + return int(C.facerec_classify(rec.ptr, cTestSample, cTolerance)) } // Close frees resources taken by the Recognizer. Safe to call multiple diff --git a/face_test.go b/face_test.go index 1f112d6..bb24020 100644 --- a/face_test.go +++ b/face_test.go @@ -115,16 +115,17 @@ func getTrainData(idata *IdolData) (tdata *TrainData) { return } -func recognizeAndClassify(fpath string) (catID *int, err error) { +func recognizeAndClassify(fpath string, tolerance float32) (id int, err error) { + id = -1 f, err := rec.RecognizeSingleFile(fpath) if err != nil || f == nil { return } - id := rec.Classify(f.Descriptor) - if id < 0 { - return + if tolerance < 0 { + id = rec.Classify(f.Descriptor) + } else { + id = rec.ClassifyThreshold(f.Descriptor, tolerance) } - catID = &id return } @@ -170,8 +171,8 @@ func TestNumFaces(t *testing.T) { func TestEmptyClassify(t *testing.T) { var sample face.Descriptor id := rec.Classify(sample) - if id != -1 { - t.Fatalf("expected -1 but got %d", id) + if id >= 0 { + t.Fatalf("Shouldn't recognize but got %d category", id) } } @@ -189,15 +190,15 @@ func TestIdols(t *testing.T) { expectedIname := names[0] expectedBname := names[1] - catID, err := recognizeAndClassify(getTPath(fname)) + catID, err := recognizeAndClassify(getTPath(fname), -1) if err != nil { - t.Fatal(err) + t.Fatalf("Can't recognize: %v", err) } - if catID == nil { + if catID < 0 { t.Errorf("%s: expected ā€œ%sā€ but not recognized", fname, expected) return } - idolID := tdata.labels[*catID] + idolID := tdata.labels[catID] idol := idata.byID[idolID] actualIname := idol.Name actualBname := idol.BandName @@ -210,6 +211,23 @@ func TestIdols(t *testing.T) { } } +func TestClassifyThreshold(t *testing.T) { + id, err := recognizeAndClassify(getTPath("nana.jpg"), 0.8) + if err != nil { + t.Fatalf("Can't recognize: %v", err) + } + if id >= 0 { + t.Fatalf("Shouldn't recognize but got %d category", id) + } + id, err = recognizeAndClassify(getTPath("nana.jpg"), 0.1) + if err != nil { + t.Fatalf("Can't recognize: %v", err) + } + if id < 0 { + t.Fatalf("Should have recognized but got %d category", id) + } +} + func TestClose(t *testing.T) { rec.Close() } diff --git a/facerec.cc b/facerec.cc index d9e625b..144070b 100644 --- a/facerec.cc +++ b/facerec.cc @@ -93,11 +93,9 @@ class FaceRec { cats_ = std::move(cats); } - int Classify(const descriptor& test_sample) { + int Classify(const descriptor& test_sample, float tolerance) { std::shared_lock lock(samples_mutex_); - if (samples_.size() == 0) - return -1; - return classify(samples_, cats_, test_sample); + return classify(samples_, cats_, test_sample, tolerance); } private: std::mutex detector_mutex_; @@ -187,10 +185,10 @@ void facerec_set_samples( cls->SetSamples(std::move(samples), std::move(cats)); } -int facerec_classify(facerec* rec, const float* c_test_sample) { +int facerec_classify(facerec* rec, const float* c_test_sample, float tolerance) { FaceRec* cls = (FaceRec*)(rec->cls); descriptor test_sample = mat(c_test_sample, DESCR_LEN, 1); - return cls->Classify(test_sample); + return cls->Classify(test_sample, tolerance); } void facerec_free(facerec* rec) { diff --git a/facerec.h b/facerec.h index b5b7c9b..0674f7c 100644 --- a/facerec.h +++ b/facerec.h @@ -27,7 +27,7 @@ typedef struct faceret { facerec* facerec_init(const char* model_dir); faceret* facerec_recognize(facerec* rec, const uint8_t* img_data, int len, int max_faces); void facerec_set_samples(facerec* rec, const float* descriptors, const int32_t* cats, int len); -int facerec_classify(facerec* rec, const float* descriptor); +int facerec_classify(facerec* rec, const float* descriptor, float tolerance); void facerec_free(facerec* rec); #ifdef __cplusplus