Skip to content

Commit

Permalink
refactor: do not leak precomputed partition map into IVF (#1819)
Browse files Browse the repository at this point in the history
Separate shuffling logic from IVF
  • Loading branch information
eddyxu authored Jan 12, 2024
1 parent 788999b commit 9e5d1d6
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 54 deletions.
4 changes: 2 additions & 2 deletions rust/lance-index/benches/find_partitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ fn bench_partitions(c: &mut Criterion) {
let matrix = MatrixView::<Float32Type>::new(centroids.clone(), DIMENSION);

for k in &[1, 10, 50] {
let ivf = IvfImpl::new(matrix.clone(), MetricType::L2, vec![], None, None);
let ivf = IvfImpl::new(matrix.clone(), MetricType::L2, vec![], None);

c.bench_function(format!("IVF{},k={},L2", num_centroids, k).as_str(), |b| {
b.iter(|| {
let _ = ivf.find_partitions(&query, *k);
})
});

let ivf = IvfImpl::new(matrix.clone(), MetricType::Cosine, vec![], None, None);
let ivf = IvfImpl::new(matrix.clone(), MetricType::Cosine, vec![], None);
c.bench_function(
format!("IVF{},k={},Cosine", num_centroids, k).as_str(),
|b| {
Expand Down
72 changes: 26 additions & 46 deletions rust/lance-index/src/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@

//! IVF - Inverted File Index

use std::collections::HashMap;
use std::ops::Range;
use std::sync::Arc;

use arrow_array::builder::UInt32Builder;
use arrow_array::types::{Float16Type, Float32Type, Float64Type, UInt64Type};
use arrow_array::types::{Float16Type, Float32Type, Float64Type};
use arrow_array::{
cast::AsArray, types::UInt32Type, Array, FixedSizeListArray, RecordBatch, UInt32Array,
};
Expand All @@ -28,12 +26,12 @@ use arrow_select::take::take;
use async_trait::async_trait;
use futures::{stream, StreamExt};
use lance_arrow::*;
use lance_core::{Error, Result, ROW_ID};
use lance_core::{Error, Result};
use lance_linalg::{
distance::{Cosine, Dot, MetricType, L2},
MatrixView,
};
use log::{debug, info};
use log::info;
use snafu::{location, Location};
use tracing::{instrument, Instrument};

Expand All @@ -55,16 +53,9 @@ fn new_ivf_impl<T: ArrowFloatType + Dot + Cosine + L2 + 'static>(
metric_type: MetricType,
transforms: Vec<Arc<dyn Transformer>>,
range: Option<Range<u32>>,
precomputed_partitions: Option<HashMap<u64, u32>>,
) -> Arc<dyn Ivf> {
let mat = MatrixView::<T>::new(Arc::new(centroids.clone()), dimension);
Arc::new(IvfImpl::<T>::new(
mat,
metric_type,
transforms,
range,
precomputed_partitions,
))
Arc::new(IvfImpl::<T>::new(mat, metric_type, transforms, range))
}

/// Create an IVF from the flatten centroids.
Expand All @@ -82,7 +73,6 @@ pub fn new_ivf(
metric_type: MetricType,
transforms: Vec<Arc<dyn Transformer>>,
range: Option<Range<u32>>,
precomputed_partitions: Option<HashMap<u64, u32>>,
) -> Result<Arc<dyn Ivf>> {
match centroids.data_type() {
DataType::Float16 => Ok(new_ivf_impl::<Float16Type>(
Expand All @@ -91,23 +81,20 @@ pub fn new_ivf(
metric_type,
transforms,
range,
precomputed_partitions,
)),
DataType::Float32 => Ok(new_ivf_impl::<Float32Type>(
centroids.as_primitive(),
dimension,
metric_type,
transforms,
range,
precomputed_partitions,
)),
DataType::Float64 => Ok(new_ivf_impl::<Float64Type>(
centroids.as_primitive(),
dimension,
metric_type,
transforms,
range,
precomputed_partitions,
)),
_ => Err(Error::Index {
message: format!(
Expand All @@ -126,7 +113,6 @@ fn new_ivf_with_pq_impl<T: ArrowFloatType + Dot + Cosine + L2 + 'static>(
vector_column: &str,
pq: Arc<dyn ProductQuantizer>,
range: Option<Range<u32>>,
precomputed_partitions: Option<HashMap<u64, u32>>,
) -> Arc<dyn Ivf> {
let mat = MatrixView::<T>::new(Arc::new(centroids.clone()), dimension);
Arc::new(IvfImpl::<T>::new_with_pq(
Expand All @@ -135,7 +121,6 @@ fn new_ivf_with_pq_impl<T: ArrowFloatType + Dot + Cosine + L2 + 'static>(
vector_column,
pq,
range,
precomputed_partitions,
))
}

Expand All @@ -146,7 +131,6 @@ pub fn new_ivf_with_pq(
vector_column: &str,
pq: Arc<dyn ProductQuantizer>,
range: Option<Range<u32>>,
precomputed_partitions: Option<HashMap<u64, u32>>,
) -> Result<Arc<dyn Ivf>> {
match centroids.data_type() {
DataType::Float16 => Ok(new_ivf_with_pq_impl::<Float16Type>(
Expand All @@ -156,7 +140,6 @@ pub fn new_ivf_with_pq(
vector_column,
pq,
range,
precomputed_partitions,
)),
DataType::Float32 => Ok(new_ivf_with_pq_impl::<Float32Type>(
centroids.as_primitive(),
Expand All @@ -165,7 +148,6 @@ pub fn new_ivf_with_pq(
vector_column,
pq,
range,
precomputed_partitions,
)),
DataType::Float64 => Ok(new_ivf_with_pq_impl::<Float64Type>(
centroids.as_primitive(),
Expand All @@ -174,7 +156,6 @@ pub fn new_ivf_with_pq(
vector_column,
pq,
range,
precomputed_partitions,
)),
_ => Err(Error::Index {
message: format!(
Expand Down Expand Up @@ -227,10 +208,20 @@ pub trait Ivf: Send + Sync + std::fmt::Debug {
/// It transform a [RecordBatch] that contains one vector column into a record batch with
/// schema `(PART_ID_COLUMN, ...)`, where [PART_ID_COLUMN] has the partition id for each vector.
///
/// Parameters
/// ----------
/// - *batch*: input [RecordBatch]
/// - *column: the name of the vector column to be partitioned and transformed.
/// - *partion_ids*: optional precomputed partition IDs for each vector.
/// Note that the vector column might be transformed by the `transforms` in the IVF.
///
/// **Warning**: unstable API.
async fn partition_transform(&self, batch: &RecordBatch, column: &str) -> Result<RecordBatch>;
async fn partition_transform(
&self,
batch: &RecordBatch,
column: &str,
partition_ids: Option<UInt32Array>,
) -> Result<RecordBatch>;
}

/// IVF - IVF file partition
Expand All @@ -251,8 +242,6 @@ pub struct IvfImpl<T: ArrowFloatType + Dot + L2 + Cosine> {

/// Only covers a range of partitions.
partition_range: Option<Range<u32>>,

precomputed_partitions: Option<HashMap<u64, u32>>,
}

impl<T: ArrowFloatType + Dot + L2 + Cosine + 'static> IvfImpl<T> {
Expand All @@ -261,14 +250,12 @@ impl<T: ArrowFloatType + Dot + L2 + Cosine + 'static> IvfImpl<T> {
metric_type: MetricType,
transforms: Vec<Arc<dyn Transformer>>,
range: Option<Range<u32>>,
precomputed_partitions: Option<HashMap<u64, u32>>,
) -> Self {
Self {
centroids,
metric_type,
transforms,
partition_range: range,
precomputed_partitions,
}
}

Expand All @@ -278,7 +265,6 @@ impl<T: ArrowFloatType + Dot + L2 + Cosine + 'static> IvfImpl<T> {
vector_column: &str,
pq: Arc<dyn ProductQuantizer>,
range: Option<Range<u32>>,
precomputed_partitions: Option<HashMap<u64, u32>>,
) -> Self {
let transforms: Vec<Arc<dyn Transformer>> = if pq.use_residual() {
vec![
Expand All @@ -305,7 +291,6 @@ impl<T: ArrowFloatType + Dot + L2 + Cosine + 'static> IvfImpl<T> {
metric_type,
transforms,
partition_range: range,
precomputed_partitions,
}
}

Expand Down Expand Up @@ -435,7 +420,7 @@ impl<T: ArrowFloatType + Dot + L2 + Cosine + 'static> Ivf for IvfImpl<T> {
),
location: Default::default(),
})?;
// TODO: hold kmeans in this struct.
// todo: hold kmeans in this struct.
let kmeans = KMeans::<T>::with_centroids(
self.centroids.data().clone(),
self.dimension(),
Expand All @@ -444,7 +429,12 @@ impl<T: ArrowFloatType + Dot + L2 + Cosine + 'static> Ivf for IvfImpl<T> {
Ok(kmeans.find_partitions(query.as_slice(), nprobes)?)
}

async fn partition_transform(&self, batch: &RecordBatch, column: &str) -> Result<RecordBatch> {
async fn partition_transform(
&self,
batch: &RecordBatch,
column: &str,
partition_ids: Option<UInt32Array>,
) -> Result<RecordBatch> {
let vector_arr = batch.column_by_name(column).ok_or(Error::Index {
message: format!("Column {} does not exist.", column),
location: location!(),
Expand All @@ -458,20 +448,10 @@ impl<T: ArrowFloatType + Dot + L2 + Cosine + 'static> Ivf for IvfImpl<T> {
location: location!(),
})?;

let part_ids = match (&self.precomputed_partitions, batch.column_by_name(ROW_ID)) {
(Some(partitions), Some(row_ids)) => {
debug!("Using precomputed partitions for partitions");
let mut builder = UInt32Builder::new();
for row in row_ids.as_primitive::<UInt64Type>().values().iter() {
if let Some(part_id) = partitions.get(row) {
builder.append_value(*part_id);
} else {
builder.append_null();
}
}
builder.finish()
}
_ => self.compute_partitions(data).await?,
let part_ids = if let Some(part_ids) = partition_ids {
part_ids
} else {
self.compute_partitions(data).await?
};

let (part_ids, batch) = if let Some(part_range) = self.partition_range.as_ref() {
Expand Down
33 changes: 31 additions & 2 deletions rust/lance-index/src/vector/ivf/shuffler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
//! 1. while groupby column will stay the same, we may want to include extra data columns in the future
//! 2. shuffling into memory is fast but we should add disk buffer to support bigger datasets

use std::collections::HashMap;
use std::sync::Arc;

use arrow_array::types::UInt64Type;
use arrow_array::{
cast::AsArray, FixedSizeListArray, RecordBatch, UInt32Array, UInt64Array, UInt8Array,
};
Expand Down Expand Up @@ -70,25 +72,52 @@ fn get_temp_dir() -> Result<Path> {
/// of shuffled partitioned data. Each stream corresponds to a partition and
/// is sorted within the stream. Consumer of these streams is expected to merge
/// the streams into a single stream by k-list mergo algo.
///
#[allow(clippy::too_many_arguments)]
pub async fn shuffle_dataset(
data: impl RecordBatchStream + Unpin + 'static,
column: &str,
ivf: Arc<dyn crate::vector::ivf::Ivf>,
precomputed_partitions: Option<HashMap<u64, u32>>,
num_partitions: u32,
num_sub_vectors: usize,
shuffle_partition_batches: usize,
shuffle_partition_concurrency: usize,
) -> Result<Vec<impl Stream<Item = Result<RecordBatch>>>> {
let column: Arc<str> = column.into();
let precomputed_partitions = precomputed_partitions.map(Arc::new);
let stream = data
.zip(repeat_with(move || ivf.clone()))
.map(move |(b, ivf)| {
let col_ref = column.clone();

// If precomputed_partitions map is provided, use it
// for fast partitions.
let partition_map = precomputed_partitions
.as_ref()
.cloned()
.unwrap_or(Arc::new(HashMap::new()));

tokio::task::spawn(async move {
let batch = b?;
ivf.partition_transform(&batch, col_ref.as_ref()).await

let part_ids = if !partition_map.is_empty() {
let row_ids = batch.column_by_name(ROW_ID).ok_or(Error::Index {
message: "column does not exist".to_string(),
location: location!(),
})?;
let part_ids = row_ids
.as_primitive::<UInt64Type>()
.values()
.iter()
.filter_map(|row_id| partition_map.get(row_id).copied())
.collect::<Vec<_>>();
Some(UInt32Array::from(part_ids))
} else {
None
};

ivf.partition_transform(&batch, col_ref.as_ref(), part_ids)
.await
})
})
.buffer_unordered(num_cpus::get())
Expand Down
4 changes: 1 addition & 3 deletions rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,13 @@ impl IVFIndex {
column,
pq_index.pq.clone(),
None,
None,
)?;

let shuffled = shuffle_dataset(
data,
column,
ivf,
None,
self.ivf.num_partitions() as u32,
pq_index.pq.num_sub_vectors(),
10000,
Expand Down Expand Up @@ -523,7 +523,6 @@ impl Ivf {
metric_type,
vec![],
None,
None,
)?;
internal.find_partitions(query, nprobes)
}
Expand Down Expand Up @@ -773,7 +772,6 @@ pub async fn build_ivf_pq_index(
metric_type,
vec![],
None,
None,
)?;

info!(
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/index/vector/ivf/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ pub(super) async fn build_partitions(
column,
pq.clone(),
Some(part_range),
precomputed_partitons,
)?;

let stream = shuffle_dataset(
data,
column,
ivf_model,
precomputed_partitons,
ivf.num_partitions() as u32,
pq.num_sub_vectors(),
shuffle_partition_batches,
Expand Down

0 comments on commit 9e5d1d6

Please sign in to comment.