-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add test for external modules (#3699)
Summary: External modules should work fine as plug-ins to Faiss in the following cases: * additional objects that can be passed in as callbacks to Faiss * functions that use `faiss.swig_ptr` to pass in arrays. The `swig_ptr` functionality does not always work well (also depending on the platform). Therefore this diff adds a small external swig file to test that everything works smoothly on that end. Differential Revision: D60379753
- Loading branch information
1 parent
0df5d24
commit f1c75ae
Showing
2 changed files
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
|
||
%module faiss_example_external_module; | ||
|
||
|
||
// Put C++ includes here | ||
%{ | ||
|
||
#include <faiss/impl/FaissException.h> | ||
#include <faiss/impl/IDSelector.h> | ||
|
||
%} | ||
|
||
#pragma SWIG nowarn=322 | ||
|
||
|
||
typedef signed char int8_t ; | ||
|
||
// to get uint32_t and friends | ||
%include <stdint.i> | ||
|
||
|
||
// This means: assume what's declared in these .h files is provided | ||
// by the Faiss module. | ||
%import(module="faiss") "faiss/MetricType.h" | ||
%import(module="faiss") "faiss/impl/IDSelector.h" | ||
|
||
// functions to be parsed here | ||
|
||
// This is important to release GIL and do Faiss exception handing | ||
%exception { | ||
Py_BEGIN_ALLOW_THREADS | ||
try { | ||
$action | ||
} catch(faiss::FaissException & e) { | ||
PyEval_RestoreThread(_save); | ||
|
||
if (PyErr_Occurred()) { | ||
// some previous code already set the error type. | ||
} else { | ||
PyErr_SetString(PyExc_RuntimeError, e.what()); | ||
} | ||
SWIG_fail; | ||
} catch(std::bad_alloc & ba) { | ||
PyEval_RestoreThread(_save); | ||
PyErr_SetString(PyExc_MemoryError, "std::bad_alloc"); | ||
SWIG_fail; | ||
} | ||
Py_END_ALLOW_THREADS | ||
} | ||
|
||
|
||
// any class or function declared below will be made available | ||
// in the module. | ||
%inline %{ | ||
|
||
struct IDSelectorModulo : faiss::IDSelector { | ||
int mod; | ||
|
||
IDSelectorModulo(int mod): mod(mod) {} | ||
|
||
bool is_member(faiss::idx_t id) const { | ||
return id % mod == 0; | ||
} | ||
|
||
~IDSelectorModulo() override {} | ||
}; | ||
|
||
faiss::idx_t sum_of_idx(size_t n, const faiss::idx_t *tab) { | ||
faiss::idx_t sum = 0; | ||
for(size_t i = 0; i < n; i++) { | ||
sum += tab[i]; | ||
} | ||
return sum; | ||
} | ||
|
||
float sum_of_float32(size_t n, const float *tab) { | ||
float sum = 0; | ||
for(size_t i = 0; i < n; i++) { | ||
sum += tab[i]; | ||
} | ||
return sum; | ||
} | ||
|
||
double sum_of_float64(size_t n, const double *tab) { | ||
double sum = 0; | ||
for(size_t i = 0; i < n; i++) { | ||
sum += tab[i]; | ||
} | ||
return sum; | ||
} | ||
|
||
%} | ||
|
||
/********************************************** | ||
* To test if passing a swig_ptr on all array types works | ||
**********************************************/ | ||
|
||
%define SUM_OF_TYPE(ty) | ||
|
||
%inline %{ | ||
|
||
ty##_t sum_of_##ty (size_t n, const ty##_t * tab) { | ||
ty##_t sum = 0; | ||
for(size_t i = 0; i < n; i++) { | ||
sum += tab[i]; | ||
} | ||
return sum; | ||
} | ||
|
||
%} | ||
|
||
%enddef | ||
|
||
SUM_OF_TYPE(uint8); | ||
SUM_OF_TYPE(uint16); | ||
SUM_OF_TYPE(uint32); | ||
SUM_OF_TYPE(uint64); | ||
|
||
SUM_OF_TYPE(int8); | ||
SUM_OF_TYPE(int16); | ||
SUM_OF_TYPE(int32); | ||
SUM_OF_TYPE(int64); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
|
||
import faiss | ||
import faiss_example_external_module as external_module | ||
|
||
|
||
class TestCustomIDSelector(unittest.TestCase): | ||
""" test if we can construct a custom IDSelector """ | ||
|
||
def test_IDSelector(self): | ||
ids = external_module.IDSelectorModulo(3) | ||
self.assertFalse(ids.is_member(1)) | ||
self.assertTrue(ids.is_member(3)) | ||
|
||
|
||
class TestArrayConversions(unittest.TestCase): | ||
|
||
def test_idx_array(self): | ||
tab = np.arange(10).astype('int64') | ||
new_sum = external_module.sum_of_idx(len(tab), faiss.swig_ptr(tab)) | ||
self.assertEqual(new_sum, tab.sum()) | ||
|
||
def do_array_test(self, ty): | ||
tab = np.arange(10).astype(ty) | ||
func = getattr(external_module, 'sum_of_' + ty) | ||
print("perceived type", faiss.swig_ptr(tab)) | ||
new_sum = func(len(tab), faiss.swig_ptr(tab)) | ||
self.assertEqual(new_sum, tab.sum()) | ||
|
||
def test_sum_uint8(self): | ||
self.do_array_test('uint8') | ||
|
||
def test_sum_uint16(self): | ||
self.do_array_test('uint16') | ||
|
||
def test_sum_uint32(self): | ||
self.do_array_test('uint32') | ||
|
||
def test_sum_uint64(self): | ||
self.do_array_test('uint64') | ||
|
||
# this conversion does not work | ||
def test_sum_int8(self): | ||
self.do_array_test('int8') | ||
|
||
def test_sum_int16(self): | ||
self.do_array_test('int16') | ||
|
||
def test_sum_int32(self): | ||
self.do_array_test('int32') | ||
|
||
def test_sum_int64(self): | ||
self.do_array_test('int64') | ||
|
||
def test_sum_float32(self): | ||
self.do_array_test('float32') | ||
|
||
def test_sum_float64(self): | ||
self.do_array_test('float64') |