From c450d007924ebf38b941f3a1f568be895347390e Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Thu, 22 Jul 2021 11:54:10 +0300 Subject: [PATCH 1/3] Add UpcastIndex trait --- src/index/flat.rs | 12 +++++++++++- src/index/ivf_flat.rs | 13 ++++++++++++- src/index/mod.rs | 15 ++++++++++++++- src/index/pretransform.rs | 7 ++++++- src/index/refine_flat.rs | 14 +++++++++++++- src/index/scalar_quantizer.rs | 14 +++++++++++++- 6 files changed, 69 insertions(+), 6 deletions(-) diff --git a/src/index/flat.rs b/src/index/flat.rs index 7734868..b7bee5f 100644 --- a/src/index/flat.rs +++ b/src/index/flat.rs @@ -151,11 +151,21 @@ impl_concurrent_index!(FlatIndexImpl); #[cfg(test)] mod tests { use super::FlatIndexImpl; - use crate::index::{index_factory, ConcurrentIndex, FromInnerPtr, Idx, Index, NativeIndex}; + use crate::index::{ + index_factory, ConcurrentIndex, FromInnerPtr, Idx, Index, NativeIndex, UpcastIndex, + }; use crate::metric::MetricType; const D: u32 = 8; + #[test] + fn flat_index_from_upcast() { + let index = FlatIndexImpl::new_l2(D).unwrap(); + + let index_impl = index.upcast(); + assert_eq!(index_impl.d(), D); + } + #[test] fn flat_index_from_cast() { let mut index = index_factory(8, "Flat", MetricType::L2).unwrap(); diff --git a/src/index/ivf_flat.rs b/src/index/ivf_flat.rs index 18b49e3..4291005 100644 --- a/src/index/ivf_flat.rs +++ b/src/index/ivf_flat.rs @@ -166,7 +166,7 @@ mod tests { use super::IVFFlatIndexImpl; use crate::index::flat::FlatIndexImpl; - use crate::index::{index_factory, ConcurrentIndex, Idx, Index}; + use crate::index::{index_factory, ConcurrentIndex, Idx, Index, UpcastIndex}; use crate::MetricType; const D: u32 = 8; @@ -286,4 +286,15 @@ mod tests { assert_eq!(index.is_trained(), true); assert_eq!(index.ntotal(), 5); } + + #[test] + fn index_upcast() { + let q = FlatIndexImpl::new_l2(D).unwrap(); + let index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap(); + assert_eq!(index.d(), D); + assert_eq!(index.ntotal(), 0); + + let index_impl = index.upcast(); + assert_eq!(index_impl.d(), D); + } } diff --git a/src/index/mod.rs b/src/index/mod.rs index 4db479c..07ae7bd 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -17,7 +17,7 @@ use crate::selector::IdSelector; use std::ffi::CString; use std::fmt::{self, Display, Formatter, Write}; use std::os::raw::c_uint; -use std::ptr; +use std::{mem, ptr}; use faiss_sys::*; @@ -467,6 +467,19 @@ impl TryFromInnerPtr for IndexImpl { } } +pub trait UpcastIndex: NativeIndex { + fn upcast(self) -> IndexImpl; +} + +impl UpcastIndex for NI { + fn upcast(self) -> IndexImpl { + let inner_ptr = self.inner_ptr(); + mem::forget(self); + + unsafe { IndexImpl::from_inner_ptr(inner_ptr) } + } +} + impl_native_index!(IndexImpl); impl_native_index_clone!(IndexImpl); diff --git a/src/index/pretransform.rs b/src/index/pretransform.rs index f1d9a03..c33ee29 100644 --- a/src/index/pretransform.rs +++ b/src/index/pretransform.rs @@ -297,7 +297,7 @@ mod tests { const D: u32 = 8; #[test] - fn pre_transform_index_from_cast() { + fn pre_transform_index_from_cast_upcast() { let mut index = index_factory(D, "PCA4,Flat", MetricType::L2).unwrap(); let some_data = &[ @@ -315,6 +315,11 @@ mod tests { assert_eq!(index.is_trained(), true); assert_eq!(index.ntotal(), 5); assert_eq!(index.d(), 8); + + let index_impl = index.upcast(); + assert_eq!(index_impl.is_trained(), true); + assert_eq!(index_impl.ntotal(), 5); + assert_eq!(index_impl.d(), 8); } #[test] diff --git a/src/index/refine_flat.rs b/src/index/refine_flat.rs index 78cc153..90d4716 100644 --- a/src/index/refine_flat.rs +++ b/src/index/refine_flat.rs @@ -270,7 +270,7 @@ where #[cfg(test)] mod tests { use super::RefineFlatIndexImpl; - use crate::index::{flat::FlatIndexImpl, ConcurrentIndex, Idx, Index}; + use crate::index::{flat::FlatIndexImpl, ConcurrentIndex, Idx, Index, UpcastIndex}; const D: u32 = 8; @@ -315,4 +315,16 @@ mod tests { refine.reset().unwrap(); assert_eq!(refine.ntotal(), 0); } + + #[test] + fn refine_flat_index_upcast() { + let index = FlatIndexImpl::new_l2(D).unwrap(); + assert_eq!(index.d(), D); + assert_eq!(index.ntotal(), 0); + + let refine = RefineFlatIndexImpl::new(index).unwrap(); + + let index_impl = refine.upcast(); + assert_eq!(index_impl.d(), D); + } } diff --git a/src/index/scalar_quantizer.rs b/src/index/scalar_quantizer.rs index dbcfd94..5fad99c 100644 --- a/src/index/scalar_quantizer.rs +++ b/src/index/scalar_quantizer.rs @@ -491,7 +491,7 @@ impl IndexImpl { #[cfg(test)] mod tests { use super::{IVFScalarQuantizerIndexImpl, QuantizerType, ScalarQuantizerIndexImpl}; - use crate::index::{flat, index_factory, ConcurrentIndex, Idx, Index}; + use crate::index::{flat, index_factory, ConcurrentIndex, Idx, Index, UpcastIndex}; use crate::metric::MetricType; const D: u32 = 8; @@ -679,4 +679,16 @@ mod tests { assert_eq!(index.is_trained(), true); assert_eq!(index.ntotal(), 5); } + + #[test] + fn ivf_sq_index_upcast() { + let quantizer = flat::FlatIndex::new_l2(D).unwrap(); + let index = + IVFScalarQuantizerIndexImpl::new_l2(quantizer, D, QuantizerType::QT_fp16, 1).unwrap(); + assert_eq!(index.d(), D); + assert_eq!(index.ntotal(), 0); + + let index_impl = index.upcast(); + assert_eq!(index_impl.d(), D); + } } From e240f33f5901173eac601492feb75163b61163ba Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Fri, 23 Jul 2021 10:38:03 +0300 Subject: [PATCH 2/3] Add brief documentation for UpcastIndex --- src/index/mod.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/index/mod.rs b/src/index/mod.rs index 07ae7bd..128e91d 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -467,7 +467,22 @@ impl TryFromInnerPtr for IndexImpl { } } +/// If you need to store several different types of indexes in one collection. +/// You can cast all indexes to one type of IndexImpl. +/// # Examples +/// +/// ``` +/// use faiss::{index::{IndexImpl, UpcastIndex}, FlatIndex, index_factory, MetricType}; +/// let f1 = FlatIndex::new_l2(128).unwrap(); +/// let f2 = index_factory(128, "Flat", MetricType::L2).unwrap(); +/// let v: Vec = vec![ +/// f1.upcast(), +/// f2, +/// ]; +/// ``` +/// pub trait UpcastIndex: NativeIndex { + /// Converting an index to the base type of `IndexImpl` fn upcast(self) -> IndexImpl; } From 6e17af3bdfbd7c61ab3369e724540560feec97a8 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Fri, 23 Jul 2021 13:20:29 +0300 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Eduardo Pinho --- src/index/mod.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/index/mod.rs b/src/index/mod.rs index 128e91d..06ac41e 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -467,12 +467,14 @@ impl TryFromInnerPtr for IndexImpl { } } -/// If you need to store several different types of indexes in one collection. -/// You can cast all indexes to one type of IndexImpl. +/// Index upcast trait. +/// +/// If you need to store several different types of indexes in one collection, +/// you can cast all indexes to the common type `IndexImpl`. /// # Examples /// /// ``` -/// use faiss::{index::{IndexImpl, UpcastIndex}, FlatIndex, index_factory, MetricType}; +/// # use faiss::{index::{IndexImpl, UpcastIndex}, FlatIndex, index_factory, MetricType}; /// let f1 = FlatIndex::new_l2(128).unwrap(); /// let f2 = index_factory(128, "Flat", MetricType::L2).unwrap(); /// let v: Vec = vec![ @@ -482,7 +484,7 @@ impl TryFromInnerPtr for IndexImpl { /// ``` /// pub trait UpcastIndex: NativeIndex { - /// Converting an index to the base type of `IndexImpl` + /// Convert an index to the base `IndexImpl` type fn upcast(self) -> IndexImpl; }