Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add record segment #1855

Closed
wants to merge 17 commits into from
150 changes: 140 additions & 10 deletions rust/worker/src/blockstore/arrow_blockfile/block/delta.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
use super::{Block, BlockBuilderOptions, BlockData, BlockDataBuilder};
use crate::blockstore::{
arrow_blockfile::{blockfile::MAX_BLOCK_SIZE, provider::ArrowBlockProvider},
types::{BlockfileKey, KeyType, Value, ValueType},
BlockfileError,
use crate::{
blockstore::{
arrow_blockfile::{blockfile::MAX_BLOCK_SIZE, provider::ArrowBlockProvider},
types::{BlockfileKey, KeyType, Value, ValueType},
BlockfileError,
},
chroma_proto,
};
use arrow::util::bit_util;
use parking_lot::RwLock;
use std::{collections::BTreeMap, sync::Arc};
use prost::Message;
use std::{
collections::{BTreeMap, HashMap},
sync::Arc,
};

/// A block delta tracks a source block and represents the new state of a block. Blocks are
/// immutable, so when a write is made to a block, a new block is created with the new state.
Expand Down Expand Up @@ -117,6 +124,21 @@ impl BlockDelta {
inner.get_value_count()
}

fn get_document_size(&self) -> usize {
let inner = self.inner.read();
inner.get_document_size()
}

fn get_metadata_size(&self) -> usize {
let inner = self.inner.read();
inner.get_metadata_size()
}

fn get_user_id_size(&self) -> usize {
let inner = self.inner.read();
inner.get_user_id_size()
}

fn len(&self) -> usize {
let inner = self.inner.read();
inner.new_data.len()
Expand All @@ -125,6 +147,10 @@ impl BlockDelta {

struct BlockDeltaInner {
new_data: BTreeMap<BlockfileKey, Value>,
// A cache of the metadata json size for each blockfile key. This is used to avoid
// reserializing the metadata json for each blockfile key. It may be heavy on memory
// but we can easily optimize this later.
// metadata_json_cache: HashMap<BlockfileKey, usize>,
}

impl BlockDeltaInner {
Expand Down Expand Up @@ -200,12 +226,52 @@ impl BlockDeltaInner {
.fold(0, |acc, (_, value)| acc + value.get_size())
}

fn get_document_size(&self) -> usize {
self.new_data.iter().fold(0, |acc, (_, value)| match value {
Value::EmbeddingRecordValue(embedding_record) => {
let len = match &embedding_record.get_document() {
Some(document) => document.len(),
None => 0,
};
acc + len
}
_ => 0,
})
}

fn get_metadata_size(&self) -> usize {
self.new_data.iter().fold(0, |acc, (_, value)| match value {
Value::EmbeddingRecordValue(embedding_record) => {
match &embedding_record.record.metadata {
Some(metadata) => {
// TODO: cache this
let as_proto: chroma_proto::UpdateMetadata = metadata.clone().into();
let bytes = as_proto.encoded_len();
acc + bytes
}
None => 0,
}
}
_ => 0,
})
}

fn get_user_id_size(&self) -> usize {
self.new_data.iter().fold(0, |acc, (_, value)| match value {
Value::EmbeddingRecordValue(embedding_record) => acc + embedding_record.record.id.len(),
_ => 0,
})
}

fn get_value_count(&self) -> usize {
self.new_data.iter().fold(0, |acc, (_, value)| match value {
Value::Int32ArrayValue(arr) => acc + arr.len(),
Value::StringValue(s) => acc + s.len(),
Value::RoaringBitmapValue(bitmap) => acc + bitmap.serialized_size(),
Value::UintValue(_) => acc + 1,
// The embedding record is multiple fields and so this just returns the
// count of the records
Value::EmbeddingRecordValue(_) => acc + 1,
_ => unimplemented!("Value type not implemented"),
})
}
Expand Down Expand Up @@ -243,6 +309,13 @@ impl BlockDeltaInner {
bit_util::round_upto_multiple_of_64((item_count + 1) * 4)
}
ValueType::Uint => 0,
ValueType::EmbeddingRecord => {
// RESUME POINT
let user_id_offset = bit_util::round_upto_multiple_of_64((item_count + 1) * 4);
let string_offset = bit_util::round_upto_multiple_of_64((item_count + 1) * 4);
let metadata_offset = bit_util::round_upto_multiple_of_64((item_count + 1) * 4);
user_id_offset + string_offset + metadata_offset
}
_ => unimplemented!("Value type not implemented"),
}
}
Expand Down Expand Up @@ -317,16 +390,32 @@ impl TryFrom<&BlockDelta> for BlockData {
type Error = super::BlockDataBuildError;

fn try_from(delta: &BlockDelta) -> Result<BlockData, Self::Error> {
let mut builder = BlockDataBuilder::new(
delta.source_block.get_key_type(),
delta.source_block.get_value_type(),
Some(BlockBuilderOptions::new(
let value_options = match delta.source_block.get_value_type() {
ValueType::Int32Array
| ValueType::String
| ValueType::RoaringBitmap
| ValueType::Uint => BlockBuilderOptions::new_flat_value(
delta.len(),
delta.get_prefix_size(),
delta.get_key_size(),
delta.get_value_count(),
delta.get_value_size(),
)),
),
ValueType::EmbeddingRecord => BlockBuilderOptions::new_embedding_value(
delta.len(),
delta.get_prefix_size(),
delta.get_key_size(),
delta.get_user_id_size(),
delta.get_document_size(),
delta.get_metadata_size(),
),
_ => unimplemented!("Value type not implemented"),
};

let mut builder = BlockDataBuilder::new(
delta.source_block.get_key_type(),
delta.source_block.get_value_type(),
value_options,
);
for (key, value) in delta.inner.read().new_data.iter() {
builder.add(key.clone(), value.clone());
Expand Down Expand Up @@ -356,8 +445,13 @@ impl From<Arc<Block>> for BlockDelta {
mod test {
use super::*;
use crate::blockstore::types::{Key, KeyType, ValueType};
use crate::types::{
LogRecord, Operation, OperationRecord, ScalarEncoding, UpdateMetadataValue,
};
use arrow::array::Int32Array;
use rand::{random, Rng};
use std::collections::HashMap;
use std::str::FromStr;

#[test]
fn test_sizing_int_arr_val() {
Expand Down Expand Up @@ -453,4 +547,40 @@ mod test {
let block_data = BlockData::try_from(&delta).unwrap();
assert_eq!(size, block_data.get_size());
}

#[test]
fn test_embedding_record_val() {
let block_provider = ArrowBlockProvider::new();
let block = block_provider.create_block(KeyType::Uint, ValueType::EmbeddingRecord);
let delta = BlockDelta::from(block.clone());

let n = 2000;
for i in 0..n {
let key = BlockfileKey::new("prefix".to_string(), Key::Uint(i as u32));
let mut metadata = HashMap::new();
metadata.insert(
"chroma:document".to_string(),
UpdateMetadataValue::Str("test".to_string()),
);
metadata.insert(
"random_float".to_string(),
UpdateMetadataValue::Float(random::<f64>()),
);
let value = Value::EmbeddingRecordValue(LogRecord {
log_offset: 0,
record: OperationRecord {
id: "test".to_string(),
embedding: Some(vec![1.0, 2.0, 3.0]),
encoding: Some(ScalarEncoding::FLOAT32),
metadata: Some(metadata),
operation: Operation::Add,
},
});
delta.add(key, value);
}

let size = delta.get_size();
let block_data = BlockData::try_from(&delta).unwrap();
assert_eq!(size, block_data.get_size());
}
}
Loading
Loading