Skip to content

Commit

Permalink
feat(vector): remove simsimd and use nalgebra instead (#5027)
Browse files Browse the repository at this point in the history
* feat(vector): remove `simsimd` and use `nalgebra` instead

Signed-off-by: Zhenchi <[email protected]>

* keep thing simple

Signed-off-by: Zhenchi <[email protected]>

---------

Signed-off-by: Zhenchi <[email protected]>
  • Loading branch information
zhongzc authored Nov 20, 2024
1 parent 55ced9a commit db345c9
Show file tree
Hide file tree
Showing 9 changed files with 429 additions and 150 deletions.
58 changes: 45 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion src/common/function/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ geo-types = { version = "0.7", optional = true }
geohash = { version = "0.13", optional = true }
h3o = { version = "0.6", optional = true }
jsonb.workspace = true
nalgebra = "0.33"
num = "0.4"
num-traits = "0.2"
once_cell.workspace = true
Expand All @@ -41,7 +42,6 @@ s2 = { version = "0.0.12", optional = true }
serde.workspace = true
serde_json.workspace = true
session.workspace = true
simsimd = "4"
snafu.workspace = true
sql.workspace = true
statrs = "0.16"
Expand All @@ -50,6 +50,7 @@ table.workspace = true
wkt = { version = "0.11", optional = true }

[dev-dependencies]
approx = "0.5"
ron = "0.7"
serde = { version = "1.0", features = ["derive"] }
tokio.workspace = true
49 changes: 31 additions & 18 deletions src/common/function/src/scalars/vector/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod cos;
mod dot;
mod l2sq;

use std::borrow::Cow;
use std::fmt::Display;
use std::sync::Arc;
Expand All @@ -21,14 +25,14 @@ use common_query::prelude::Signature;
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::value::ValueRef;
use datatypes::vectors::{Float64VectorBuilder, MutableVector, Vector, VectorRef};
use datatypes::vectors::{Float32VectorBuilder, MutableVector, Vector, VectorRef};
use snafu::ensure;

use crate::function::{Function, FunctionContext};
use crate::helper;

macro_rules! define_distance_function {
($StructName:ident, $display_name:expr, $similarity_method:ident) => {
($StructName:ident, $display_name:expr, $similarity_method:path) => {

/// A function calculates the distance between two vectors.

Expand All @@ -41,7 +45,7 @@ macro_rules! define_distance_function {
}

fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::float64_datatype())
Ok(ConcreteDataType::float32_datatype())
}

fn signature(&self) -> Signature {
Expand Down Expand Up @@ -71,7 +75,7 @@ macro_rules! define_distance_function {
let arg1 = &columns[1];

let size = arg0.len();
let mut result = Float64VectorBuilder::with_capacity(size);
let mut result = Float32VectorBuilder::with_capacity(size);
if size == 0 {
return Ok(result.to_vector());
}
Expand Down Expand Up @@ -101,9 +105,8 @@ macro_rules! define_distance_function {
}
);

let f = <f32 as simsimd::SpatialSimilarity>::$similarity_method;
// Safe: checked if the length of the vectors match
let d = f(vec0.as_ref(), vec1.as_ref()).unwrap();
// Checked if the length of the vectors match
let d = $similarity_method(vec0.as_ref(), vec1.as_ref());
result.push(Some(d));
} else {
result.push_null();
Expand All @@ -122,9 +125,9 @@ macro_rules! define_distance_function {
}
}

define_distance_function!(CosDistanceFunction, "cos_distance", cos);
define_distance_function!(L2SqDistanceFunction, "l2sq_distance", l2sq);
define_distance_function!(DotProductFunction, "dot_product", dot);
define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);

/// Parse a vector value if the value is a constant string.
fn parse_if_constant_string(arg: &Arc<dyn Vector>) -> Result<Option<Vec<f32>>> {
Expand All @@ -148,7 +151,7 @@ fn as_vector(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
ConcreteDataType::Binary(_) => arg
.as_binary()
.unwrap() // Safe: checked if it is a binary
.map(|bytes| Ok(Cow::Borrowed(binary_as_vector(bytes)?)))
.map(binary_as_vector)
.transpose(),
ConcreteDataType::String(_) => arg
.as_string()
Expand All @@ -164,18 +167,28 @@ fn as_vector(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
}

/// Convert a u8 slice to a vector value.
fn binary_as_vector(bytes: &[u8]) -> Result<&[f32]> {
if bytes.len() % 4 != 0 {
fn binary_as_vector(bytes: &[u8]) -> Result<Cow<'_, [f32]>> {
if bytes.len() % std::mem::size_of::<f32>() != 0 {
return InvalidFuncArgsSnafu {
err_msg: format!("Invalid binary length of vector: {}", bytes.len()),
}
.fail();
}

unsafe {
let num_floats = bytes.len() / 4;
let floats: &[f32] = std::slice::from_raw_parts(bytes.as_ptr() as *const f32, num_floats);
Ok(floats)
if cfg!(target_endian = "little") {
Ok(unsafe {
let vec = std::slice::from_raw_parts(
bytes.as_ptr() as *const f32,
bytes.len() / std::mem::size_of::<f32>(),
);
Cow::Borrowed(vec)
})
} else {
let v = bytes
.chunks_exact(std::mem::size_of::<f32>())
.map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
.collect::<Vec<f32>>();
Ok(Cow::Owned(v))
}
}

Expand Down Expand Up @@ -460,7 +473,7 @@ mod tests {
fn test_binary_as_vector() {
let bytes = [0, 0, 128, 63];
let result = binary_as_vector(&bytes).unwrap();
assert_eq!(result, &[1.0]);
assert_eq!(result.as_ref(), &[1.0]);

let invalid_bytes = [0, 0, 128];
let result = binary_as_vector(&invalid_bytes);
Expand Down
87 changes: 87 additions & 0 deletions src/common/function/src/scalars/vector/distance/cos.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use nalgebra::DVectorView;

/// Calculates the cos distance between two vectors.
///
/// **Note:** Must ensure that the length of the two vectors are the same.
pub fn cos(lhs: &[f32], rhs: &[f32]) -> f32 {
let lhs_vec = DVectorView::from_slice(lhs, lhs.len());
let rhs_vec = DVectorView::from_slice(rhs, rhs.len());

let dot_product = lhs_vec.dot(&rhs_vec);
let lhs_norm = lhs_vec.norm();
let rhs_norm = rhs_vec.norm();
if dot_product.abs() < f32::EPSILON
|| lhs_norm.abs() < f32::EPSILON
|| rhs_norm.abs() < f32::EPSILON
{
return 1.0;
}

let cos_similar = dot_product / (lhs_norm * rhs_norm);
let res = 1.0 - cos_similar;
if res.abs() < f32::EPSILON {
0.0
} else {
res
}
}

#[cfg(test)]
mod tests {
use approx::assert_relative_eq;

use super::*;

#[test]
fn test_cos_scalar() {
let lhs = vec![1.0, 2.0, 3.0];
let rhs = vec![1.0, 2.0, 3.0];
assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);

let lhs = vec![1.0, 2.0, 3.0];
let rhs = vec![4.0, 5.0, 6.0];
assert_relative_eq!(cos(&lhs, &rhs), 0.025, epsilon = 1e-2);

let lhs = vec![1.0, 2.0, 3.0];
let rhs = vec![7.0, 8.0, 9.0];
assert_relative_eq!(cos(&lhs, &rhs), 0.04, epsilon = 1e-2);

let lhs = vec![0.0, 0.0, 0.0];
let rhs = vec![1.0, 2.0, 3.0];
assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);

let lhs = vec![0.0, 0.0, 0.0];
let rhs = vec![4.0, 5.0, 6.0];
assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);

let lhs = vec![0.0, 0.0, 0.0];
let rhs = vec![7.0, 8.0, 9.0];
assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);

let lhs = vec![7.0, 8.0, 9.0];
let rhs = vec![1.0, 2.0, 3.0];
assert_relative_eq!(cos(&lhs, &rhs), 0.04, epsilon = 1e-2);

let lhs = vec![7.0, 8.0, 9.0];
let rhs = vec![4.0, 5.0, 6.0];
assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);

let lhs = vec![7.0, 8.0, 9.0];
let rhs = vec![7.0, 8.0, 9.0];
assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);
}
}
Loading

0 comments on commit db345c9

Please sign in to comment.