From 2cf026e007c7ea9cd655c279705bd45d8411e2cb Mon Sep 17 00:00:00 2001 From: Matthijs van Otterdijk Date: Thu, 21 Mar 2024 12:39:46 +0100 Subject: [PATCH 1/8] ground work for derived domains --- vectorlink/src/domain.rs | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/vectorlink/src/domain.rs b/vectorlink/src/domain.rs index e625c60..e0cdb83 100644 --- a/vectorlink/src/domain.rs +++ b/vectorlink/src/domain.rs @@ -1,11 +1,13 @@ use std::{ any::Any, + collections::HashMap, io, ops::{Deref, DerefMut, Range}, path::Path, sync::{Arc, RwLock}, }; +use parallel_hnsw::pq::HnswQuantizer; use urlencoding::encode; use crate::store::{ImmutableVectorFile, LoadedVectorRange, SequentialVectorLoader, VectorFile}; @@ -22,9 +24,27 @@ pub fn downcast_generic_domain( .expect("Could not downcast domain to expected embedding size") } +pub struct PqDerivedDomain< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C, +> { + name: String, + file: RwLock>, + quantizer: HnswQuantizer, +} + +pub trait Deriver { + type From: Copy; + + fn append_derived(&self, vecs: &[Self::From]) -> io::Result<()>; +} + pub struct Domain { name: String, file: RwLock>, + derived_domains: Vec + Send + Sync>>, } impl GenericDomain for Domain { @@ -55,6 +75,7 @@ impl Domain { Ok(Domain { name: name.to_string(), + derived_domains: Vec::new(), file, }) } @@ -71,17 +92,6 @@ impl Domain { self.file().as_immutable() } - fn add_vecs<'a, I: Iterator>(&self, vecs: I) -> io::Result<(usize, usize)> - where - T: 'a, - { - let mut vector_file = self.file_mut(); - let old_len = vector_file.num_vecs(); - let count = vector_file.append_vectors(vecs)?; - - Ok((old_len, count)) - } - pub fn concatenate_file>(&self, path: P) -> io::Result<(usize, usize)> { let read_vector_file = VectorFile::open(path, true)?; let old_size = self.num_vecs(); From ab698eb5b882e41790bb84c3dd871dda30b4a254 Mon Sep 17 00:00:00 2001 From: Matthijs van Otterdijk Date: Thu, 21 Mar 2024 13:53:17 +0100 Subject: [PATCH 2/8] derived domain work --- vectorlink/src/batch.rs | 34 ++--------------- vectorlink/src/domain.rs | 81 +++++++++++++++++++++++++++++++++++----- vectorlink/src/store.rs | 26 +------------ 3 files changed, 77 insertions(+), 64 deletions(-) diff --git a/vectorlink/src/batch.rs b/vectorlink/src/batch.rs index b87be13..b3bbef9 100644 --- a/vectorlink/src/batch.rs +++ b/vectorlink/src/batch.rs @@ -26,6 +26,7 @@ use crate::{ indexer::{create_index_name, index_serialization_path}, openai::{embeddings_for, EmbeddingError, Model}, server::Operation, + store::VectorFile, vecmath::{Embedding, CENTROID_16_LENGTH, EMBEDDING_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH}, vectors::VectorStore, }; @@ -56,36 +57,13 @@ pub enum VectorizationError { Io(#[from] io::Error), } -async fn save_embeddings( - vec_file: &mut File, - offset: usize, - embeddings: &[Embedding], -) -> Result<(), VectorizationError> { - let transmuted = unsafe { - std::slice::from_raw_parts( - embeddings.as_ptr() as *const u8, - std::mem::size_of_val(embeddings), - ) - }; - vec_file - .seek(SeekFrom::Start( - (offset * std::mem::size_of::()) as u64, - )) - .await?; - vec_file.write_all(transmuted).await?; - vec_file.flush().await?; - vec_file.sync_data().await?; - - Ok(()) -} - pub async fn vectorize_from_operations< S: Stream>, P: AsRef + Unpin, >( api_key: &str, model: Model, - vec_file: &mut File, + vec_file: &mut VectorFile, op_stream: S, progress_file_path: P, ) -> Result { @@ -122,7 +100,7 @@ pub async fn vectorize_from_operations< let (embeddings, chunk_failures) = embeds.unwrap()?; eprintln!("retrieved embeddings"); - save_embeddings(vec_file, offset as usize, &embeddings).await?; + vec_file.append_vector_range(&embeddings)?; eprintln!("saved embeddings"); failures += chunk_failures; offset += embeddings.len() as u64; @@ -303,11 +281,7 @@ pub async fn index_from_operations_file>( let mut vector_path = staging_path.clone(); vector_path.push("vectors"); - let mut vec_file = OpenOptions::new() - .create(true) - .write(true) - .open(&vector_path) - .await?; + let mut vec_file = VectorFile::open_create(&vector_path, true)?; let mut progress_file_path = staging_path.clone(); progress_file_path.push("progress"); diff --git a/vectorlink/src/domain.rs b/vectorlink/src/domain.rs index e0cdb83..29d2200 100644 --- a/vectorlink/src/domain.rs +++ b/vectorlink/src/domain.rs @@ -7,10 +7,16 @@ use std::{ sync::{Arc, RwLock}, }; -use parallel_hnsw::pq::HnswQuantizer; +use parallel_hnsw::{ + pq::{HnswQuantizer, Quantizer}, + Comparator, +}; use urlencoding::encode; -use crate::store::{ImmutableVectorFile, LoadedVectorRange, SequentialVectorLoader, VectorFile}; +use crate::{ + store::{ImmutableVectorFile, LoadedVectorRange, SequentialVectorLoader, VectorFile}, + vecmath::Embedding, +}; pub trait GenericDomain: 'static + Any + Send + Sync { fn name(&self) -> &str; @@ -24,6 +30,15 @@ pub fn downcast_generic_domain( .expect("Could not downcast domain to expected embedding size") } +pub trait Deriver: Any { + type From: Copy; + + fn concatenate_derived(&self, loader: SequentialVectorLoader) -> io::Result<()>; + fn chunk_size(&self) -> usize { + 1_000 + } +} + pub struct PqDerivedDomain< const SIZE: usize, const CENTROID_SIZE: usize, @@ -35,16 +50,35 @@ pub struct PqDerivedDomain< quantizer: HnswQuantizer, } -pub trait Deriver { - type From: Copy; - - fn append_derived(&self, vecs: &[Self::From]) -> io::Result<()>; +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + Comparator, + > Deriver for PqDerivedDomain +{ + type From = [f32; SIZE]; + + fn concatenate_derived(&self, loader: SequentialVectorLoader) -> io::Result<()> { + for chunk in loader { + let chunk = chunk?; + let mut result = Vec::with_capacity(chunk.len()); + for vec in chunk.iter() { + let quantized = self.quantizer.quantize(vec); + result.push(quantized); + } + let mut file = self.file.write().unwrap(); + file.append_vector_range(&result)?; + } + + Ok(()) + } } pub struct Domain { name: String, file: RwLock>, - derived_domains: Vec + Send + Sync>>, + derived_domains: RwLock + Send + Sync>>>, } impl GenericDomain for Domain { @@ -58,7 +92,7 @@ impl GenericDomain for Domain { } #[allow(unused)] -impl Domain { +impl Domain { pub fn name(&self) -> &str { &self.name } @@ -75,7 +109,7 @@ impl Domain { Ok(Domain { name: name.to_string(), - derived_domains: Vec::new(), + derived_domains: RwLock::new(HashMap::new()), file, }) } @@ -95,6 +129,11 @@ impl Domain { pub fn concatenate_file>(&self, path: P) -> io::Result<(usize, usize)> { let read_vector_file = VectorFile::open(path, true)?; let old_size = self.num_vecs(); + let derived_domains = self.derived_domains.read().unwrap(); + for derived in derived_domains.values() { + let chunk_size = derived.chunk_size(); + derived.concatenate_derived(read_vector_file.vector_chunks(chunk_size)?)?; + } Ok(( old_size, self.file_mut().append_vector_file(&read_vector_file)?, @@ -116,4 +155,28 @@ impl Domain { pub fn vector_chunks(&self, chunk_size: usize) -> io::Result> { self.file().vector_chunks(chunk_size) } + + pub fn create_derived + 'static + Send + Sync>( + &self, + name: String, + deriver: D, + ) { + let mut derived_domains = self.derived_domains.write().unwrap(); + assert!( + !derived_domains.contains_key(&name), + "tried to create derived domain that already exists" + ); + + derived_domains.insert(name, Arc::new(deriver)); + } + + pub fn derived_domain<'a, T2: Deriver + Send + Sync>( + &'a self, + name: &str, + ) -> Option + 'a> { + let derived_domains = self.derived_domains.read().unwrap(); + let derived = derived_domains.get(name)?; + + Some(Arc::downcast::(derived.clone()).expect("derived domain was not of expected type")) + } } diff --git a/vectorlink/src/store.rs b/vectorlink/src/store.rs index a994b66..2b2f8f2 100644 --- a/vectorlink/src/store.rs +++ b/vectorlink/src/store.rs @@ -277,34 +277,10 @@ impl VectorFile { (self.num_vecs * std::mem::size_of::()) as u64, )?; self.num_vecs = self.num_vecs + vectors.len(); - self.file.sync_data()?; + self.file.sync_data()?; // TODO probably don't do it here cause we might want to append multiple ranges Ok(vectors.len()) } - pub fn append_vectors<'a, I: Iterator>(&mut self, vectors: I) -> io::Result - where - T: 'a, - { - // wouldn't it be more straightforward to just use the file as a cursor? - let mut offset = (self.num_vecs * std::mem::size_of::()) as u64; - let mut count = 0; - for vector in vectors { - let bytes = unsafe { - std::slice::from_raw_parts( - vector as *const T as *const u8, - std::mem::size_of::(), - ) - }; - self.file.write_all_at(bytes, offset)?; - self.num_vecs += 1; - offset += std::mem::size_of::() as u64; - count += 1; - } - - self.file.sync_data()?; - - Ok(count) - } pub fn append_vector_file(&mut self, file: &VectorFile) -> io::Result { let mut read_offset = 0; From b7fdbbd7bdcfff47f301d874038485c00596b207 Mon Sep 17 00:00:00 2001 From: Matthijs van Otterdijk Date: Thu, 21 Mar 2024 15:15:31 +0100 Subject: [PATCH 3/8] more groundwork --- Cargo.lock | 3 + vectorlink/Cargo.toml | 3 + vectorlink/src/domain.rs | 159 ++++++++++++++++++++++++++++++++++++--- vectorlink/src/store.rs | 4 + 4 files changed, 157 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b34975b..d6619ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1821,7 +1821,10 @@ dependencies = [ "itertools 0.10.5", "lazy_static", "libc", + "linfa", + "linfa-clustering", "lru", + "ndarray", "parallel-hnsw", "rand", "rand_pcg", diff --git a/vectorlink/Cargo.toml b/vectorlink/Cargo.toml index b8e9181..b4226e2 100644 --- a/vectorlink/Cargo.toml +++ b/vectorlink/Cargo.toml @@ -33,6 +33,9 @@ itertools = "0.10" chrono = "0.4.26" rayon = "1.8.0" libc = "0.2.153" +linfa = "0.7.0" +linfa-clustering = "0.7.0" +ndarray = "0.15.6" [dev-dependencies] assert_float_eq = "1.1.3" diff --git a/vectorlink/src/domain.rs b/vectorlink/src/domain.rs index 29d2200..6b94f63 100644 --- a/vectorlink/src/domain.rs +++ b/vectorlink/src/domain.rs @@ -1,22 +1,25 @@ use std::{ any::Any, - collections::HashMap, + collections::{HashMap, HashSet}, + error::Error, io, + marker::PhantomData, ops::{Deref, DerefMut, Range}, - path::Path, + path::{Path, PathBuf}, sync::{Arc, RwLock}, }; +use linfa::{traits::Fit, DatasetBase}; +use linfa_clustering::KMeans; +use ndarray::{Array, Array2}; use parallel_hnsw::{ - pq::{HnswQuantizer, Quantizer}, - Comparator, + pq::{CentroidComparatorConstructor, HnswQuantizer, Quantizer}, + Comparator, Hnsw, Serializable, VectorId, }; +use rand::{distributions::Uniform, rngs::StdRng, thread_rng, Rng, SeedableRng}; use urlencoding::encode; -use crate::{ - store::{ImmutableVectorFile, LoadedVectorRange, SequentialVectorLoader, VectorFile}, - vecmath::Embedding, -}; +use crate::store::{ImmutableVectorFile, LoadedVectorRange, SequentialVectorLoader, VectorFile}; pub trait GenericDomain: 'static + Any + Send + Sync { fn name(&self) -> &str; @@ -39,13 +42,23 @@ pub trait Deriver: Any { } } +pub trait NewDeriver { + type T: Copy; + type Deriver: Deriver; + + fn new( + &self, + path: PathBuf, + vectors: &VectorFile, + ) -> Result>; +} + pub struct PqDerivedDomain< const SIZE: usize, const CENTROID_SIZE: usize, const QUANTIZED_SIZE: usize, C, > { - name: String, file: RwLock>, quantizer: HnswQuantizer, } @@ -75,6 +88,116 @@ impl< } } +#[derive(Default)] +struct PqDerivedDomainInitializer< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C, +> { + _x: PhantomData, +} + +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + + Comparator + + CentroidComparatorConstructor + + Serializable, + > NewDeriver for PqDerivedDomainInitializer +{ + type T = [f32; SIZE]; + type Deriver = PqDerivedDomain; + + fn new( + &self, + path: PathBuf, + vectors: &VectorFile<[f32; SIZE]>, + ) -> Result> { + // TODO do something else for sizes close to number of vecs + const NUMBER_OF_CENTROIDS: usize = 10_000; + const SAMPLE_SIZE: usize = NUMBER_OF_CENTROIDS / 10; + let selection = if SAMPLE_SIZE >= vectors.num_vecs() { + vectors.all_vectors().unwrap().clone().into_vec() + } else { + let mut rng = thread_rng(); + let mut set = HashSet::new(); + let range = Uniform::from(0_usize..vectors.num_vecs()); + while set.len() != SAMPLE_SIZE { + let candidate = rng.sample(&range); + set.insert(candidate); + } + + set.into_iter() + .map(|index| vectors.vec(index).unwrap()) + .collect() + }; + + // Linfa + let data: Vec = selection.into_iter().flat_map(|v| v.into_iter()).collect(); + let sub_length = data.len() / CENTROID_SIZE; + let sub_arrays = Array::from_shape_vec((sub_length, CENTROID_SIZE), data).unwrap(); + eprintln!("sub_arrays: {sub_arrays:?}"); + let observations = DatasetBase::from(sub_arrays); + // TODO review this number + let number_of_clusters = usize::min(sub_length, 1_000); + let prng = StdRng::seed_from_u64(42); + eprintln!("Running kmeans"); + let model = KMeans::params_with_rng(number_of_clusters, prng.clone()) + .tolerance(1e-2) + .fit(&observations) + .expect("KMeans fitted"); + let centroid_array: Array2 = model.centroids().clone(); + centroid_array.len(); + let centroid_flat: Vec = centroid_array + .into_shape(number_of_clusters * CENTROID_SIZE) + .unwrap() + .to_vec(); + eprintln!("centroid flat len: {}", centroid_flat.len()); + let centroids: Vec<[f32; CENTROID_SIZE]> = centroid_flat + .chunks(CENTROID_SIZE) + .map(|v| { + let mut array = [0.0; CENTROID_SIZE]; + array.copy_from_slice(v); + array + }) + .collect(); + // + eprintln!("Number of centroids: {}", centroids.len()); + + let vector_ids = (0..centroids.len()).map(VectorId).collect(); + let centroid_comparator = C::new(centroids); + let centroid_m = 24; + let centroid_m0 = 48; + let centroid_order = 12; + let mut centroid_hnsw: Hnsw = Hnsw::generate( + centroid_comparator, + vector_ids, + centroid_m, + centroid_m0, + centroid_order, + ); + //centroid_hnsw.improve_index(); + centroid_hnsw.improve_neighbors(0.01, 1.0); + + let centroid_quantizer: HnswQuantizer = + HnswQuantizer::new(centroid_hnsw); + + let quantizer_path = path.join("quantizer"); + centroid_quantizer.serialize(quantizer_path)?; + + let quantized_path = path.join("quantized.vecs"); + let quantized_file = VectorFile::create(quantized_path, true)?; + + Ok(PqDerivedDomain { + file: RwLock::new(quantized_file), + quantizer: centroid_quantizer, + }) + } +} + pub struct Domain { name: String, file: RwLock>, @@ -156,18 +279,30 @@ impl Domain { self.file().vector_chunks(chunk_size) } - pub fn create_derived + 'static + Send + Sync>( + pub fn create_derived< + N: NewDeriver, + D: Deriver + 'static + Send + Sync, + >( &self, name: String, - deriver: D, - ) { + deriver: N, + ) -> Result<(), Box> { let mut derived_domains = self.derived_domains.write().unwrap(); assert!( !derived_domains.contains_key(&name), "tried to create derived domain that already exists" ); + let file = self.file(); + let mut path = file.path().clone(); + path.set_extension("derived"); + path.push(&name); + std::fs::create_dir_all(&path)?; + + let deriver = deriver.new(path, &*file)?; derived_domains.insert(name, Arc::new(deriver)); + + Ok(()) } pub fn derived_domain<'a, T2: Deriver + Send + Sync>( diff --git a/vectorlink/src/store.rs b/vectorlink/src/store.rs index 2b2f8f2..7c0997e 100644 --- a/vectorlink/src/store.rs +++ b/vectorlink/src/store.rs @@ -337,6 +337,10 @@ impl VectorFile { _x: PhantomData, }) } + + pub fn path(&self) -> &PathBuf { + &self.path + } } pub struct ImmutableVectorFile(VectorFile); From f6e52e40d628089378c1cc9b048e1e97ea8a940c Mon Sep 17 00:00:00 2001 From: Matthijs van Otterdijk Date: Thu, 21 Mar 2024 17:28:39 +0100 Subject: [PATCH 4/8] creating, serialization and deserialization of derived domains --- vectorlink/src/comparator.rs | 17 ++++- vectorlink/src/domain.rs | 136 +++++++++++++++++++++++++++++++++-- vectorlink/src/vectors.rs | 9 ++- 3 files changed, 154 insertions(+), 8 deletions(-) diff --git a/vectorlink/src/comparator.rs b/vectorlink/src/comparator.rs index 4f04ac9..acb08d4 100644 --- a/vectorlink/src/comparator.rs +++ b/vectorlink/src/comparator.rs @@ -1,5 +1,5 @@ use parallel_hnsw::pq::{ - CentroidComparatorConstructor, PartialDistance, QuantizedComparatorConstructor, + CentroidComparatorConstructor, HnswQuantizer, PartialDistance, QuantizedComparatorConstructor, }; use rand::distributions::Uniform; use rand::{thread_rng, Rng}; @@ -16,7 +16,7 @@ use parallel_hnsw::{pq, Comparator, Serializable, SerializationError, VectorId}; use crate::store::{ImmutableVectorFile, LoadedVectorRange, VectorFile}; use crate::vecmath::{ self, EuclideanDistance16, EuclideanDistance32, Quantized16Embedding, Quantized32Embedding, - CENTROID_16_LENGTH, CENTROID_32_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, + CENTROID_16_LENGTH, CENTROID_32_LENGTH, EMBEDDING_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH, }; use crate::{ @@ -541,6 +541,19 @@ impl pq::VectorStore for Quantized16Comparator { } } +pub type HnswQuantizer16 = HnswQuantizer< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + Centroid16Comparator, +>; +pub type HnswQuantizer32 = HnswQuantizer< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + Centroid32Comparator, +>; + #[cfg(test)] mod tests { use std::sync::{Arc, RwLock}; diff --git a/vectorlink/src/domain.rs b/vectorlink/src/domain.rs index 6b94f63..d2c7b2d 100644 --- a/vectorlink/src/domain.rs +++ b/vectorlink/src/domain.rs @@ -1,5 +1,5 @@ use std::{ - any::Any, + any::{Any, TypeId}, collections::{HashMap, HashSet}, error::Error, io, @@ -17,9 +17,17 @@ use parallel_hnsw::{ Comparator, Hnsw, Serializable, VectorId, }; use rand::{distributions::Uniform, rngs::StdRng, thread_rng, Rng, SeedableRng}; +use serde::{Deserialize, Serialize}; use urlencoding::encode; -use crate::store::{ImmutableVectorFile, LoadedVectorRange, SequentialVectorLoader, VectorFile}; +use crate::{ + comparator::{Centroid16Comparator, Centroid32Comparator, HnswQuantizer16, HnswQuantizer32}, + store::{ImmutableVectorFile, LoadedVectorRange, SequentialVectorLoader, VectorFile}, + vecmath::{ + Embedding, CENTROID_16_LENGTH, CENTROID_32_LENGTH, EMBEDDING_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH, + }, +}; pub trait GenericDomain: 'static + Any + Send + Sync { fn name(&self) -> &str; @@ -37,9 +45,12 @@ pub trait Deriver: Any { type From: Copy; fn concatenate_derived(&self, loader: SequentialVectorLoader) -> io::Result<()>; + fn configuration(&self) -> DerivedDomainConfiguration; fn chunk_size(&self) -> usize { 1_000 } + + //fn try_cast(&self) -> Option> } pub trait NewDeriver { @@ -63,6 +74,31 @@ pub struct PqDerivedDomain< quantizer: HnswQuantizer, } +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + Comparator, + > PqDerivedDomain +{ + fn as_arc( + self, + ) -> Option + Send + Sync + 'static>> { + let expected_type_id = TypeId::of::<[f32; SIZE]>(); + let actual_type_id = TypeId::of::(); + if expected_type_id == actual_type_id { + let result = Arc::new(self) as Arc>; + // this should be safe as we asserted at runtime that these types are the same + let transmuted: Arc + Send + Sync + 'static> = + unsafe { std::mem::transmute(result) }; + + Some(transmuted) + } else { + None + } + } +} + impl< const SIZE: usize, const CENTROID_SIZE: usize, @@ -86,6 +122,73 @@ impl< Ok(()) } + + fn configuration(&self) -> DerivedDomainConfiguration { + match (SIZE, CENTROID_SIZE, QUANTIZED_SIZE) { + (EMBEDDING_LENGTH, CENTROID_16_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH) => { + DerivedDomainConfiguration::SmallPq + } + (EMBEDDING_LENGTH, CENTROID_32_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH) => { + DerivedDomainConfiguration::LargePq + } + _ => panic!("unserializable pq derived domain"), + } + } +} + +pub type PqDerivedDomain16 = PqDerivedDomain< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + Centroid16Comparator, +>; +pub type PqDerivedDomain32 = PqDerivedDomain< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + Centroid32Comparator, +>; + +#[derive(Serialize, Deserialize)] +pub enum DerivedDomainConfiguration { + SmallPq, + LargePq, +} + +impl DerivedDomainConfiguration { + pub fn new>( + &self, + path: P, + ) -> Result + Send + Sync + 'static>, io::Error> { + match self { + Self::SmallPq => { + let file = RwLock::new(VectorFile::open( + path.as_ref().join("quantized.vecs"), + true, + )?); + // panic here if T is not what we expect + let quantizer: HnswQuantizer16 = + HnswQuantizer::deserialize(path, ()).expect("TODO"); + + let domain: PqDerivedDomain16 = PqDerivedDomain { file, quantizer }; + + Ok(domain.as_arc::().unwrap()) + } + Self::LargePq => { + let file = RwLock::new(VectorFile::open( + path.as_ref().join("quantized.vecs"), + true, + )?); + // panic here if T is not what we expect + let quantizer: HnswQuantizer32 = + HnswQuantizer::deserialize(path, ()).expect("TODO"); + + let domain: PqDerivedDomain32 = PqDerivedDomain { file, quantizer }; + + Ok(domain.as_arc::().unwrap()) + } + } + } } #[derive(Default)] @@ -224,15 +327,36 @@ impl Domain { self.file().num_vecs() } - pub fn open>(dir: P, name: &str) -> io::Result { + pub fn open>(dir: P, name: &str) -> Result { let mut path = dir.as_ref().to_path_buf(); let encoded_name = encode(name); path.push(format!("{encoded_name}.vecs")); let file = RwLock::new(VectorFile::open_create(&path, true)?); + // load derived domains + let mut derived_path = path.clone(); + derived_path.set_extension("derived"); + let mut derived_domains = HashMap::new(); + if derived_path.exists() { + for file in std::fs::read_dir(derived_path)? { + let derived = file?; + // now we have to discover what kind of derived domain this is + // the options are hardcoded. + let name = derived.file_name().into_string().unwrap(); + let config_file = derived.path().join("config.json"); + if config_file.exists() { + let mut file = std::fs::File::open(config_file)?; + let config: DerivedDomainConfiguration = serde_json::from_reader(file)?; + let derived_domain = config.new::(derived.path()).expect("TODO"); + + derived_domains.insert(name, derived_domain); + } + } + } + Ok(Domain { name: name.to_string(), - derived_domains: RwLock::new(HashMap::new()), + derived_domains: RwLock::new(derived_domains), file, }) } @@ -299,7 +423,11 @@ impl Domain { path.push(&name); std::fs::create_dir_all(&path)?; + let config_path = path.join("config.json"); let deriver = deriver.new(path, &*file)?; + let config = deriver.configuration(); + let config_string = serde_json::to_string(&config).unwrap(); + std::fs::write(config_path, config_string)?; derived_domains.insert(name, Arc::new(deriver)); Ok(()) diff --git a/vectorlink/src/vectors.rs b/vectorlink/src/vectors.rs index d798f2a..de93875 100644 --- a/vectorlink/src/vectors.rs +++ b/vectorlink/src/vectors.rs @@ -3,6 +3,7 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; +use std::error::Error; use std::fmt; use std::fs::{File, OpenOptions}; use std::io::{self, Seek, SeekFrom, Write}; @@ -45,7 +46,8 @@ impl VectorStore { } } - pub fn get_domain(&self, name: &str) -> io::Result>> { + // TODO better error + pub fn get_domain(&self, name: &str) -> Result>, io::Error> { let domains = self.domains.read().unwrap(); if let Some(domain) = domains.get(name) { Ok(downcast_generic_domain(domain.clone())) @@ -55,7 +57,10 @@ impl VectorStore { if let Some(domain) = domains.get(name) { Ok(downcast_generic_domain(domain.clone())) } else { - let domain = Arc::new(Domain::open(&self.dir, name)?); + let domain = Arc::new( + Domain::open(&self.dir, name) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?, + ); domains.insert(name.to_string(), domain.clone()); Ok(domain) From 3bba64201e4880fdc6575d4e6b4802221c85e2b9 Mon Sep 17 00:00:00 2001 From: Matthijs van Otterdijk Date: Fri, 22 Mar 2024 12:55:43 +0100 Subject: [PATCH 5/8] nitialize domain on creation --- vectorlink/src/domain.rs | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/vectorlink/src/domain.rs b/vectorlink/src/domain.rs index d2c7b2d..9beb76d 100644 --- a/vectorlink/src/domain.rs +++ b/vectorlink/src/domain.rs @@ -45,6 +45,11 @@ pub trait Deriver: Any { type From: Copy; fn concatenate_derived(&self, loader: SequentialVectorLoader) -> io::Result<()>; + fn concatenate_file(&self, file: &VectorFile) -> io::Result<()> { + self.concatenate_derived(file.vector_chunks(self.chunk_size())?); + + Ok(()) + } fn configuration(&self) -> DerivedDomainConfiguration; fn chunk_size(&self) -> usize { 1_000 @@ -378,8 +383,7 @@ impl Domain { let old_size = self.num_vecs(); let derived_domains = self.derived_domains.read().unwrap(); for derived in derived_domains.values() { - let chunk_size = derived.chunk_size(); - derived.concatenate_derived(read_vector_file.vector_chunks(chunk_size)?)?; + derived.concatenate_file(&read_vector_file)?; } Ok(( old_size, @@ -411,23 +415,34 @@ impl Domain { name: String, deriver: N, ) -> Result<(), Box> { + // first, let's take a read lock on the internal file to stop + // others from doing things to this domain. + // Makes deadlocks less likely as the only hold-and-wait + // pattern then remaining has to involve both file and derived + // domains. + let file = self.file(); let mut derived_domains = self.derived_domains.write().unwrap(); assert!( !derived_domains.contains_key(&name), "tried to create derived domain that already exists" ); - let file = self.file(); + // create a directory for this derived domain let mut path = file.path().clone(); path.set_extension("derived"); path.push(&name); std::fs::create_dir_all(&path)?; + // write a config so we can recognize later on what this domain is let config_path = path.join("config.json"); let deriver = deriver.new(path, &*file)?; let config = deriver.configuration(); let config_string = serde_json::to_string(&config).unwrap(); std::fs::write(config_path, config_string)?; + + // convert all already-existing vectors to this domain + deriver.concatenate_file(&*file)?; + derived_domains.insert(name, Arc::new(deriver)); Ok(()) From c52456967d0dd2eed8a287e3cddef1877d7c10ce Mon Sep 17 00:00:00 2001 From: Matthijs van Otterdijk Date: Fri, 22 Mar 2024 14:59:16 +0100 Subject: [PATCH 6/8] (untested) quantize implementation --- vectorlink/src/domain.rs | 120 +++++++++++++++++++++++++++++---------- vectorlink/src/main.rs | 25 +++++++- 2 files changed, 114 insertions(+), 31 deletions(-) diff --git a/vectorlink/src/domain.rs b/vectorlink/src/domain.rs index 9beb76d..c6528f5 100644 --- a/vectorlink/src/domain.rs +++ b/vectorlink/src/domain.rs @@ -9,6 +9,7 @@ use std::{ sync::{Arc, RwLock}, }; +use clap::ValueEnum; use linfa::{traits::Fit, DatasetBase}; use linfa_clustering::KMeans; use ndarray::{Array, Array2}; @@ -46,7 +47,7 @@ pub trait Deriver: Any { fn concatenate_derived(&self, loader: SequentialVectorLoader) -> io::Result<()>; fn concatenate_file(&self, file: &VectorFile) -> io::Result<()> { - self.concatenate_derived(file.vector_chunks(self.chunk_size())?); + self.concatenate_derived(file.vector_chunks(self.chunk_size())?)?; Ok(()) } @@ -54,19 +55,27 @@ pub trait Deriver: Any { fn chunk_size(&self) -> usize { 1_000 } - - //fn try_cast(&self) -> Option> } -pub trait NewDeriver { - type T: Copy; - type Deriver: Deriver; +pub trait DerivedDomainInitializer { + fn initialize( + &self, + path: PathBuf, + vectors: &VectorFile, + ) -> Result + Send + Sync>, Box>; +} - fn new( +// interestingly, we're required to provide our own trait object implementation. Rust is not able to derive it for us. +impl DerivedDomainInitializer + for Box + Send + Sync> +{ + fn initialize( &self, path: PathBuf, - vectors: &VectorFile, - ) -> Result>; + vectors: &VectorFile, + ) -> Result + Send + Sync>, Box> { + (**self).initialize(path, vectors) + } } pub struct PqDerivedDomain< @@ -154,7 +163,7 @@ pub type PqDerivedDomain32 = PqDerivedDomain< Centroid32Comparator, >; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, ValueEnum, Debug, Clone, Copy)] pub enum DerivedDomainConfiguration { SmallPq, LargePq, @@ -194,10 +203,43 @@ impl DerivedDomainConfiguration { } } } + + pub fn initializer( + &self, + ) -> Box + 'static + Send + Sync> { + assert_eq!(TypeId::of::(), TypeId::of::()); + match self { + DerivedDomainConfiguration::SmallPq => { + let initializer = PqDerivedDomainInitializer::< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + Centroid16Comparator, + >::default(); + + let boxed: Box + 'static + Send + Sync> = + Box::new(initializer); + + unsafe { std::mem::transmute(boxed) } + } + DerivedDomainConfiguration::LargePq => { + let initializer = PqDerivedDomainInitializer::< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + Centroid32Comparator, + >::default(); + + let boxed: Box + 'static + Send + Sync> = + Box::new(initializer); + + unsafe { std::mem::transmute(boxed) } + } + } + } } -#[derive(Default)] -struct PqDerivedDomainInitializer< +pub struct PqDerivedDomainInitializer< const SIZE: usize, const CENTROID_SIZE: usize, const QUANTIZED_SIZE: usize, @@ -205,7 +247,6 @@ struct PqDerivedDomainInitializer< > { _x: PhantomData, } - impl< const SIZE: usize, const CENTROID_SIZE: usize, @@ -213,17 +254,37 @@ impl< C: 'static + Comparator + CentroidComparatorConstructor - + Serializable, - > NewDeriver for PqDerivedDomainInitializer + + Serializable + + Send, + > PqDerivedDomainInitializer { - type T = [f32; SIZE]; - type Deriver = PqDerivedDomain; +} - fn new( +impl Default + for PqDerivedDomainInitializer +{ + fn default() -> Self { + Self { _x: PhantomData } + } +} + +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + + Comparator + + CentroidComparatorConstructor + + Serializable + + Send, + > DerivedDomainInitializer<[f32; SIZE]> + for PqDerivedDomainInitializer +{ + fn initialize( &self, path: PathBuf, vectors: &VectorFile<[f32; SIZE]>, - ) -> Result> { + ) -> Result + Send + Sync>, Box> { // TODO do something else for sizes close to number of vecs const NUMBER_OF_CENTROIDS: usize = 10_000; const SAMPLE_SIZE: usize = NUMBER_OF_CENTROIDS / 10; @@ -297,12 +358,14 @@ impl< centroid_quantizer.serialize(quantizer_path)?; let quantized_path = path.join("quantized.vecs"); - let quantized_file = VectorFile::create(quantized_path, true)?; + let quantized_file: VectorFile<[u16; QUANTIZED_SIZE]> = + VectorFile::create(quantized_path, true)?; - Ok(PqDerivedDomain { + let deriver = PqDerivedDomain { file: RwLock::new(quantized_file), quantizer: centroid_quantizer, - }) + }; + Ok(Arc::new(deriver)) } } @@ -407,13 +470,10 @@ impl Domain { self.file().vector_chunks(chunk_size) } - pub fn create_derived< - N: NewDeriver, - D: Deriver + 'static + Send + Sync, - >( + pub fn create_derived>( &self, name: String, - deriver: N, + derived_domain_initializer: N, ) -> Result<(), Box> { // first, let's take a read lock on the internal file to stop // others from doing things to this domain. @@ -435,7 +495,7 @@ impl Domain { // write a config so we can recognize later on what this domain is let config_path = path.join("config.json"); - let deriver = deriver.new(path, &*file)?; + let deriver = derived_domain_initializer.initialize(path, &*file)?; let config = deriver.configuration(); let config_string = serde_json::to_string(&config).unwrap(); std::fs::write(config_path, config_string)?; @@ -443,12 +503,12 @@ impl Domain { // convert all already-existing vectors to this domain deriver.concatenate_file(&*file)?; - derived_domains.insert(name, Arc::new(deriver)); + derived_domains.insert(name, deriver); Ok(()) } - pub fn derived_domain<'a, T2: Deriver + Send + Sync>( + pub fn get_derived<'a, T2: Deriver + Send + Sync>( &'a self, name: &str, ) -> Option + 'a> { diff --git a/vectorlink/src/main.rs b/vectorlink/src/main.rs index f2323e3..2f0c334 100644 --- a/vectorlink/src/main.rs +++ b/vectorlink/src/main.rs @@ -9,13 +9,13 @@ use std::sync::Arc; mod batch; mod comparator; mod configuration; +mod domain; mod indexer; mod openai; mod server; mod store; mod vecmath; mod vectors; -mod domain; mod search_server; @@ -23,6 +23,7 @@ use batch::index_from_operations_file; use clap::CommandFactory; use clap::{Parser, Subcommand, ValueEnum}; use configuration::HnswConfiguration; +use domain::DerivedDomainConfiguration; //use hnsw::Hnsw; use openai::Model; use parallel_hnsw::pq::Quantizer; @@ -32,6 +33,7 @@ use parallel_hnsw::Comparator; use parallel_hnsw::Serializable; use std::fs::File; use std::io; +use vecmath::Embedding; use vecmath::Quantized32Embedding; use vecmath::EMBEDDING_BYTE_LENGTH; use vecmath::EMBEDDING_LENGTH; @@ -212,6 +214,16 @@ enum Commands { #[arg(short, long)] key: Option, }, + Quantize { + #[arg(short, long)] + directory: String, + #[arg(short, long)] + domain: String, + #[arg(short, long)] + derived: String, + #[arg(short, long, value_enum, default_value_t = DerivedDomainConfiguration::SmallPq)] + method: DerivedDomainConfiguration, + }, } #[derive(Clone, Copy, Debug, ValueEnum)] @@ -636,6 +648,17 @@ async fn main() -> Result<(), Box> { .await .unwrap() } + Commands::Quantize { + directory, + domain, + derived, + method, + } => { + let store = VectorStore::new(directory, 10_000); // num bufs is actually obsolete now. + let domain = store.get_domain(&domain).unwrap(); + let initializer = method.initializer::(); + domain.create_derived(derived.clone(), initializer).unwrap(); + } } Ok(()) From 5da0bf65452b67172d80cc7481bd99dceb094b98 Mon Sep 17 00:00:00 2001 From: Matthijs van Otterdijk Date: Fri, 22 Mar 2024 15:10:30 +0100 Subject: [PATCH 7/8] fix loading of derived domains --- vectorlink/src/domain.rs | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/vectorlink/src/domain.rs b/vectorlink/src/domain.rs index c6528f5..e5df6cf 100644 --- a/vectorlink/src/domain.rs +++ b/vectorlink/src/domain.rs @@ -174,28 +174,22 @@ impl DerivedDomainConfiguration { &self, path: P, ) -> Result + Send + Sync + 'static>, io::Error> { + let vecs_path = path.as_ref().join("quantized.vecs"); + let quantizer_path = path.as_ref().join("quantizer"); match self { Self::SmallPq => { - let file = RwLock::new(VectorFile::open( - path.as_ref().join("quantized.vecs"), - true, - )?); - // panic here if T is not what we expect - let quantizer: HnswQuantizer16 = - HnswQuantizer::deserialize(path, ()).expect("TODO"); + let file = RwLock::new(VectorFile::open(&vecs_path, true)?); + let quantizer: HnswQuantizer16 = HnswQuantizer::deserialize(&quantizer_path, ()) + .expect("hnsw deserialization failed (small)"); let domain: PqDerivedDomain16 = PqDerivedDomain { file, quantizer }; Ok(domain.as_arc::().unwrap()) } Self::LargePq => { - let file = RwLock::new(VectorFile::open( - path.as_ref().join("quantized.vecs"), - true, - )?); - // panic here if T is not what we expect - let quantizer: HnswQuantizer32 = - HnswQuantizer::deserialize(path, ()).expect("TODO"); + let file = RwLock::new(VectorFile::open(&vecs_path, true)?); + let quantizer: HnswQuantizer32 = HnswQuantizer::deserialize(&quantizer_path, ()) + .expect("hnsw deserialization failed (large)"); let domain: PqDerivedDomain32 = PqDerivedDomain { file, quantizer }; From 8c4ad7844baf4816620e6bf5fd6dc891936de889 Mon Sep 17 00:00:00 2001 From: Matthijs van Otterdijk Date: Tue, 26 Mar 2024 08:50:37 +0100 Subject: [PATCH 8/8] assorted broken work --- vectorlink/src/batch.rs | 29 +- vectorlink/src/comparator.rs | 485 +++++++++++++++++--------------- vectorlink/src/configuration.rs | 42 ++- vectorlink/src/domain.rs | 141 +++++++--- vectorlink/src/main.rs | 1 - vectorlink/src/server.rs | 2 +- vectorlink/src/vecmath.rs | 4 +- 7 files changed, 411 insertions(+), 293 deletions(-) diff --git a/vectorlink/src/batch.rs b/vectorlink/src/batch.rs index b3bbef9..f878851 100644 --- a/vectorlink/src/batch.rs +++ b/vectorlink/src/batch.rs @@ -20,9 +20,11 @@ use urlencoding::encode; use crate::{ comparator::{ - Centroid16Comparator, DiskOpenAIComparator, OpenAIComparator, Quantized16Comparator, + Centroid16Comparator, DiskOpenAIComparator, DomainQuantizer, HnswQuantizer16, + OpenAIComparator, Quantized16Comparator, }, configuration::HnswConfiguration, + domain::{PqDerivedDomainInfo16, PqDerivedDomainInitializer16}, indexer::{create_index_name, index_serialization_path}, openai::{embeddings_for, EmbeddingError, Model}, server::Operation, @@ -168,7 +170,7 @@ pub async fn index_using_operations_and_vectors< op_file_path: P2, size: usize, id_offset: u64, - quantize_hnsw: bool, + quantize_hnsw: Option<&str>, model: Model, ) -> Result<(), IndexingError> { // Start at last hnsw offset @@ -235,20 +237,37 @@ pub async fn index_using_operations_and_vectors< .collect(); eprintln!("ready to generate hnsw"); - let hnsw = if quantize_hnsw { + let hnsw = if let Some(pq_name) = quantize_hnsw { let number_of_vectors = NUMBER_OF_CENTROIDS / 10; let c = DiskOpenAIComparator::new( domain_obj.name().to_owned(), Arc::new(domain_obj.immutable_file()), ); + + let derived_domain_info = domain_obj.get_derived_domain_info(pq_name); + if derived_domain_info.is_none() { + eprintln!("pq derived domain ({pq_name}) doesn't exist yet. constructing now"); + domain_obj + .create_derived(pq_name.to_string(), PqDerivedDomainInitializer16::default()) + .unwrap(); // TODO + } + // lazy - we just look it up again and now it should exist + let derived_domain_info: PqDerivedDomainInfo16 = + domain_obj.get_derived_domain_info(pq_name).unwrap(); + + let quantizer = derived_domain_info.quantizer.clone()); + + let quantized_comparator = + Quantized16Comparator::load(&vs, domain.to_string(), pq_name.to_string())?; + let hnsw: QuantizedHnsw< EMBEDDING_LENGTH, CENTROID_16_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, - Centroid16Comparator, Quantized16Comparator, DiskOpenAIComparator, - > = QuantizedHnsw::new(number_of_vectors, c); + Arc, + > = QuantizedHnsw::generate(quantizer, quantized_comparator, c, vecs); HnswConfiguration::SmallQuantizedOpenAi(model, hnsw) } else { let hnsw = Hnsw::generate(comparator, vecs, 24, 48, 12); diff --git a/vectorlink/src/comparator.rs b/vectorlink/src/comparator.rs index acb08d4..97e9a5b 100644 --- a/vectorlink/src/comparator.rs +++ b/vectorlink/src/comparator.rs @@ -1,23 +1,17 @@ -use parallel_hnsw::pq::{ - CentroidComparatorConstructor, HnswQuantizer, PartialDistance, QuantizedComparatorConstructor, -}; -use rand::distributions::Uniform; -use rand::{thread_rng, Rng}; +use parallel_hnsw::pq::{HnswQuantizer, PartialDistance, Quantizer}; use serde::{Deserialize, Serialize}; -use std::collections::HashSet; use std::fs::OpenOptions; -use std::io::{Read, Write}; +use std::io::{self, BufReader, Read, Write}; use std::marker::PhantomData; -use std::path::PathBuf; use std::{path::Path, sync::Arc}; -use parallel_hnsw::{pq, Comparator, Serializable, SerializationError, VectorId}; +use parallel_hnsw::{Comparator, Serializable, SerializationError, VectorId}; +use crate::domain::PqDerivedDomainInfo; use crate::store::{ImmutableVectorFile, LoadedVectorRange, VectorFile}; use crate::vecmath::{ - self, EuclideanDistance16, EuclideanDistance32, Quantized16Embedding, Quantized32Embedding, - CENTROID_16_LENGTH, CENTROID_32_LENGTH, EMBEDDING_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, - QUANTIZED_32_EMBEDDING_LENGTH, + self, EuclideanDistance16, EuclideanDistance32, CENTROID_16_LENGTH, CENTROID_32_LENGTH, + EMBEDDING_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH, }; use crate::{ vecmath::{normalized_cosine_distance, Embedding}, @@ -72,7 +66,7 @@ impl Serializable for DiskOpenAIComparator { fn deserialize>( path: P, - store: Arc, + store: &Arc, ) -> Result { let mut comparator_file = OpenOptions::new().read(true).open(path)?; let mut contents = String::new(); @@ -86,35 +80,6 @@ impl Serializable for DiskOpenAIComparator { } } -impl pq::VectorSelector for DiskOpenAIComparator { - type T = Embedding; - - fn selection(&self, size: usize) -> Vec { - // TODO do something else for sizes close to number of vecs - if size >= self.vectors.num_vecs() { - return self.vectors.all_vectors().unwrap().clone().into_vec(); - } - let mut rng = thread_rng(); - let mut set = HashSet::new(); - let range = Uniform::from(0_usize..self.vectors.num_vecs()); - while set.len() != size { - let candidate = rng.sample(&range); - set.insert(candidate); - } - - set.into_iter() - .map(|index| self.vectors.vec(index).unwrap()) - .collect() - } - - fn vector_chunks(&self) -> impl Iterator> { - self.vectors - .vector_chunks(1_000_000) - .unwrap() - .map(|x| x.unwrap()) - } -} - #[derive(Clone)] pub struct OpenAIComparator { domain_name: String, @@ -166,7 +131,7 @@ impl Serializable for OpenAIComparator { fn deserialize>( path: P, - store: Arc, + store: &Arc, ) -> Result { let mut comparator_file = OpenOptions::new().read(true).open(path)?; let mut contents = String::new(); @@ -230,6 +195,17 @@ pub struct ArrayCentroidComparator { calculator: PhantomData, } +impl + Default> ArrayCentroidComparator { + pub fn new(centroids: Vec<[f32; N]>) -> Self { + let len = centroids.len(); + Self { + distances: Arc::new(MemoizedPartialDistances::new(C::default(), ¢roids)), + centroids: Arc::new(LoadedVectorRange::new(centroids, 0..len)), + calculator: PhantomData, + } + } +} + impl Clone for ArrayCentroidComparator { fn clone(&self) -> Self { Self { @@ -244,19 +220,6 @@ unsafe impl Sync for ArrayCentroidComparator {} pub type Centroid16Comparator = ArrayCentroidComparator; pub type Centroid32Comparator = ArrayCentroidComparator; -impl + Default> - CentroidComparatorConstructor for ArrayCentroidComparator -{ - fn new(centroids: Vec) -> Self { - let len = centroids.len(); - Self { - distances: Arc::new(MemoizedPartialDistances::new(C::default(), ¢roids)), - centroids: Arc::new(LoadedVectorRange::new(centroids, 0..len)), - calculator: PhantomData, - } - } -} - impl + Default> Comparator for ArrayCentroidComparator { @@ -294,7 +257,7 @@ impl + Default> Serializable fn deserialize>( path: P, - _params: Self::Params, + _params: &Self::Params, ) -> Result { let vector_file: VectorFile<[f32; N]> = VectorFile::open(path, true)?; let centroids = Arc::new(vector_file.all_vectors()?); @@ -310,174 +273,165 @@ impl + Default> Serializable } } -#[derive(Clone)] -pub struct Quantized32Comparator { - pub cc: Centroid32Comparator, - pub data: Arc>, +pub struct QuantizedDomainComparator< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C, +> { + domain: String, + subdomain: String, + cc: ArrayCentroidComparator, + data: Arc>, } -impl QuantizedComparatorConstructor for Quantized32Comparator { - type CentroidComparator = Centroid32Comparator; - - fn new(cc: &Self::CentroidComparator) -> Self { +impl Clone + for QuantizedDomainComparator +{ + fn clone(&self) -> Self { Self { - cc: cc.clone(), - data: Default::default(), + domain: self.domain.clone(), + subdomain: self.subdomain.clone(), + cc: self.cc.clone(), + data: self.data.clone(), } } } -#[derive(Clone)] -pub struct Quantized16Comparator { - pub cc: Centroid16Comparator, - pub data: Arc>, -} - -impl QuantizedComparatorConstructor for Quantized16Comparator { - type CentroidComparator = Centroid16Comparator; +pub type Quantized16Comparator = QuantizedDomainComparator< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + EuclideanDistance16, +>; +pub type Quantized32Comparator = QuantizedDomainComparator< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + EuclideanDistance32, +>; - fn new(cc: &Self::CentroidComparator) -> Self { - Self { - cc: cc.clone(), - data: Default::default(), - } - } +#[derive(Serialize, Deserialize)] +struct QuantizedDomainComparatorMeta { + domain: String, + subdomain: String, } -impl PartialDistance for Quantized32Comparator { - fn partial_distance(&self, i: u16, j: u16) -> f32 { - self.cc.partial_distance(i, j) +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + DistanceCalculator, + > QuantizedDomainComparator +where + ArrayCentroidComparator: 'static + Comparator, +{ + pub fn load(store: &VectorStore, domain: String, subdomain: String) -> io::Result { + assert_eq!(SIZE, CENTROID_SIZE * QUANTIZED_SIZE); // TODO compile-time macro check this + let domain_info = store.get_domain(&domain)?; + let derived_domain_info: PqDerivedDomainInfo = + domain_info + .get_derived_domain_info(&subdomain) + .expect("pq subdomain not found"); + + Ok(Self { + domain, + subdomain, + cc: derived_domain_info.quantizer.quantizer.comparator().clone(), + data: Arc::new(derived_domain_info.file.all_vectors()?), + }) } } - -impl PartialDistance for Quantized16Comparator { +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + DistanceCalculator, + > PartialDistance for QuantizedDomainComparator +where + ArrayCentroidComparator: 'static + Comparator, +{ fn partial_distance(&self, i: u16, j: u16) -> f32 { self.cc.partial_distance(i, j) } } -impl Comparator for Quantized32Comparator +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + DistanceCalculator, + > Serializable for QuantizedDomainComparator where - Quantized32Comparator: PartialDistance, + ArrayCentroidComparator: 'static + Comparator, { - type T = Quantized32Embedding; - - type Borrowable<'a> = &'a Quantized32Embedding; - - fn lookup(&self, v: VectorId) -> Self::Borrowable<'_> { - &self.data[v.0] - } - - fn compare_raw(&self, v1: &Self::T, v2: &Self::T) -> f32 { - let mut partial_distances = [0.0_f32; QUANTIZED_32_EMBEDDING_LENGTH]; - for ix in 0..QUANTIZED_32_EMBEDDING_LENGTH { - let partial_1 = v1[ix]; - let partial_2 = v2[ix]; - let partial_distance = self.cc.partial_distance(partial_1, partial_2); - partial_distances[ix] = partial_distance; - } - - vecmath::sum_48(&partial_distances).sqrt() - } -} - -impl Serializable for Quantized32Comparator { - type Params = (); + type Params = Arc; fn serialize>(&self, path: P) -> Result<(), SerializationError> { - let path_buf: PathBuf = path.as_ref().into(); - std::fs::create_dir_all(&path_buf)?; - - let index_path = path_buf.join("index"); - self.cc.serialize(index_path)?; + let meta = QuantizedDomainComparatorMeta { + domain: self.domain.clone(), + subdomain: self.subdomain.clone(), + }; + let meta_string = serde_json::to_string(&meta)?; + std::fs::write(path, meta_string)?; - let vector_path = path_buf.join("vectors"); - let mut vector_file = VectorFile::open(vector_path, true)?; - vector_file.append_vector_range(self.data.vecs())?; Ok(()) } fn deserialize>( path: P, - _params: Self::Params, + params: &Self::Params, ) -> Result { - let path_buf: PathBuf = path.as_ref().into(); - let index_path = path_buf.join("index"); - let cc = Centroid32Comparator::deserialize(index_path, ())?; - - let vector_path = path_buf.join("vectors"); - let vector_file = VectorFile::open(vector_path, true)?; - let range = vector_file.all_vectors()?; + let comparator_file = OpenOptions::new().read(true).open(path)?; + let QuantizedDomainComparatorMeta { domain, subdomain } = + serde_json::from_reader(BufReader::new(comparator_file))?; - let data = Arc::new(range); - Ok(Self { cc, data }) + Ok(Self::load(¶ms, domain, subdomain)?) } } -impl pq::VectorStore for Quantized32Comparator { - type T = ::T; - - fn store(&mut self, i: Box>) -> Vec { - // this is p retty stupid, but then, these comparators should not be storing in the first place - let mut new_contents: Vec = Vec::with_capacity(self.data.len() + i.size_hint().0); - new_contents.extend(self.data.vecs().iter()); - let vid = self.data.len(); - let mut vectors: Vec = Vec::new(); - new_contents.extend(i.enumerate().map(|(i, v)| { - vectors.push(VectorId(vid + i)); - v - })); - let end = new_contents.len(); - - let data = LoadedVectorRange::new(new_contents, 0..end); - self.data = Arc::new(data); - - vectors - } -} - -impl pq::VectorSelector for OpenAIComparator { - type T = Embedding; - - fn selection(&self, size: usize) -> Vec { - // TODO do something else for sizes close to number of vecs - let mut rng = thread_rng(); - let mut set = HashSet::new(); - let range = Uniform::from(0_usize..size); - while set.len() != size { - let candidate = rng.sample(&range); - set.insert(candidate); - } - - set.into_iter() - .map(|index| *self.range.vec(index)) - .collect() - } +pub type QuantizedDomainComparator16 = QuantizedDomainComparator< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + Centroid16Comparator, +>; +pub type QuantizedDomainComparator32 = QuantizedDomainComparator< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + Centroid32Comparator, +>; - fn vector_chunks(&self) -> impl Iterator> { - // low quality make better - self.range.vecs().chunks(1_000_000).map(|c| c.to_vec()) - } +pub struct QuantizedEmbeddingSizeCombination< + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, +>; +pub trait ImplementedQuantizedEmbeddingSizeCombination< + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, +> +{ + fn compare_quantized( + comparator: &C, + v1: &[u16; QUANTIZED_SIZE], + v2: &[u16; QUANTIZED_SIZE], + ) -> f32; } -impl Comparator for Quantized16Comparator -where - Quantized16Comparator: PartialDistance, +impl ImplementedQuantizedEmbeddingSizeCombination + for QuantizedEmbeddingSizeCombination { - type T = Quantized16Embedding; - - type Borrowable<'a> = &'a Self::T; - - fn lookup(&self, v: VectorId) -> Self::Borrowable<'_> { - self.data.vec(v.0) - } - - fn compare_raw(&self, v1: &Self::T, v2: &Self::T) -> f32 { + fn compare_quantized( + comparator: &C, + v1: &[u16; QUANTIZED_16_EMBEDDING_LENGTH], + v2: &[u16; QUANTIZED_16_EMBEDDING_LENGTH], + ) -> f32 { let mut partial_distances = [0.0_f32; QUANTIZED_16_EMBEDDING_LENGTH]; for ix in 0..QUANTIZED_16_EMBEDDING_LENGTH { let partial_1 = v1[ix]; let partial_2 = v2[ix]; - let partial_distance = self.cc.partial_distance(partial_1, partial_2); + let partial_distance = comparator.partial_distance(partial_1, partial_2); partial_distances[ix] = partial_distance; } @@ -485,59 +439,44 @@ where } } -impl Serializable for Quantized16Comparator { - type Params = (); - - fn serialize>(&self, path: P) -> Result<(), SerializationError> { - let path_buf: PathBuf = path.as_ref().into(); - std::fs::create_dir_all(&path_buf)?; - - let index_path = path_buf.join("index"); - self.cc.serialize(index_path)?; - - let vector_path = path_buf.join("vectors"); - let mut vector_file = VectorFile::create(vector_path, true)?; - vector_file.append_vector_range(self.data.vecs())?; - Ok(()) - } - - fn deserialize>( - path: P, - _params: Self::Params, - ) -> Result { - let path_buf: PathBuf = path.as_ref().into(); - let index_path = path_buf.join("index"); - let cc = Centroid16Comparator::deserialize(index_path, ())?; - - let vector_path = path_buf.join("vectors"); - let vector_file = VectorFile::open(vector_path, true)?; - let range = vector_file.all_vectors()?; +impl ImplementedQuantizedEmbeddingSizeCombination + for QuantizedEmbeddingSizeCombination +{ + fn compare_quantized( + comparator: &C, + v1: &[u16; QUANTIZED_32_EMBEDDING_LENGTH], + v2: &[u16; QUANTIZED_32_EMBEDDING_LENGTH], + ) -> f32 { + let mut partial_distances = [0.0_f32; QUANTIZED_32_EMBEDDING_LENGTH]; + for ix in 0..QUANTIZED_32_EMBEDDING_LENGTH { + let partial_1 = v1[ix]; + let partial_2 = v2[ix]; + let partial_distance = comparator.partial_distance(partial_1, partial_2); + partial_distances[ix] = partial_distance; + } - let data = Arc::new(range); - Ok(Self { cc, data }) + vecmath::sum_48(&partial_distances).sqrt() } } -impl pq::VectorStore for Quantized16Comparator { - type T = ::T; - - fn store(&mut self, i: Box>) -> Vec { - // this is p retty stupid, but then, these comparators should not be storing in the first place - let mut new_contents: Vec = Vec::with_capacity(self.data.len() + i.size_hint().0); - new_contents.extend(self.data.vecs().iter()); - let vid = self.data.len(); - let mut vectors: Vec = Vec::new(); - new_contents.extend(i.enumerate().map(|(i, v)| { - vectors.push(VectorId(vid + i)); - v - })); +impl + Comparator for QuantizedDomainComparator +where + QuantizedEmbeddingSizeCombination: + ImplementedQuantizedEmbeddingSizeCombination, +{ + type T = [u16; QUANTIZED_SIZE]; - let end = new_contents.len(); + type Borrowable<'a> = &'a [u16; QUANTIZED_SIZE]; - let data = LoadedVectorRange::new(new_contents, 0..end); - self.data = Arc::new(data); + fn lookup(&self, v: VectorId) -> Self::Borrowable<'_> { + &self.data[v.0] + } - vectors + fn compare_raw(&self, v1: &Self::T, v2: &Self::T) -> f32 { + QuantizedEmbeddingSizeCombination::::compare_quantized( + &self.cc, v1, v2, + ) } } @@ -554,6 +493,90 @@ pub type HnswQuantizer32 = HnswQuantizer< Centroid32Comparator, >; +pub struct DomainQuantizer< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C, +> { + domain: String, + derived_domain: String, + quantizer: Arc< + HnswQuantizer< + SIZE, + CENTROID_SIZE, + QUANTIZED_SIZE, + ArrayCentroidComparator, + >, + >, +} + +impl Clone + for DomainQuantizer +{ + fn clone(&self) -> Self { + Self { + domain: self.domain.clone(), + derived_domain: self.derived_domain.clone(), + quantizer: self.quantizer.clone(), + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct DomainQuantizerMeta { + domain: String, + derived_domain: String, +} + +impl + Quantizer for DomainQuantizer +where + ArrayCentroidComparator: Comparator, +{ + fn quantize(&self, vec: &[f32; SIZE]) -> [u16; QUANTIZED_SIZE] { + self.quantizer.quantize(vec) + } + + fn reconstruct(&self, qvec: &[u16; QUANTIZED_SIZE]) -> [f32; SIZE] { + self.quantizer.reconstruct(qvec) + } +} + +impl + Serializable for DomainQuantizer +{ + type Params = Arc; + + fn serialize>(&self, path: P) -> Result<(), SerializationError> { + let meta = DomainQuantizerMeta { + domain: self.domain.clone(), + derived_domain: self.derived_domain.clone(), + }; + let data = serde_json::to_string(&meta)?; + std::fs::write(path, data)?; + + Ok(()) + } + + fn deserialize>( + path: P, + params: &Self::Params, + ) -> Result { + let DomainQuantizerMeta { + domain, + derived_domain, + } = serde_json::from_reader(BufReader::new(std::fs::File::open(path)?))?; + + let d = params.get_domain(&domain).expect("domain not found"); + let dd: PqDerivedDomainInfo = d + .get_derived_domain_info(&derived_domain) + .expect("derived domain not found"); + + Ok(dd.quantizer.clone()) + } +} + #[cfg(test)] mod tests { use std::sync::{Arc, RwLock}; diff --git a/vectorlink/src/configuration.rs b/vectorlink/src/configuration.rs index dd2509d..d3ca518 100644 --- a/vectorlink/src/configuration.rs +++ b/vectorlink/src/configuration.rs @@ -1,19 +1,23 @@ use std::{fs::OpenOptions, path::PathBuf, sync::Arc}; use itertools::Either; -use parallel_hnsw::{pq::QuantizedHnsw, AbstractVector, Hnsw, Serializable, VectorId}; +use parallel_hnsw::{ + pq::{HnswQuantizer, QuantizedHnsw}, + AbstractVector, Hnsw, Serializable, VectorId, +}; use rayon::iter::IndexedParallelIterator; use serde::{Deserialize, Serialize}; use crate::{ comparator::{ - Centroid16Comparator, Centroid32Comparator, DiskOpenAIComparator, OpenAIComparator, - Quantized16Comparator, Quantized32Comparator, + Centroid16Comparator, Centroid32Comparator, DiskOpenAIComparator, DomainQuantizer, + OpenAIComparator, Quantized16Comparator, Quantized32Comparator, }, openai::Model, vecmath::{ - Embedding, CENTROID_16_LENGTH, CENTROID_32_LENGTH, EMBEDDING_LENGTH, - QUANTIZED_16_EMBEDDING_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH, + Embedding, EuclideanDistance16, EuclideanDistance32, CENTROID_16_LENGTH, + CENTROID_32_LENGTH, EMBEDDING_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, }, vectors::VectorStore, }; @@ -42,9 +46,14 @@ pub enum HnswConfiguration { EMBEDDING_LENGTH, CENTROID_32_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH, - Centroid32Comparator, Quantized32Comparator, DiskOpenAIComparator, + DomainQuantizer< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + EuclideanDistance32, + >, >, ), SmallQuantizedOpenAi( @@ -53,9 +62,14 @@ pub enum HnswConfiguration { EMBEDDING_LENGTH, CENTROID_16_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, - Centroid16Comparator, Quantized16Comparator, DiskOpenAIComparator, + DomainQuantizer< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + EuclideanDistance16, + >, >, ), UnquantizedOpenAi(Model, OpenAIHnsw), @@ -173,12 +187,12 @@ impl Serializable for HnswConfiguration { path: P, ) -> Result<(), parallel_hnsw::SerializationError> { match self { - HnswConfiguration::QuantizedOpenAi(_, hnsw) => { - hnsw.serialize(&path)?; - } - HnswConfiguration::UnquantizedOpenAi(_, qhnsw) => { + HnswConfiguration::QuantizedOpenAi(_, qhnsw) => { qhnsw.serialize(&path)?; } + HnswConfiguration::UnquantizedOpenAi(_, hnsw) => { + hnsw.serialize(&path)?; + } HnswConfiguration::SmallQuantizedOpenAi(_, qhnsw) => { qhnsw.serialize(&path)?; } @@ -196,7 +210,7 @@ impl Serializable for HnswConfiguration { fn deserialize>( path: P, - params: Self::Params, + params: &Self::Params, ) -> Result { let state_path: PathBuf = path.as_ref().join("state.json"); let mut state_file = OpenOptions::new() @@ -209,14 +223,14 @@ impl Serializable for HnswConfiguration { Ok(match state.typ { HnswConfigurationType::QuantizedOpenAi => HnswConfiguration::QuantizedOpenAi( state.model, - QuantizedHnsw::deserialize(path, params)?, + QuantizedHnsw::deserialize(path, &(params.clone(), params.clone()))?, ), HnswConfigurationType::UnquantizedOpenAi => { HnswConfiguration::UnquantizedOpenAi(state.model, Hnsw::deserialize(path, params)?) } HnswConfigurationType::SmallQuantizedOpenAi => HnswConfiguration::SmallQuantizedOpenAi( state.model, - QuantizedHnsw::deserialize(path, params)?, + QuantizedHnsw::deserialize(path, &(params.clone(), params.clone()))?, ), }) } diff --git a/vectorlink/src/domain.rs b/vectorlink/src/domain.rs index e5df6cf..0616c51 100644 --- a/vectorlink/src/domain.rs +++ b/vectorlink/src/domain.rs @@ -14,7 +14,7 @@ use linfa::{traits::Fit, DatasetBase}; use linfa_clustering::KMeans; use ndarray::{Array, Array2}; use parallel_hnsw::{ - pq::{CentroidComparatorConstructor, HnswQuantizer, Quantizer}, + pq::{HnswQuantizer, Quantizer}, Comparator, Hnsw, Serializable, VectorId, }; use rand::{distributions::Uniform, rngs::StdRng, thread_rng, Rng, SeedableRng}; @@ -22,11 +22,15 @@ use serde::{Deserialize, Serialize}; use urlencoding::encode; use crate::{ - comparator::{Centroid16Comparator, Centroid32Comparator, HnswQuantizer16, HnswQuantizer32}, + comparator::{ + ArrayCentroidComparator, DistanceCalculator, DomainQuantizer, HnswQuantizer16, + HnswQuantizer32, + }, store::{ImmutableVectorFile, LoadedVectorRange, SequentialVectorLoader, VectorFile}, vecmath::{ - Embedding, CENTROID_16_LENGTH, CENTROID_32_LENGTH, EMBEDDING_LENGTH, - QUANTIZED_16_EMBEDDING_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH, + Embedding, EuclideanDistance16, EuclideanDistance32, CENTROID_16_LENGTH, + CENTROID_32_LENGTH, EMBEDDING_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, }, }; @@ -46,17 +50,20 @@ pub trait Deriver: Any { type From: Copy; fn concatenate_derived(&self, loader: SequentialVectorLoader) -> io::Result<()>; + fn configuration(&self) -> DerivedDomainConfiguration; + fn get_derived_domain_info(&self) -> Box; fn concatenate_file(&self, file: &VectorFile) -> io::Result<()> { self.concatenate_derived(file.vector_chunks(self.chunk_size())?)?; Ok(()) } - fn configuration(&self) -> DerivedDomainConfiguration; fn chunk_size(&self) -> usize { 1_000 } } +pub trait DerivedDomainInfo: Any {} + pub trait DerivedDomainInitializer { fn initialize( &self, @@ -85,15 +92,17 @@ pub struct PqDerivedDomain< C, > { file: RwLock>, - quantizer: HnswQuantizer, + quantizer: DomainQuantizer, } impl< const SIZE: usize, const CENTROID_SIZE: usize, const QUANTIZED_SIZE: usize, - C: 'static + Comparator, + C: 'static + DistanceCalculator, > PqDerivedDomain +where + ArrayCentroidComparator: 'static + Comparator, { fn as_arc( self, @@ -117,8 +126,10 @@ impl< const SIZE: usize, const CENTROID_SIZE: usize, const QUANTIZED_SIZE: usize, - C: 'static + Comparator, + C: 'static + DistanceCalculator, > Deriver for PqDerivedDomain +where + ArrayCentroidComparator: 'static + Comparator, { type From = [f32; SIZE]; @@ -148,19 +159,66 @@ impl< _ => panic!("unserializable pq derived domain"), } } + + fn get_derived_domain_info(&self) -> Box { + let info = PqDerivedDomainInfo { + file: self.file.read().unwrap().as_immutable(), + quantizer: self.quantizer.clone(), + }; + Box::new(info) + } +} + +pub struct PqDerivedDomainInfo< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C, +> { + pub file: ImmutableVectorFile<[u16; QUANTIZED_SIZE]>, + pub quantizer: DomainQuantizer, } +impl + DerivedDomainInfo for PqDerivedDomainInfo +{ +} + +pub type PqDerivedDomainInfo16 = PqDerivedDomainInfo< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + EuclideanDistance16, +>; +pub type PqDerivedDomainInfo32 = PqDerivedDomainInfo< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + EuclideanDistance32, +>; pub type PqDerivedDomain16 = PqDerivedDomain< EMBEDDING_LENGTH, CENTROID_16_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, - Centroid16Comparator, + EuclideanDistance16, >; pub type PqDerivedDomain32 = PqDerivedDomain< EMBEDDING_LENGTH, CENTROID_32_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH, - Centroid32Comparator, + EuclideanDistance32, +>; +pub type PqDerivedDomainInitializer16 = PqDerivedDomainInitializer< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + EuclideanDistance16, +>; +pub type PqDerivedDomainInitializer32 = PqDerivedDomainInitializer< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + EuclideanDistance32, >; #[derive(Serialize, Deserialize, ValueEnum, Debug, Clone, Copy)] @@ -179,19 +237,25 @@ impl DerivedDomainConfiguration { match self { Self::SmallPq => { let file = RwLock::new(VectorFile::open(&vecs_path, true)?); - let quantizer: HnswQuantizer16 = HnswQuantizer::deserialize(&quantizer_path, ()) - .expect("hnsw deserialization failed (small)"); - let domain: PqDerivedDomain16 = PqDerivedDomain { file, quantizer }; + let quantizer: HnswQuantizer16 = HnswQuantizer::deserialize(&quantizer_path, &()) + .expect("hnsw deserialization failed (small)"); + let domain: PqDerivedDomain16 = PqDerivedDomain { + file, + quantizer: quantizer, + }; Ok(domain.as_arc::().unwrap()) } Self::LargePq => { let file = RwLock::new(VectorFile::open(&vecs_path, true)?); - let quantizer: HnswQuantizer32 = HnswQuantizer::deserialize(&quantizer_path, ()) + let quantizer: HnswQuantizer32 = HnswQuantizer::deserialize(&quantizer_path, &()) .expect("hnsw deserialization failed (large)"); - let domain: PqDerivedDomain32 = PqDerivedDomain { file, quantizer }; + let domain: PqDerivedDomain32 = PqDerivedDomain { + file, + quantizer: quantizer, + }; Ok(domain.as_arc::().unwrap()) } @@ -208,7 +272,7 @@ impl DerivedDomainConfiguration { EMBEDDING_LENGTH, CENTROID_16_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, - Centroid16Comparator, + EuclideanDistance16, >::default(); let boxed: Box + 'static + Send + Sync> = @@ -221,7 +285,7 @@ impl DerivedDomainConfiguration { EMBEDDING_LENGTH, CENTROID_32_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH, - Centroid32Comparator, + EuclideanDistance32, >::default(); let boxed: Box + 'static + Send + Sync> = @@ -245,11 +309,7 @@ impl< const SIZE: usize, const CENTROID_SIZE: usize, const QUANTIZED_SIZE: usize, - C: 'static - + Comparator - + CentroidComparatorConstructor - + Serializable - + Send, + C: 'static + DistanceCalculator, > PqDerivedDomainInitializer { } @@ -266,13 +326,11 @@ impl< const SIZE: usize, const CENTROID_SIZE: usize, const QUANTIZED_SIZE: usize, - C: 'static - + Comparator - + CentroidComparatorConstructor - + Serializable - + Send, + C: 'static + DistanceCalculator + Default + Send, > DerivedDomainInitializer<[f32; SIZE]> for PqDerivedDomainInitializer +where + ArrayCentroidComparator: 'static + Comparator, { fn initialize( &self, @@ -331,11 +389,11 @@ impl< eprintln!("Number of centroids: {}", centroids.len()); let vector_ids = (0..centroids.len()).map(VectorId).collect(); - let centroid_comparator = C::new(centroids); + let centroid_comparator = ArrayCentroidComparator::new(centroids); let centroid_m = 24; let centroid_m0 = 48; let centroid_order = 12; - let mut centroid_hnsw: Hnsw = Hnsw::generate( + let mut centroid_hnsw: Hnsw> = Hnsw::generate( centroid_comparator, vector_ids, centroid_m, @@ -345,8 +403,12 @@ impl< //centroid_hnsw.improve_index(); centroid_hnsw.improve_neighbors(0.01, 1.0); - let centroid_quantizer: HnswQuantizer = - HnswQuantizer::new(centroid_hnsw); + let centroid_quantizer: HnswQuantizer< + SIZE, + CENTROID_SIZE, + QUANTIZED_SIZE, + ArrayCentroidComparator, + > = HnswQuantizer::new(centroid_hnsw); let quantizer_path = path.join("quantizer"); centroid_quantizer.serialize(quantizer_path)?; @@ -357,7 +419,7 @@ impl< let deriver = PqDerivedDomain { file: RwLock::new(quantized_file), - quantizer: centroid_quantizer, + quantizer: Arc::new(centroid_quantizer), }; Ok(Arc::new(deriver)) } @@ -502,13 +564,14 @@ impl Domain { Ok(()) } - pub fn get_derived<'a, T2: Deriver + Send + Sync>( - &'a self, - name: &str, - ) -> Option + 'a> { - let derived_domains = self.derived_domains.read().unwrap(); - let derived = derived_domains.get(name)?; + pub fn get_derived_domain_info(&self, name: &str) -> Option { + let domains = self.derived_domains.read().unwrap(); + let deriver = domains.get(name)?; + let info = deriver.get_derived_domain_info() as Box; + let downcast_info: Box = info + .downcast() + .expect("derived domain info not of expected type"); - Some(Arc::downcast::(derived.clone()).expect("derived domain was not of expected type")) + Some(*downcast_info) } } diff --git a/vectorlink/src/main.rs b/vectorlink/src/main.rs index 2f0c334..a759740 100644 --- a/vectorlink/src/main.rs +++ b/vectorlink/src/main.rs @@ -27,7 +27,6 @@ use domain::DerivedDomainConfiguration; //use hnsw::Hnsw; use openai::Model; use parallel_hnsw::pq::Quantizer; -use parallel_hnsw::pq::VectorSelector; use parallel_hnsw::AbstractVector; use parallel_hnsw::Comparator; use parallel_hnsw::Serializable; diff --git a/vectorlink/src/server.rs b/vectorlink/src/server.rs index 181cb71..7dac9d0 100644 --- a/vectorlink/src/server.rs +++ b/vectorlink/src/server.rs @@ -460,7 +460,7 @@ impl Service { let index_path = index_serialization_path(path, index_id); Ok(Arc::new(OpenAIHnsw::deserialize( index_path, - self.vector_store.clone(), + &self.vector_store, )?)) } } diff --git a/vectorlink/src/vecmath.rs b/vectorlink/src/vecmath.rs index 4e2db8f..518458f 100644 --- a/vectorlink/src/vecmath.rs +++ b/vectorlink/src/vecmath.rs @@ -128,7 +128,7 @@ pub fn cosine_partial_distance_32(v1: &Centroid32, v2: &Centroid32) -> f32 { simd::cosine_partial_distance_32_simd(v1, v2) } -#[derive(Default)] +#[derive(Default, Clone)] pub struct EuclideanDistance32; impl DistanceCalculator for EuclideanDistance32 { type T = Centroid32; @@ -148,7 +148,7 @@ impl DistanceCalculator for EuclideanDistance32 { } } -#[derive(Default)] +#[derive(Default, Clone)] pub struct EuclideanDistance16; impl DistanceCalculator for EuclideanDistance16 { type T = Centroid16;