From dd0a16ac15266127b78f9c69aeba68cfd44b8c0b Mon Sep 17 00:00:00 2001 From: hammadb Date: Sat, 9 Mar 2024 14:12:17 -0800 Subject: [PATCH] [ENH] Add uint blockfile key/val --- .../blockstore/arrow_blockfile/block/delta.rs | 22 ++++++++- .../arrow_blockfile/block/iterator.rs | 6 ++- .../blockstore/arrow_blockfile/block/types.rs | 46 ++++++++++++++++++- .../blockstore/arrow_blockfile/blockfile.rs | 39 +++++++++++++++- rust/worker/src/blockstore/types.rs | 29 ++++++++---- rust/worker/src/index/fulltext/types.rs | 14 +++--- 6 files changed, 135 insertions(+), 21 deletions(-) diff --git a/rust/worker/src/blockstore/arrow_blockfile/block/delta.rs b/rust/worker/src/blockstore/arrow_blockfile/block/delta.rs index 52c0ba48daf3..99b7260c88d5 100644 --- a/rust/worker/src/blockstore/arrow_blockfile/block/delta.rs +++ b/rust/worker/src/blockstore/arrow_blockfile/block/delta.rs @@ -202,6 +202,7 @@ impl BlockDeltaInner { Value::Int32ArrayValue(arr) => acc + arr.len(), Value::StringValue(s) => acc + s.len(), Value::RoaringBitmapValue(bitmap) => acc + bitmap.serialized_size(), + Value::UintValue(_) => acc + 1, _ => unimplemented!("Value type not implemented"), }) } @@ -238,6 +239,7 @@ impl BlockDeltaInner { ValueType::Int32Array | ValueType::String | ValueType::RoaringBitmap => { bit_util::round_upto_multiple_of_64((item_count + 1) * 4) } + ValueType::Uint => 0, _ => unimplemented!("Value type not implemented"), } } @@ -245,7 +247,7 @@ impl BlockDeltaInner { fn offset_size_for_key_type(&self, item_count: usize, key_type: KeyType) -> usize { match key_type { KeyType::String => bit_util::round_upto_multiple_of_64((item_count + 1) * 4), - KeyType::Float => 0, + KeyType::Float | KeyType::Uint => 0, _ => unimplemented!("Key type not implemented"), } } @@ -429,7 +431,23 @@ mod test { let size = delta.get_size(); let block_data = BlockData::try_from(&delta).unwrap(); assert_eq!(size, block_data.get_size()); + } - let (split_key, delta) = delta.split(&block_provider); + #[test] + fn test_sizing_uint_key_val() { + let block_provider = ArrowBlockProvider::new(); + let block = block_provider.create_block(KeyType::Uint, ValueType::Uint); + let delta = BlockDelta::from(block.clone()); + + let n = 2000; + for i in 0..n { + let key = BlockfileKey::new("prefix".to_string(), Key::Uint(i as u32)); + let value = Value::UintValue(i as u32); + delta.add(key, value); + } + + let size = delta.get_size(); + let block_data = BlockData::try_from(&delta).unwrap(); + assert_eq!(size, block_data.get_size()); } } diff --git a/rust/worker/src/blockstore/arrow_blockfile/block/iterator.rs b/rust/worker/src/blockstore/arrow_blockfile/block/iterator.rs index 902066bedf41..5a771ea4ed39 100644 --- a/rust/worker/src/blockstore/arrow_blockfile/block/iterator.rs +++ b/rust/worker/src/blockstore/arrow_blockfile/block/iterator.rs @@ -1,6 +1,6 @@ use super::types::Block; use crate::blockstore::types::{BlockfileKey, Key, KeyType, Value, ValueType}; -use arrow::array::{Array, BooleanArray, Int32Array, ListArray, StringArray}; +use arrow::array::{Array, BooleanArray, Int32Array, ListArray, StringArray, UInt32Array}; /// An iterator over the contents of a block. /// This is a simple wrapper around the Arrow array data that is stored in the block. @@ -77,6 +77,10 @@ impl Iterator for BlockIterator { Some(key) => Key::Bool(key.value(self.index)), None => return None, }, + KeyType::Uint => match key.as_any().downcast_ref::() { + Some(key) => Key::Uint(key.value(self.index) as u32), + None => return None, + }, }; let value = match self.value_type { diff --git a/rust/worker/src/blockstore/arrow_blockfile/block/types.rs b/rust/worker/src/blockstore/arrow_blockfile/block/types.rs index 9c9d4f5ee1e3..7c39383d394d 100644 --- a/rust/worker/src/blockstore/arrow_blockfile/block/types.rs +++ b/rust/worker/src/blockstore/arrow_blockfile/block/types.rs @@ -2,7 +2,7 @@ use crate::blockstore::types::{BlockfileKey, Key, KeyType, Value, ValueType}; use crate::errors::{ChromaError, ErrorCodes}; use arrow::array::{ BinaryArray, BinaryBuilder, BooleanArray, BooleanBuilder, Float32Array, Float32Builder, - GenericByteBuilder, + GenericByteBuilder, UInt32Array, UInt32Builder, }; use arrow::{ array::{Array, Int32Array, Int32Builder, ListArray, ListBuilder, StringArray, StringBuilder}, @@ -125,6 +125,11 @@ impl Block { .unwrap() .value(i) } + Key::Uint(inner_key) => { + *inner_key + == key.as_any().downcast_ref::().unwrap().value(i) + as u32 + } }; if key_matches { match self.get_value_type() { @@ -166,6 +171,15 @@ impl Block { Err(_) => return None, } } + ValueType::Uint => { + return Some(Value::UintValue( + value + .as_any() + .downcast_ref::() + .unwrap() + .value(i), + )) + } // TODO: Add support for other types _ => unimplemented!(), } @@ -285,12 +299,14 @@ enum KeyBuilder { StringBuilder(StringBuilder), FloatBuilder(Float32Builder), BoolBuilder(BooleanBuilder), + UintBuilder(UInt32Builder), } enum ValueBuilder { Int32ArrayValueBuilder(ListBuilder), StringValueBuilder(StringBuilder), RoaringBitmapBuilder(BinaryBuilder), + UintValueBuilder(UInt32Builder), } /// BlockDataBuilder is used to build a block. It is used to add data to a block and then build the BlockData once all data has been added. @@ -367,6 +383,9 @@ impl BlockDataBuilder { KeyType::Bool => { KeyBuilder::BoolBuilder(BooleanBuilder::with_capacity(options.item_count)) } + KeyType::Uint => { + KeyBuilder::UintBuilder(UInt32Builder::with_capacity(options.item_count)) + } }; let value_builder = match value_type { ValueType::Int32Array => { @@ -379,6 +398,9 @@ impl BlockDataBuilder { options.item_count, options.total_value_capacity, )), + ValueType::Uint => { + ValueBuilder::UintValueBuilder(UInt32Builder::with_capacity(options.item_count)) + } ValueType::RoaringBitmap => ValueBuilder::RoaringBitmapBuilder( BinaryBuilder::with_capacity(options.item_count, options.total_value_capacity), ), @@ -428,6 +450,12 @@ impl BlockDataBuilder { } _ => unreachable!("Invalid key type for block"), }, + KeyBuilder::UintBuilder(ref mut builder) => match key.key { + Key::Uint(key) => { + builder.append_value(key); + } + _ => unreachable!("Invalid key type for block"), + }, } match self.value_builder { @@ -443,6 +471,12 @@ impl BlockDataBuilder { } _ => unreachable!("Invalid value type for block"), }, + ValueBuilder::UintValueBuilder(ref mut builder) => match value { + Value::UintValue(uint) => { + builder.append_value(uint); + } + _ => unreachable!("Invalid value type for block"), + }, ValueBuilder::RoaringBitmapBuilder(ref mut builder) => match value { Value::RoaringBitmapValue(bitmap) => { let mut bytes = Vec::with_capacity(bitmap.serialized_size()); @@ -481,6 +515,11 @@ impl BlockDataBuilder { let arr = builder.finish(); (&arr as &dyn Array).slice(0, arr.len()) } + KeyBuilder::UintBuilder(ref mut builder) => { + key_field = Field::new("key", DataType::UInt32, true); + let arr = builder.finish(); + (&arr as &dyn Array).slice(0, arr.len()) + } }; let value_field; @@ -499,6 +538,11 @@ impl BlockDataBuilder { let arr = builder.finish(); (&arr as &dyn Array).slice(0, arr.len()) } + ValueBuilder::UintValueBuilder(ref mut builder) => { + value_field = Field::new("value", DataType::UInt32, true); + let arr = builder.finish(); + (&arr as &dyn Array).slice(0, arr.len()) + } ValueBuilder::RoaringBitmapBuilder(ref mut builder) => { value_field = Field::new("value", DataType::Binary, true); let arr = builder.finish(); diff --git a/rust/worker/src/blockstore/arrow_blockfile/blockfile.rs b/rust/worker/src/blockstore/arrow_blockfile/blockfile.rs index 006ef83fcb8e..c1f65e80edd3 100644 --- a/rust/worker/src/blockstore/arrow_blockfile/blockfile.rs +++ b/rust/worker/src/blockstore/arrow_blockfile/blockfile.rs @@ -151,6 +151,11 @@ impl Blockfile for ArrowBlockfile { return Err(Box::new(BlockfileError::InvalidKeyType)); } } + Key::Uint(_) => { + if self.key_type != KeyType::Uint { + return Err(Box::new(BlockfileError::InvalidKeyType)); + } + } } // Validate value type @@ -165,8 +170,13 @@ impl Blockfile for ArrowBlockfile { return Err(Box::new(BlockfileError::InvalidValueType)); } } - Value::Int32Value(_) => { - if self.value_type != ValueType::Int32 { + Value::IntValue(_) => { + if self.value_type != ValueType::Int { + return Err(Box::new(BlockfileError::InvalidValueType)); + } + } + Value::UintValue(_) => { + if self.value_type != ValueType::Uint { return Err(Box::new(BlockfileError::InvalidValueType)); } } @@ -554,4 +564,29 @@ mod tests { } } } + + #[test] + fn test_uint_key_val() { + let block_provider = ArrowBlockProvider::new(); + let mut blockfile = ArrowBlockfile::new(KeyType::Uint, ValueType::Uint, block_provider); + + blockfile.begin_transaction().unwrap(); + let n = 2000; + for i in 0..n { + let key = BlockfileKey::new("key".to_string(), Key::Uint(i as u32)); + blockfile.set(key, Value::UintValue(i as u32)).unwrap(); + } + blockfile.commit_transaction().unwrap(); + + for i in 0..n { + let key = BlockfileKey::new("key".to_string(), Key::Uint(i as u32)); + let res = blockfile.get(key).unwrap(); + match res { + Value::UintValue(val) => { + assert_eq!(val, i as u32); + } + _ => panic!("Unexpected value type"), + } + } + } } diff --git a/rust/worker/src/blockstore/types.rs b/rust/worker/src/blockstore/types.rs index fd1f906e85d7..8c7d9772807f 100644 --- a/rust/worker/src/blockstore/types.rs +++ b/rust/worker/src/blockstore/types.rs @@ -49,6 +49,7 @@ impl Key { Key::String(s) => s.len(), Key::Float(_) => 4, Key::Bool(_) => 1, + Key::Uint(_) => 4, } } } @@ -69,6 +70,7 @@ impl From<&BlockfileKey> for KeyType { Key::String(_) => KeyType::String, Key::Float(_) => KeyType::Float, Key::Bool(_) => KeyType::Bool, + Key::Uint(_) => KeyType::Uint, } } } @@ -78,6 +80,7 @@ pub(crate) enum Key { String(String), Float(f32), Bool(bool), + Uint(u32), } #[derive(Debug, Clone, Copy, PartialEq)] @@ -85,6 +88,7 @@ pub(crate) enum KeyType { String, Float, Bool, + Uint, } impl Display for Key { @@ -93,6 +97,7 @@ impl Display for Key { Key::String(s) => write!(f, "{}", s), Key::Float(fl) => write!(f, "{}", fl), Key::Bool(b) => write!(f, "{}", b), + Key::Uint(u) => write!(f, "{}", u), } } } @@ -146,15 +151,19 @@ impl Ord for BlockfileKey { match self.key { Key::String(ref s1) => match &other.key { Key::String(s2) => s1.cmp(s2), - _ => panic!("Cannot compare string to float or bool"), + _ => panic!("Cannot compare string to float, bool, or uint"), }, Key::Float(f1) => match &other.key { Key::Float(f2) => f1.partial_cmp(f2).unwrap(), - _ => panic!("Cannot compare float to string or bool"), + _ => panic!("Cannot compare float to string, bool, or uint"), }, Key::Bool(b1) => match &other.key { Key::Bool(b2) => b1.cmp(b2), - _ => panic!("Cannot compare bool to string or float"), + _ => panic!("Cannot compare bool to string, float, or uint"), + }, + Key::Uint(u1) => match &other.key { + Key::Uint(u2) => u1.cmp(u2), + _ => panic!("Cannot compare uint to string, float, or bool"), }, } } else { @@ -170,7 +179,8 @@ pub(crate) enum Value { Int32ArrayValue(Int32Array), PositionalPostingListValue(PositionalPostingList), StringValue(String), - Int32Value(i32), + IntValue(i32), + UintValue(u32), RoaringBitmapValue(RoaringBitmap), } @@ -199,7 +209,8 @@ impl Clone for Value { } Value::StringValue(s) => Value::StringValue(s.clone()), Value::RoaringBitmapValue(bitmap) => Value::RoaringBitmapValue(bitmap.clone()), - Value::Int32Value(i) => Value::Int32Value(*i), + Value::IntValue(i) => Value::IntValue(*i), + Value::UintValue(u) => Value::UintValue(*u), } } } @@ -213,7 +224,7 @@ impl Value { } Value::StringValue(s) => s.len(), Value::RoaringBitmapValue(bitmap) => bitmap.serialized_size(), - Value::Int32Value(_) => 4, + Value::IntValue(_) | Value::UintValue(_) => 4, } } } @@ -225,7 +236,8 @@ impl From<&Value> for ValueType { Value::PositionalPostingListValue(_) => ValueType::PositionalPostingList, Value::RoaringBitmapValue(_) => ValueType::RoaringBitmap, Value::StringValue(_) => ValueType::String, - Value::Int32Value(_) => ValueType::Int32, + Value::IntValue(_) => ValueType::Int, + Value::UintValue(_) => ValueType::Uint, } } } @@ -236,7 +248,8 @@ pub(crate) enum ValueType { PositionalPostingList, RoaringBitmap, String, - Int32, + Int, + Uint, } pub(crate) trait Blockfile: BlockfileClone { diff --git a/rust/worker/src/index/fulltext/types.rs b/rust/worker/src/index/fulltext/types.rs index 0319a4b4c65f..08d8d7db6d0f 100644 --- a/rust/worker/src/index/fulltext/types.rs +++ b/rust/worker/src/index/fulltext/types.rs @@ -86,7 +86,7 @@ impl FullTextIndex for BlockfileFullTextIndex { for (key, value) in self.uncommitted_frequencies.drain() { let blockfilekey = BlockfileKey::new("".to_string(), Key::String(key.to_string())); self.frequencies_blockfile - .set(blockfilekey, Value::Int32Value(value)); + .set(blockfilekey, Value::IntValue(value)); } self.posting_lists_blockfile.commit_transaction()?; self.frequencies_blockfile.commit_transaction()?; @@ -135,7 +135,7 @@ impl FullTextIndex for BlockfileFullTextIndex { BlockfileKey::new("".to_string(), Key::String(token.text.to_string())); let value = self.frequencies_blockfile.get(blockfilekey); match value { - Ok(Value::Int32Value(frequency)) => { + Ok(Value::IntValue(frequency)) => { token_frequencies.push((token.text.to_string(), frequency)); } Ok(_) => { @@ -230,7 +230,7 @@ mod tests { .create("pl", KeyType::String, ValueType::PositionalPostingList) .unwrap(); let freq_blockfile = provider - .create("freq", KeyType::String, ValueType::Int32) + .create("freq", KeyType::String, ValueType::Int) .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( NgramTokenizer::new(1, 1, false).unwrap(), @@ -245,7 +245,7 @@ mod tests { .create("pl", KeyType::String, ValueType::PositionalPostingList) .unwrap(); let freq_blockfile = provider - .create("freq", KeyType::String, ValueType::Int32) + .create("freq", KeyType::String, ValueType::Int) .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( NgramTokenizer::new(1, 1, false).unwrap(), @@ -266,7 +266,7 @@ mod tests { .create("pl", KeyType::String, ValueType::PositionalPostingList) .unwrap(); let freq_blockfile = provider - .create("freq", KeyType::String, ValueType::Int32) + .create("freq", KeyType::String, ValueType::Int) .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( NgramTokenizer::new(1, 1, false).unwrap(), @@ -287,7 +287,7 @@ mod tests { .create("pl", KeyType::String, ValueType::PositionalPostingList) .unwrap(); let freq_blockfile = provider - .create("freq", KeyType::String, ValueType::Int32) + .create("freq", KeyType::String, ValueType::Int) .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( NgramTokenizer::new(1, 1, false).unwrap(), @@ -318,7 +318,7 @@ mod tests { .create("pl", KeyType::String, ValueType::PositionalPostingList) .unwrap(); let freq_blockfile = provider - .create("freq", KeyType::String, ValueType::Int32) + .create("freq", KeyType::String, ValueType::Int) .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( NgramTokenizer::new(1, 1, false).unwrap(),