Skip to content

Commit

Permalink
add test for external modules (#3699)
Browse files Browse the repository at this point in the history
Summary:

External modules should work fine as plug-ins to Faiss in the following cases:

* additional objects that can be passed in as callbacks to Faiss

* functions that use `faiss.swig_ptr` to pass in arrays.

The `swig_ptr` functionality does not always work well (also depending on the platform). Therefore this diff adds a small external swig file to test that everything works smoothly on that end.

Differential Revision: D60379753
  • Loading branch information
mdouze authored and facebook-github-bot committed Oct 1, 2024
1 parent 0df5d24 commit f1c75ae
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 0 deletions.
122 changes: 122 additions & 0 deletions faiss/python/faiss_example_external_module.swig
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@

%module faiss_example_external_module;


// Put C++ includes here
%{

#include <faiss/impl/FaissException.h>
#include <faiss/impl/IDSelector.h>

%}

#pragma SWIG nowarn=322


typedef signed char int8_t ;

// to get uint32_t and friends
%include <stdint.i>


// This means: assume what's declared in these .h files is provided
// by the Faiss module.
%import(module="faiss") "faiss/MetricType.h"
%import(module="faiss") "faiss/impl/IDSelector.h"

// functions to be parsed here

// This is important to release GIL and do Faiss exception handing
%exception {
Py_BEGIN_ALLOW_THREADS
try {
$action
} catch(faiss::FaissException & e) {
PyEval_RestoreThread(_save);

if (PyErr_Occurred()) {
// some previous code already set the error type.
} else {
PyErr_SetString(PyExc_RuntimeError, e.what());
}
SWIG_fail;
} catch(std::bad_alloc & ba) {
PyEval_RestoreThread(_save);
PyErr_SetString(PyExc_MemoryError, "std::bad_alloc");
SWIG_fail;
}
Py_END_ALLOW_THREADS
}


// any class or function declared below will be made available
// in the module.
%inline %{

struct IDSelectorModulo : faiss::IDSelector {
int mod;

IDSelectorModulo(int mod): mod(mod) {}

bool is_member(faiss::idx_t id) const {
return id % mod == 0;
}

~IDSelectorModulo() override {}
};

faiss::idx_t sum_of_idx(size_t n, const faiss::idx_t *tab) {
faiss::idx_t sum = 0;
for(size_t i = 0; i < n; i++) {
sum += tab[i];
}
return sum;
}

float sum_of_float32(size_t n, const float *tab) {
float sum = 0;
for(size_t i = 0; i < n; i++) {
sum += tab[i];
}
return sum;
}

double sum_of_float64(size_t n, const double *tab) {
double sum = 0;
for(size_t i = 0; i < n; i++) {
sum += tab[i];
}
return sum;
}

%}

/**********************************************
* To test if passing a swig_ptr on all array types works
**********************************************/

%define SUM_OF_TYPE(ty)

%inline %{

ty##_t sum_of_##ty (size_t n, const ty##_t * tab) {
ty##_t sum = 0;
for(size_t i = 0; i < n; i++) {
sum += tab[i];
}
return sum;
}

%}

%enddef

SUM_OF_TYPE(uint8);
SUM_OF_TYPE(uint16);
SUM_OF_TYPE(uint32);
SUM_OF_TYPE(uint64);

SUM_OF_TYPE(int8);
SUM_OF_TYPE(int16);
SUM_OF_TYPE(int32);
SUM_OF_TYPE(int64);
62 changes: 62 additions & 0 deletions tests/test_external_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

import unittest

import numpy as np

import faiss
import faiss_example_external_module as external_module


class TestCustomIDSelector(unittest.TestCase):
""" test if we can construct a custom IDSelector """

def test_IDSelector(self):
ids = external_module.IDSelectorModulo(3)
self.assertFalse(ids.is_member(1))
self.assertTrue(ids.is_member(3))


class TestArrayConversions(unittest.TestCase):

def test_idx_array(self):
tab = np.arange(10).astype('int64')
new_sum = external_module.sum_of_idx(len(tab), faiss.swig_ptr(tab))
self.assertEqual(new_sum, tab.sum())

def do_array_test(self, ty):
tab = np.arange(10).astype(ty)
func = getattr(external_module, 'sum_of_' + ty)
print("perceived type", faiss.swig_ptr(tab))
new_sum = func(len(tab), faiss.swig_ptr(tab))
self.assertEqual(new_sum, tab.sum())

def test_sum_uint8(self):
self.do_array_test('uint8')

def test_sum_uint16(self):
self.do_array_test('uint16')

def test_sum_uint32(self):
self.do_array_test('uint32')

def test_sum_uint64(self):
self.do_array_test('uint64')

# this conversion does not work
def test_sum_int8(self):
self.do_array_test('int8')

def test_sum_int16(self):
self.do_array_test('int16')

def test_sum_int32(self):
self.do_array_test('int32')

def test_sum_int64(self):
self.do_array_test('int64')

def test_sum_float32(self):
self.do_array_test('float32')

def test_sum_float64(self):
self.do_array_test('float64')

0 comments on commit f1c75ae

Please sign in to comment.