Skip to content

Commit

Permalink
Merge pull request facebookresearch#2 from justinaustin/remove_ids
Browse files Browse the repository at this point in the history
Adding remove_ids method to Index trait
  • Loading branch information
Enet4 authored Dec 9, 2018
2 parents a31410e + 8f64449 commit 137bc35
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 1 deletion.
31 changes: 31 additions & 0 deletions src/index/id_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ use index::{
AssignSearchResult, ConcurrentIndex, CpuIndex, FromInnerPtr, Idx, Index, NativeIndex, RangeSearchResult,
SearchResult,
};
use selector::IdSelector;

use std::marker::PhantomData;
use std::mem;
Expand Down Expand Up @@ -251,6 +252,18 @@ impl<I> Index for IdMap<I> {
Ok(())
}
}

fn remove_ids(&mut self, sel: &IdSelector) -> Result<(i64)> {
unsafe {
let mut n_removed = 0;
faiss_try!(faiss_Index_remove_ids(
self.inner_ptr(),
sel.inner_ptr(),
&mut n_removed
));
Ok(n_removed)
}
}
}

impl<I> ConcurrentIndex for IdMap<I>
Expand Down Expand Up @@ -308,6 +321,7 @@ where
mod tests {
use super::IdMap;
use index::{index_factory, Index};
use selector::IdSelector;
use MetricType;

#[test]
Expand Down Expand Up @@ -340,4 +354,21 @@ mod tests {
assert_eq!(result.labels, vec![9, 6, 3, 12, 15, 12, 15, 3, 6, 9]);
assert!(result.distances.iter().all(|x| *x > 0.));
}

#[test]
fn index_remove_ids() {
let mut index = index_factory(4, "Flat", MetricType::L2).unwrap();
let mut id_index = IdMap::new(index).unwrap();
let some_data = &[2.3_f32, 0.0, -1., 1., 1., 1., 1., 4.5, 2.3, 7.6, 1., 2.2];

let ids = &[4, 8, 12];

id_index.add_with_ids(some_data, ids).unwrap();
assert_eq!(id_index.ntotal(), 3);

let id_sel = IdSelector::batch(&[4, 12]).ok().unwrap();

id_index.remove_ids(&id_sel).unwrap();
assert_eq!(id_index.ntotal(), 1);
}
}
1 change: 1 addition & 0 deletions src/index/lsh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::{AssignSearchResult, ConcurrentIndex, CpuIndex, FromInnerPtr, Idx, In
NativeIndex, RangeSearchResult, SearchResult};
use error::{Error, Result};
use faiss_sys::*;
use selector::IdSelector;
use std::mem;
use std::ptr;

Expand Down
6 changes: 5 additions & 1 deletion src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

use error::{Error, Result};
use metric::MetricType;
use selector::IdSelector;
use std::ffi::CString;
use std::os::raw::c_uint;
use std::ptr;
Expand All @@ -36,7 +37,7 @@ pub type Idx = idx_t;
/// Although all methods appear to be available for all index implementations,
/// some methods may not be supported. For instance, a [`FlatIndex`] stores
/// vectors sequentially, and so does not support `add_with_ids` nor
/// `remove_with_ids`. Users are advised to read the Faiss wiki pages in order
/// `remove_ids`. Users are advised to read the Faiss wiki pages in order
/// to understand which index algorithms support which operations.
///
/// [`FlatIndex`]: flat/struct.FlatIndex.html
Expand Down Expand Up @@ -79,6 +80,9 @@ pub trait Index {

/// Clear the entire index.
fn reset(&mut self) -> Result<()>;

/// Remove data vectors represented by IDs.
fn remove_ids(&mut self, sel: &IdSelector) -> Result<i64>;
}

/// Sub-trait for native implementations of a Faiss index.
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ pub mod cluster;
pub mod error;
pub mod index;
pub mod metric;
pub mod selector;

#[cfg(feature = "gpu")]
pub mod gpu;
Expand Down
12 changes: 12 additions & 0 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ macro_rules! impl_native_index {
Ok(())
}
}

fn remove_ids(&mut self, sel: &IdSelector) -> Result<i64> {
unsafe {
let mut n_removed = 0;
faiss_try!(faiss_Index_remove_ids(
self.inner_ptr(),
sel.inner_ptr(),
&mut n_removed
));
Ok(n_removed)
}
}
}
};
}
Expand Down
50 changes: 50 additions & 0 deletions src/selector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//! Abstract Faiss ID selector
use error::Result;
use faiss_sys::*;
use index::Idx;
use std::os::raw::c_long;
use std::ptr;

/// Abstraction over IDSelectorRange and IDSelectorBatch
#[derive(Debug)]
pub struct IdSelector {
inner: *mut FaissIDSelector,
}

impl IdSelector {
/// Create new range selector
pub fn range(min: Idx, max: Idx) -> Result<Self> {
let mut p_sel = ptr::null_mut();
unsafe {
faiss_try!(faiss_IDSelectorRange_new(&mut p_sel, min, max));
};
Ok(IdSelector { inner: p_sel as *mut _})
}

/// Create new batch selector
pub fn batch(indices: &[Idx]) -> Result<Self> {
let n = indices.len() as c_long;
let mut p_sel = ptr::null_mut();
unsafe {
faiss_try!(faiss_IDSelectorBatch_new(&mut p_sel, n, indices.as_ptr()));
};
Ok(IdSelector { inner: p_sel as *mut _})
}

/// Return the inner pointer
pub fn inner_ptr(&self) -> *mut FaissIDSelector {
self.inner
}

}

impl Drop for IdSelector {
fn drop(&mut self) {
unsafe {
faiss_IDSelector_free(self.inner);
}
}
}

unsafe impl Send for IdSelector {}
unsafe impl Sync for IdSelector {}

0 comments on commit 137bc35

Please sign in to comment.