Skip to content

Commit

Permalink
Merge pull request #43 from ava57r/index-upcast
Browse files Browse the repository at this point in the history
Add UpcastIndex trait
  • Loading branch information
Enet4 authored Jul 23, 2021
2 parents 4aba00d + 6e17af3 commit 97753a9
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 6 deletions.
12 changes: 11 additions & 1 deletion src/index/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,21 @@ impl_concurrent_index!(FlatIndexImpl);
#[cfg(test)]
mod tests {
use super::FlatIndexImpl;
use crate::index::{index_factory, ConcurrentIndex, FromInnerPtr, Idx, Index, NativeIndex};
use crate::index::{
index_factory, ConcurrentIndex, FromInnerPtr, Idx, Index, NativeIndex, UpcastIndex,
};
use crate::metric::MetricType;

const D: u32 = 8;

#[test]
fn flat_index_from_upcast() {
let index = FlatIndexImpl::new_l2(D).unwrap();

let index_impl = index.upcast();
assert_eq!(index_impl.d(), D);
}

#[test]
fn flat_index_from_cast() {
let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
Expand Down
13 changes: 12 additions & 1 deletion src/index/ivf_flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ mod tests {

use super::IVFFlatIndexImpl;
use crate::index::flat::FlatIndexImpl;
use crate::index::{index_factory, ConcurrentIndex, Idx, Index};
use crate::index::{index_factory, ConcurrentIndex, Idx, Index, UpcastIndex};
use crate::MetricType;

const D: u32 = 8;
Expand Down Expand Up @@ -286,4 +286,15 @@ mod tests {
assert_eq!(index.is_trained(), true);
assert_eq!(index.ntotal(), 5);
}

#[test]
fn index_upcast() {
let q = FlatIndexImpl::new_l2(D).unwrap();
let index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap();
assert_eq!(index.d(), D);
assert_eq!(index.ntotal(), 0);

let index_impl = index.upcast();
assert_eq!(index_impl.d(), D);
}
}
32 changes: 31 additions & 1 deletion src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::selector::IdSelector;
use std::ffi::CString;
use std::fmt::{self, Display, Formatter, Write};
use std::os::raw::c_uint;
use std::ptr;
use std::{mem, ptr};

use faiss_sys::*;

Expand Down Expand Up @@ -467,6 +467,36 @@ impl TryFromInnerPtr for IndexImpl {
}
}

/// Index upcast trait.
///
/// If you need to store several different types of indexes in one collection,
/// you can cast all indexes to the common type `IndexImpl`.
/// # Examples
///
/// ```
/// # use faiss::{index::{IndexImpl, UpcastIndex}, FlatIndex, index_factory, MetricType};
/// let f1 = FlatIndex::new_l2(128).unwrap();
/// let f2 = index_factory(128, "Flat", MetricType::L2).unwrap();
/// let v: Vec<IndexImpl> = vec![
/// f1.upcast(),
/// f2,
/// ];
/// ```
///
pub trait UpcastIndex: NativeIndex {
/// Convert an index to the base `IndexImpl` type
fn upcast(self) -> IndexImpl;
}

impl<NI: NativeIndex> UpcastIndex for NI {
fn upcast(self) -> IndexImpl {
let inner_ptr = self.inner_ptr();
mem::forget(self);

unsafe { IndexImpl::from_inner_ptr(inner_ptr) }
}
}

impl_native_index!(IndexImpl);

impl_native_index_clone!(IndexImpl);
Expand Down
7 changes: 6 additions & 1 deletion src/index/pretransform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ mod tests {
const D: u32 = 8;

#[test]
fn pre_transform_index_from_cast() {
fn pre_transform_index_from_cast_upcast() {
let mut index = index_factory(D, "PCA4,Flat", MetricType::L2).unwrap();

let some_data = &[
Expand All @@ -315,6 +315,11 @@ mod tests {
assert_eq!(index.is_trained(), true);
assert_eq!(index.ntotal(), 5);
assert_eq!(index.d(), 8);

let index_impl = index.upcast();
assert_eq!(index_impl.is_trained(), true);
assert_eq!(index_impl.ntotal(), 5);
assert_eq!(index_impl.d(), 8);
}

#[test]
Expand Down
14 changes: 13 additions & 1 deletion src/index/refine_flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ where
#[cfg(test)]
mod tests {
use super::RefineFlatIndexImpl;
use crate::index::{flat::FlatIndexImpl, ConcurrentIndex, Idx, Index};
use crate::index::{flat::FlatIndexImpl, ConcurrentIndex, Idx, Index, UpcastIndex};

const D: u32 = 8;

Expand Down Expand Up @@ -315,4 +315,16 @@ mod tests {
refine.reset().unwrap();
assert_eq!(refine.ntotal(), 0);
}

#[test]
fn refine_flat_index_upcast() {
let index = FlatIndexImpl::new_l2(D).unwrap();
assert_eq!(index.d(), D);
assert_eq!(index.ntotal(), 0);

let refine = RefineFlatIndexImpl::new(index).unwrap();

let index_impl = refine.upcast();
assert_eq!(index_impl.d(), D);
}
}
14 changes: 13 additions & 1 deletion src/index/scalar_quantizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ impl IndexImpl {
#[cfg(test)]
mod tests {
use super::{IVFScalarQuantizerIndexImpl, QuantizerType, ScalarQuantizerIndexImpl};
use crate::index::{flat, index_factory, ConcurrentIndex, Idx, Index};
use crate::index::{flat, index_factory, ConcurrentIndex, Idx, Index, UpcastIndex};
use crate::metric::MetricType;

const D: u32 = 8;
Expand Down Expand Up @@ -679,4 +679,16 @@ mod tests {
assert_eq!(index.is_trained(), true);
assert_eq!(index.ntotal(), 5);
}

#[test]
fn ivf_sq_index_upcast() {
let quantizer = flat::FlatIndex::new_l2(D).unwrap();
let index =
IVFScalarQuantizerIndexImpl::new_l2(quantizer, D, QuantizerType::QT_fp16, 1).unwrap();
assert_eq!(index.d(), D);
assert_eq!(index.ntotal(), 0);

let index_impl = index.upcast();
assert_eq!(index_impl.d(), D);
}
}

0 comments on commit 97753a9

Please sign in to comment.