diff --git a/task_whisperer/src/embedding/base.py b/task_whisperer/src/embedding/base.py index a9bad32..63036e8 100644 --- a/task_whisperer/src/embedding/base.py +++ b/task_whisperer/src/embedding/base.py @@ -10,14 +10,16 @@ class BaseEmbeddings(ABC): """BaseEmbeddings""" - @abstractmethod def __init__( self, api_key: str, faiss_index_root_path: str, embedding_model: str, ) -> None: - pass + self.embedder = None + self.api_key = api_key + self.embedding_model = embedding_model + self.faiss_index_root_path = faiss_index_root_path @abstractmethod def load_documents( @@ -51,3 +53,7 @@ def generate_embeddings( description_col_name: str = "description_cleaned", ) -> Tuple[str, Any]: pass + + @abstractmethod + def embed_query(self, query: str) -> List[float]: + pass diff --git a/task_whisperer/src/embedding/openai.py b/task_whisperer/src/embedding/openai.py index 4c6a44c..314917d 100644 --- a/task_whisperer/src/embedding/openai.py +++ b/task_whisperer/src/embedding/openai.py @@ -6,12 +6,15 @@ from langchain_text_splitters import CharacterTextSplitter from langchain_openai import OpenAIEmbeddings from langchain_community.vectorstores.faiss import FAISS -from langchain_community.callbacks import get_openai_callback EMBEDDING_MODEL = "text-embedding-ada-002" +def get_embedder(api_key: str, embedding_model: str = EMBEDDING_MODEL): + return OpenAIEmbeddings(api_key=api_key, model=embedding_model) + + class OpenAIEmbeddingGenerator: """OpenAIEmbeddingGenerator""" @@ -26,7 +29,7 @@ def __init__( assert api_key, "api_key is required" assert faiss_index_root_path, "faiss_index_root_path is required" assert embedding_model, "embedding_model is required" - self.api_key = api_key + self.embedder = get_embedder(api_key, embedding_model) self.embedding_model = embedding_model self.faiss_index_root_path = faiss_index_root_path @@ -63,17 +66,15 @@ def split_documents( def embed_documents( self, project: str, documents: List[Document] ) -> Tuple[str, Any]: - embedder = OpenAIEmbeddings(api_key=self.api_key, model=self.embedding_model) embedding_path = os.path.join( self.faiss_index_root_path, self.kind, f"faiss_index_{project}_{self.embedding_model}", ) - with get_openai_callback() as cb: - faiss_db = FAISS.from_documents(documents, embedder) - faiss_db.save_local(embedding_path) + faiss_db = FAISS.from_documents(documents, self.embedder) + faiss_db.save_local(embedding_path) - return embedding_path, cb + return embedding_path def generate_embeddings( self, @@ -87,5 +88,8 @@ def generate_embeddings( project, issues_df, summary_col_name, description_col_name ) splitted_docs = self.split_documents(docs, chunk_size) - embedding_path, cb = self.embed_documents(project, splitted_docs) - return embedding_path, cb + embedding_path = self.embed_documents(project, splitted_docs) + return embedding_path + + def embed_query(self, query: str) -> List[float]: + return self.embedder.embed_query(query) diff --git a/task_whisperer/src/page_helpers/generate_embeddings.py b/task_whisperer/src/page_helpers/generate_embeddings.py index 4656bef..2cc97a4 100644 --- a/task_whisperer/src/page_helpers/generate_embeddings.py +++ b/task_whisperer/src/page_helpers/generate_embeddings.py @@ -42,13 +42,9 @@ def create_embeddings(self, project: str) -> str: embedding_model=self.llm_config["embedding_model"], ) - embedding_path, embedding_cb = embedding_client.generate_embeddings( + embedding_path = embedding_client.generate_embeddings( project, processed_issues_df ) - - # TODO: use a logger here - print(embedding_cb) - return embedding_path def load_metadata(self) -> Optional[List[Dict]]: diff --git a/task_whisperer/src/page_helpers/generate_task_description.py b/task_whisperer/src/page_helpers/generate_task_description.py index 8254d32..eee8ae8 100644 --- a/task_whisperer/src/page_helpers/generate_task_description.py +++ b/task_whisperer/src/page_helpers/generate_task_description.py @@ -2,7 +2,9 @@ from typing import Any, Dict from task_whisperer import PROJECT_ROOT, CONFIG -from task_whisperer.src.task_generation.factory import task_generator_factory +from task_whisperer.src.task_generation import task_generator_factory +from task_whisperer.src.embedding import embedding_factory +from task_whisperer.src.vector_store import vector_store_factory EMBEDDINGS_ROOT_PATH = os.path.join( PROJECT_ROOT, CONFIG["datastore_path"], "embeddings" @@ -16,12 +18,20 @@ def create_task_description( task_summary: str, project: str, ): - task_generator_client = task_generator_factory.get(llm_kind)( + embedding_client = embedding_factory.get(llm_kind)( api_key=llm_config["api_key"], faiss_index_root_path=FAISS_ROOT_PATH, - model=llm_config["llm_model"], embedding_model=llm_config["embedding_model"], ) + vector_store_client = vector_store_factory.get("faiss")( + faiss_index_root_path=FAISS_ROOT_PATH, + embedding_generator=embedding_client, + ) + task_generator_client = task_generator_factory.get(llm_kind)( + api_key=llm_config["api_key"], + vector_store=vector_store_client, + model=llm_config["llm_model"], + ) response = task_generator_client.create_task_description( project, task_summary, diff --git a/task_whisperer/src/task_generation/base.py b/task_whisperer/src/task_generation/base.py index 6a812c4..2b893fd 100644 --- a/task_whisperer/src/task_generation/base.py +++ b/task_whisperer/src/task_generation/base.py @@ -1,17 +1,6 @@ from abc import ABC, abstractmethod -import os -from typing import Any, List, Tuple -import pandas as pd -from langchain_openai import OpenAIEmbeddings, ChatOpenAI -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_community.vectorstores import FAISS -from langchain_community.callbacks import get_openai_callback -import tiktoken - - -GPT_MODEL = "gpt-3.5-turbo" -EMBEDDING_MODEL = "text-embedding-ada-002" +from task_whisperer.src.vector_store.base import BaseVectorStore class BaseTaskGenerator(ABC): @@ -19,7 +8,10 @@ class BaseTaskGenerator(ABC): @abstractmethod def __init__( - self, api_key: str, faiss_index_root_path: str, model: str, embedding_model: str + self, + api_key: str, + vector_store: BaseVectorStore, + model: str, ) -> None: pass diff --git a/task_whisperer/src/task_generation/openai.py b/task_whisperer/src/task_generation/openai.py index e582b0f..ab32c2c 100644 --- a/task_whisperer/src/task_generation/openai.py +++ b/task_whisperer/src/task_generation/openai.py @@ -1,15 +1,14 @@ import os -from typing import List, Tuple +from typing import List from jinja2 import Environment, FileSystemLoader -from langchain_openai import OpenAIEmbeddings, ChatOpenAI +from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, SystemMessage -from langchain_community.vectorstores import FAISS from langchain_community.callbacks import get_openai_callback -import tiktoken from task_whisperer.src.task_generation.base import BaseTaskGenerator +from task_whisperer.src.vector_store.base import BaseVectorStore GPT_MODEL = "gpt-3.5-turbo" @@ -27,63 +26,14 @@ class OpenAITaskGenerator(BaseTaskGenerator): def __init__( self, api_key: str, - faiss_index_root_path: str, + vector_store: BaseVectorStore, model: str = GPT_MODEL, - embedding_model: str = EMBEDDING_MODEL, ) -> None: assert api_key, "api_key is required" - assert faiss_index_root_path, "faiss_index_root_path is required" assert model, "model is required" - assert embedding_model, "embedding_model is required" self.api_key = api_key self.model = model - self.embedding_model = embedding_model - self.faiss_index_root_path = faiss_index_root_path - - def get_n_tokens(self, query: str) -> int: - try: - encoding = tiktoken.encoding_for_model(self.model) - except KeyError: - encoding = tiktoken.get_encoding("cl100k_base") - return len(encoding.encode(query)) - - def read_embeddings(self, project: str): - embedding_path = os.path.join( - self.faiss_index_root_path, - self.kind, - f"faiss_index_{project}_{self.embedding_model}", - ) - - embedder = OpenAIEmbeddings(api_key=self.api_key, model=self.embedding_model) - faiss_db = FAISS.load_local(embedding_path, embedder) - return embedder, faiss_db - - def get_task_embedding( - self, embedder: OpenAIEmbeddings, task_summary: str, task_desc: str = "" - ) -> Tuple[List[float], int]: - task_def = f"Summary: {task_summary}\nDescription: {task_desc}" - n_tokens = self.get_n_tokens(task_def) - embedded = embedder.embed_query(task_def) - return embedded, n_tokens - - def get_similar_queries( - self, - faiss_db, - embedder: OpenAIEmbeddings, - task_summary: str, - task_desc: str = "", - n_similar: int = 5, - ): - task_embed, n_tokens = self.get_task_embedding( - embedder, task_summary, task_desc - ) - similar_questions = faiss_db.similarity_search_by_vector( - task_embed, k=n_similar - ) - similar_questions = [ - similar_question.page_content for similar_question in similar_questions - ] - return similar_questions, n_tokens + self.vector_store = vector_store def get_system_prompt(self): with open(os.path.join(TEMPLATES_PATH, "system.txt"), "r") as f: @@ -117,11 +67,9 @@ def create_task_description( n_similar_tasks: int = 5, temperature: float = 0, ): - embedder, faiss_db = self.read_embeddings(project) - if n_similar_tasks > 0: - similar_tasks, n_tokens = self.get_similar_queries( - faiss_db, embedder, task_summary, task_desc, n_similar_tasks + similar_tasks, n_tokens = self.vector_store.similarity_search( + project, task_summary, task_desc, n_similar=n_similar_tasks ) else: similar_tasks = []