Skip to content

Commit

Permalink
Add ClassifyThreshold
Browse files Browse the repository at this point in the history
Fixes Kagami#9
  • Loading branch information
Kagami committed Jan 13, 2019
1 parent 08b2496 commit 4d258ca
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 25 deletions.
19 changes: 14 additions & 5 deletions classify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,37 @@
int classify(
const std::vector<descriptor>& samples,
const std::unordered_map<int, int>& cats,
const descriptor& test_sample
const descriptor& test_sample,
float tolerance
) {
std::vector<std::pair<int, double>> distances;
if (samples.size() == 0)
return -1;

std::vector<std::pair<int, float>> distances;
distances.reserve(samples.size());
auto dist_func = dlib::squared_euclidean_distance();
int idx = 0;
for (const auto& sample : samples) {
double dist = dist_func(sample, test_sample);
float dist = dist_func(sample, test_sample);
if (dist < tolerance)
continue;
distances.push_back({idx, dist});
idx++;
}

if (distances.size() == 0)
return -1;

std::sort(
distances.begin(), distances.end(),
[](const auto a, const auto b) { return a.second < b.second; }
);

int len = std::min((int)distances.size(), 10);
std::unordered_map<int, std::pair<int, double>> hits_by_cat;
std::unordered_map<int, std::pair<int, float>> hits_by_cat;
for (int i = 0; i < len; i++) {
int idx = distances[i].first;
double dist = distances[i].second;
float dist = distances[i].second;
auto cat = cats.find(idx);
if (cat == cats.end())
continue;
Expand Down
3 changes: 2 additions & 1 deletion classify.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ typedef dlib::matrix<float,0,1> descriptor;
int classify(
const std::vector<descriptor>& samples,
const std::unordered_map<int, int>& cats,
const descriptor& test_sample
const descriptor& test_sample,
float tolerance
);
10 changes: 9 additions & 1 deletion face.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,15 @@ func (rec *Recognizer) SetSamples(samples []Descriptor, cats []int32) {
// returned if no match. Thread-safe.
func (rec *Recognizer) Classify(testSample Descriptor) int {
cTestSample := (*C.float)(unsafe.Pointer(&testSample))
return int(C.facerec_classify(rec.ptr, cTestSample))
return int(C.facerec_classify(rec.ptr, cTestSample, -1))
}

// Same as Classify but allows to specify how much distance between
// faces to consider it a match. Start with 0.6 if not sure.
func (rec *Recognizer) ClassifyThreshold(testSample Descriptor, tolerance float32) int {
cTestSample := (*C.float)(unsafe.Pointer(&testSample))
cTolerance := C.float(tolerance)
return int(C.facerec_classify(rec.ptr, cTestSample, cTolerance))
}

// Close frees resources taken by the Recognizer. Safe to call multiple
Expand Down
40 changes: 29 additions & 11 deletions face_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,17 @@ func getTrainData(idata *IdolData) (tdata *TrainData) {
return
}

func recognizeAndClassify(fpath string) (catID *int, err error) {
func recognizeAndClassify(fpath string, tolerance float32) (id int, err error) {
id = -1
f, err := rec.RecognizeSingleFile(fpath)
if err != nil || f == nil {
return
}
id := rec.Classify(f.Descriptor)
if id < 0 {
return
if tolerance < 0 {
id = rec.Classify(f.Descriptor)
} else {
id = rec.ClassifyThreshold(f.Descriptor, tolerance)
}
catID = &id
return
}

Expand Down Expand Up @@ -170,8 +171,8 @@ func TestNumFaces(t *testing.T) {
func TestEmptyClassify(t *testing.T) {
var sample face.Descriptor
id := rec.Classify(sample)
if id != -1 {
t.Fatalf("expected -1 but got %d", id)
if id >= 0 {
t.Fatalf("Shouldn't recognize but got %d category", id)
}
}

Expand All @@ -189,15 +190,15 @@ func TestIdols(t *testing.T) {
expectedIname := names[0]
expectedBname := names[1]

catID, err := recognizeAndClassify(getTPath(fname))
catID, err := recognizeAndClassify(getTPath(fname), -1)
if err != nil {
t.Fatal(err)
t.Fatalf("Can't recognize: %v", err)
}
if catID == nil {
if catID < 0 {
t.Errorf("%s: expected “%s” but not recognized", fname, expected)
return
}
idolID := tdata.labels[*catID]
idolID := tdata.labels[catID]
idol := idata.byID[idolID]
actualIname := idol.Name
actualBname := idol.BandName
Expand All @@ -210,6 +211,23 @@ func TestIdols(t *testing.T) {
}
}

func TestClassifyThreshold(t *testing.T) {
id, err := recognizeAndClassify(getTPath("nana.jpg"), 0.8)
if err != nil {
t.Fatalf("Can't recognize: %v", err)
}
if id >= 0 {
t.Fatalf("Shouldn't recognize but got %d category", id)
}
id, err = recognizeAndClassify(getTPath("nana.jpg"), 0.1)
if err != nil {
t.Fatalf("Can't recognize: %v", err)
}
if id < 0 {
t.Fatalf("Should have recognized but got %d category", id)
}
}

func TestClose(t *testing.T) {
rec.Close()
}
10 changes: 4 additions & 6 deletions facerec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,9 @@ class FaceRec {
cats_ = std::move(cats);
}

int Classify(const descriptor& test_sample) {
int Classify(const descriptor& test_sample, float tolerance) {
std::shared_lock<std::shared_mutex> lock(samples_mutex_);
if (samples_.size() == 0)
return -1;
return classify(samples_, cats_, test_sample);
return classify(samples_, cats_, test_sample, tolerance);
}
private:
std::mutex detector_mutex_;
Expand Down Expand Up @@ -187,10 +185,10 @@ void facerec_set_samples(
cls->SetSamples(std::move(samples), std::move(cats));
}

int facerec_classify(facerec* rec, const float* c_test_sample) {
int facerec_classify(facerec* rec, const float* c_test_sample, float tolerance) {
FaceRec* cls = (FaceRec*)(rec->cls);
descriptor test_sample = mat(c_test_sample, DESCR_LEN, 1);
return cls->Classify(test_sample);
return cls->Classify(test_sample, tolerance);
}

void facerec_free(facerec* rec) {
Expand Down
2 changes: 1 addition & 1 deletion facerec.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ typedef struct faceret {
facerec* facerec_init(const char* model_dir);
faceret* facerec_recognize(facerec* rec, const uint8_t* img_data, int len, int max_faces);
void facerec_set_samples(facerec* rec, const float* descriptors, const int32_t* cats, int len);
int facerec_classify(facerec* rec, const float* descriptor);
int facerec_classify(facerec* rec, const float* descriptor, float tolerance);
void facerec_free(facerec* rec);

#ifdef __cplusplus
Expand Down

0 comments on commit 4d258ca

Please sign in to comment.