From dede426ad7edcd9ba5d439f85ba54961bc92e2dc Mon Sep 17 00:00:00 2001 From: Xin Liu Date: Mon, 16 Dec 2024 17:03:48 +0800 Subject: [PATCH] refactor!: improve APIs Signed-off-by: Xin Liu --- examples/src/main.rs | 22 +++++++++++++++++++--- src/lib.rs | 33 ++++++++++++++++++++++++++------- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/examples/src/main.rs b/examples/src/main.rs index 4ceb2aa..62b1846 100644 --- a/examples/src/main.rs +++ b/examples/src/main.rs @@ -78,17 +78,33 @@ async fn main() -> Result<(), Box> { client.collection_info("my_test").await ); - let p = client.get_point("my_test", 2).await; + let p = client.get_point("my_test", &PointId::from(2)).await; println!("The second point is {:?}", p); - let ps = client.get_points("my_test", vec![1, 2, 3, 4, 5, 6]).await; + let ps = client + .get_points( + "my_test", + &vec![1, 2, 3, 4, 5, 6] + .into_iter() + .map(|id| PointId::from(id)) + .collect::>(), + ) + .await; println!("The 1-6 points are {:?}", ps); let q = vec![0.2, 0.1, 0.9, 0.7]; let r = client.search_points("my_test", q, 2, None).await; println!("Search result points are {:?}", r); - let r = client.delete_points("my_test", vec![1, 4]).await; + let r = client + .delete_points( + "my_test", + &vec![1, 4] + .into_iter() + .map(|id| PointId::from(id)) + .collect::>(), + ) + .await; println!("Delete points result is {:?}", r); println!( diff --git a/src/lib.rs b/src/lib.rs index b20542c..380802d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,13 +6,32 @@ use anyhow::{anyhow, bail, Error}; use serde::{Deserialize, Serialize}; use serde_json::json; use serde_json::{Map, Value}; +use std::fmt::Display; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum PointId { Uuid(String), Num(u64), } +impl From for PointId { + fn from(num: u64) -> Self { + PointId::Num(num) + } +} +impl From for PointId { + fn from(uuid: String) -> Self { + PointId::Uuid(uuid) + } +} +impl Display for PointId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PointId::Uuid(uuid) => write!(f, "{}", uuid), + PointId::Num(num) => write!(f, "{}", num), + } + } +} /// The point struct. /// A point is a record consisting of a vector and an optional payload. @@ -179,7 +198,7 @@ impl Qdrant { pub async fn search_points( &self, collection_name: &str, - point: Vec, + vector: Vec, limit: u64, score_threshold: Option, ) -> Result, Error> { @@ -192,7 +211,7 @@ impl Qdrant { }; let params = json!({ - "vector": point, + "vector": vector, "limit": limit, "with_payload": true, "with_vector": true, @@ -226,7 +245,7 @@ impl Qdrant { } } - pub async fn get_points(&self, collection_name: &str, ids: Vec) -> Vec { + pub async fn get_points(&self, collection_name: &str, ids: &[PointId]) -> Vec { #[cfg(feature = "logging")] info!(target: "stdout", "get points from collection '{}'", collection_name); @@ -246,7 +265,7 @@ impl Qdrant { ps } - pub async fn get_point(&self, collection_name: &str, id: u64) -> Point { + pub async fn get_point(&self, collection_name: &str, id: &PointId) -> Point { #[cfg(feature = "logging")] info!(target: "stdout", "get point from collection '{}' with id {}", collection_name, id); @@ -255,7 +274,7 @@ impl Qdrant { serde_json::from_value(r.clone()).unwrap() } - pub async fn delete_points(&self, collection_name: &str, ids: Vec) -> Result<(), Error> { + pub async fn delete_points(&self, collection_name: &str, ids: &[PointId]) -> Result<(), Error> { #[cfg(feature = "logging")] info!(target: "stdout", "delete points from collection '{}'", collection_name); @@ -651,7 +670,7 @@ impl Qdrant { Ok(json) } - pub async fn get_point_api(&self, collection_name: &str, id: u64) -> Result { + pub async fn get_point_api(&self, collection_name: &str, id: &PointId) -> Result { let url = format!( "{}/collections/{}/points/{}", self.url_base, collection_name, id,