diff --git a/src/agent/rag/mod.rs b/src/agent/rag/mod.rs index 34e50a2..3553790 100644 --- a/src/agent/rag/mod.rs +++ b/src/agent/rag/mod.rs @@ -1,5 +1,6 @@ use anyhow::Result; use async_trait::async_trait; +use naive::NaiveVectorStore; use serde::{Deserialize, Serialize}; use super::generator::Client; @@ -23,10 +24,21 @@ pub struct Document { #[async_trait] pub trait VectorStore: Send { #[allow(clippy::borrowed_box)] - fn new_with_generator(generator: Box) -> Result + async fn new(embedder: Box, config: Configuration) -> Result where Self: Sized; async fn add(&mut self, document: Document) -> Result<()>; async fn retrieve(&self, query: &str, top_k: usize) -> Result>; } + +pub async fn factory( + flavor: &str, + embedder: Box, + config: Configuration, +) -> Result> { + match flavor { + "naive" => Ok(Box::new(NaiveVectorStore::new(embedder, config).await?)), + _ => Err(anyhow!("rag flavor '{flavor} not supported yet")), + } +} diff --git a/src/agent/rag/naive.rs b/src/agent/rag/naive.rs index f7a3878..5774367 100644 --- a/src/agent/rag/naive.rs +++ b/src/agent/rag/naive.rs @@ -8,23 +8,39 @@ use async_trait::async_trait; use colored::Colorize; use glob::glob; -use super::{Document, Embeddings, VectorStore}; +use super::{Configuration, Document, Embeddings, VectorStore}; use crate::agent::{generator::Client, rag::metrics}; // TODO: integrate other more efficient vector databases. pub struct NaiveVectorStore { + config: Configuration, embedder: Box, documents: HashMap, embeddings: HashMap, } -impl NaiveVectorStore { - // TODO: add persistency - pub async fn from_indexed_path(generator: Box, path: &str) -> Result { - let path = std::fs::canonicalize(path)?.display().to_string(); +#[async_trait] +impl VectorStore for NaiveVectorStore { + #[allow(clippy::borrowed_box)] + async fn new(embedder: Box, config: Configuration) -> Result + where + Self: Sized, + { + // TODO: add persistency + let documents = HashMap::new(); + let embeddings = HashMap::new(); + let mut store = Self { + config, + documents, + embeddings, + embedder, + }; + + let path = std::fs::canonicalize(&store.config.path)? + .display() + .to_string(); let expr = format!("{}/**/*.txt", path); - let mut store = NaiveVectorStore::new_with_generator(generator)?; for path in (glob(&expr)?).flatten() { let doc_name = path.display(); @@ -39,24 +55,6 @@ impl NaiveVectorStore { Ok(store) } -} - -#[async_trait] -impl VectorStore for NaiveVectorStore { - #[allow(clippy::borrowed_box)] - fn new_with_generator(embedder: Box) -> Result - where - Self: Sized, - { - let documents = HashMap::new(); - let embeddings = HashMap::new(); - - Ok(Self { - documents, - embeddings, - embedder, - }) - } async fn add(&mut self, document: Document) -> Result<()> { if self.documents.contains_key(&document.name) { diff --git a/src/agent/state/mod.rs b/src/agent/state/mod.rs index 058684e..c19aa80 100644 --- a/src/agent/state/mod.rs +++ b/src/agent/state/mod.rs @@ -7,7 +7,7 @@ use metrics::Metrics; use super::{ generator::{Client, Message}, namespaces::{self, Namespace}, - rag::{naive::NaiveVectorStore, Document, VectorStore}, + rag::{Document, VectorStore}, task::Task, Invocation, }; @@ -84,12 +84,11 @@ impl State { // add RAG namespace let rag: Option> = if let Some(config) = task.get_rag_config() { - let v_store: NaiveVectorStore = - NaiveVectorStore::from_indexed_path(embedder, &config.path).await?; + let v_store = super::rag::factory("naive", embedder, config).await?; namespaces.push(namespaces::NAMESPACES.get("rag").unwrap()()); - Some(Box::new(v_store)) + Some(v_store) } else { None };