Skip to content

Commit

Permalink
refactoring to simplify task_generation modules by composition of vec…
Browse files Browse the repository at this point in the history
…tor store
  • Loading branch information
erensahin committed May 20, 2024
1 parent dbb0bcc commit afa8558
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 91 deletions.
10 changes: 8 additions & 2 deletions task_whisperer/src/embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
class BaseEmbeddings(ABC):
"""BaseEmbeddings"""

@abstractmethod
def __init__(
self,
api_key: str,
faiss_index_root_path: str,
embedding_model: str,
) -> None:
pass
self.embedder = None
self.api_key = api_key
self.embedding_model = embedding_model
self.faiss_index_root_path = faiss_index_root_path

@abstractmethod
def load_documents(
Expand Down Expand Up @@ -51,3 +53,7 @@ def generate_embeddings(
description_col_name: str = "description_cleaned",
) -> Tuple[str, Any]:
pass

@abstractmethod
def embed_query(self, query: str) -> List[float]:
pass
22 changes: 13 additions & 9 deletions task_whisperer/src/embedding/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from langchain_text_splitters import CharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.faiss import FAISS
from langchain_community.callbacks import get_openai_callback


EMBEDDING_MODEL = "text-embedding-ada-002"


def get_embedder(api_key: str, embedding_model: str = EMBEDDING_MODEL):
return OpenAIEmbeddings(api_key=api_key, model=embedding_model)


class OpenAIEmbeddingGenerator:
"""OpenAIEmbeddingGenerator"""

Expand All @@ -26,7 +29,7 @@ def __init__(
assert api_key, "api_key is required"
assert faiss_index_root_path, "faiss_index_root_path is required"
assert embedding_model, "embedding_model is required"
self.api_key = api_key
self.embedder = get_embedder(api_key, embedding_model)
self.embedding_model = embedding_model
self.faiss_index_root_path = faiss_index_root_path

Expand Down Expand Up @@ -63,17 +66,15 @@ def split_documents(
def embed_documents(
self, project: str, documents: List[Document]
) -> Tuple[str, Any]:
embedder = OpenAIEmbeddings(api_key=self.api_key, model=self.embedding_model)
embedding_path = os.path.join(
self.faiss_index_root_path,
self.kind,
f"faiss_index_{project}_{self.embedding_model}",
)
with get_openai_callback() as cb:
faiss_db = FAISS.from_documents(documents, embedder)
faiss_db.save_local(embedding_path)
faiss_db = FAISS.from_documents(documents, self.embedder)
faiss_db.save_local(embedding_path)

return embedding_path, cb
return embedding_path

def generate_embeddings(
self,
Expand All @@ -87,5 +88,8 @@ def generate_embeddings(
project, issues_df, summary_col_name, description_col_name
)
splitted_docs = self.split_documents(docs, chunk_size)
embedding_path, cb = self.embed_documents(project, splitted_docs)
return embedding_path, cb
embedding_path = self.embed_documents(project, splitted_docs)
return embedding_path

def embed_query(self, query: str) -> List[float]:
return self.embedder.embed_query(query)
6 changes: 1 addition & 5 deletions task_whisperer/src/page_helpers/generate_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,9 @@ def create_embeddings(self, project: str) -> str:
embedding_model=self.llm_config["embedding_model"],
)

embedding_path, embedding_cb = embedding_client.generate_embeddings(
embedding_path = embedding_client.generate_embeddings(
project, processed_issues_df
)

# TODO: use a logger here
print(embedding_cb)

return embedding_path

def load_metadata(self) -> Optional[List[Dict]]:
Expand Down
16 changes: 13 additions & 3 deletions task_whisperer/src/page_helpers/generate_task_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Any, Dict

from task_whisperer import PROJECT_ROOT, CONFIG
from task_whisperer.src.task_generation.factory import task_generator_factory
from task_whisperer.src.task_generation import task_generator_factory
from task_whisperer.src.embedding import embedding_factory
from task_whisperer.src.vector_store import vector_store_factory

EMBEDDINGS_ROOT_PATH = os.path.join(
PROJECT_ROOT, CONFIG["datastore_path"], "embeddings"
Expand All @@ -16,12 +18,20 @@ def create_task_description(
task_summary: str,
project: str,
):
task_generator_client = task_generator_factory.get(llm_kind)(
embedding_client = embedding_factory.get(llm_kind)(
api_key=llm_config["api_key"],
faiss_index_root_path=FAISS_ROOT_PATH,
model=llm_config["llm_model"],
embedding_model=llm_config["embedding_model"],
)
vector_store_client = vector_store_factory.get("faiss")(
faiss_index_root_path=FAISS_ROOT_PATH,
embedding_generator=embedding_client,
)
task_generator_client = task_generator_factory.get(llm_kind)(
api_key=llm_config["api_key"],
vector_store=vector_store_client,
model=llm_config["llm_model"],
)
response = task_generator_client.create_task_description(
project,
task_summary,
Expand Down
18 changes: 5 additions & 13 deletions task_whisperer/src/task_generation/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,17 @@
from abc import ABC, abstractmethod
import os
from typing import Any, List, Tuple

import pandas as pd
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_community.vectorstores import FAISS
from langchain_community.callbacks import get_openai_callback
import tiktoken


GPT_MODEL = "gpt-3.5-turbo"
EMBEDDING_MODEL = "text-embedding-ada-002"
from task_whisperer.src.vector_store.base import BaseVectorStore


class BaseTaskGenerator(ABC):
"""OpenAITaskGenerator"""

@abstractmethod
def __init__(
self, api_key: str, faiss_index_root_path: str, model: str, embedding_model: str
self,
api_key: str,
vector_store: BaseVectorStore,
model: str,
) -> None:
pass

Expand Down
66 changes: 7 additions & 59 deletions task_whisperer/src/task_generation/openai.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
from typing import List, Tuple
from typing import List

from jinja2 import Environment, FileSystemLoader

from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_community.vectorstores import FAISS
from langchain_community.callbacks import get_openai_callback
import tiktoken

from task_whisperer.src.task_generation.base import BaseTaskGenerator
from task_whisperer.src.vector_store.base import BaseVectorStore


GPT_MODEL = "gpt-3.5-turbo"
Expand All @@ -27,63 +26,14 @@ class OpenAITaskGenerator(BaseTaskGenerator):
def __init__(
self,
api_key: str,
faiss_index_root_path: str,
vector_store: BaseVectorStore,
model: str = GPT_MODEL,
embedding_model: str = EMBEDDING_MODEL,
) -> None:
assert api_key, "api_key is required"
assert faiss_index_root_path, "faiss_index_root_path is required"
assert model, "model is required"
assert embedding_model, "embedding_model is required"
self.api_key = api_key
self.model = model
self.embedding_model = embedding_model
self.faiss_index_root_path = faiss_index_root_path

def get_n_tokens(self, query: str) -> int:
try:
encoding = tiktoken.encoding_for_model(self.model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(query))

def read_embeddings(self, project: str):
embedding_path = os.path.join(
self.faiss_index_root_path,
self.kind,
f"faiss_index_{project}_{self.embedding_model}",
)

embedder = OpenAIEmbeddings(api_key=self.api_key, model=self.embedding_model)
faiss_db = FAISS.load_local(embedding_path, embedder)
return embedder, faiss_db

def get_task_embedding(
self, embedder: OpenAIEmbeddings, task_summary: str, task_desc: str = ""
) -> Tuple[List[float], int]:
task_def = f"Summary: {task_summary}\nDescription: {task_desc}"
n_tokens = self.get_n_tokens(task_def)
embedded = embedder.embed_query(task_def)
return embedded, n_tokens

def get_similar_queries(
self,
faiss_db,
embedder: OpenAIEmbeddings,
task_summary: str,
task_desc: str = "",
n_similar: int = 5,
):
task_embed, n_tokens = self.get_task_embedding(
embedder, task_summary, task_desc
)
similar_questions = faiss_db.similarity_search_by_vector(
task_embed, k=n_similar
)
similar_questions = [
similar_question.page_content for similar_question in similar_questions
]
return similar_questions, n_tokens
self.vector_store = vector_store

def get_system_prompt(self):
with open(os.path.join(TEMPLATES_PATH, "system.txt"), "r") as f:
Expand Down Expand Up @@ -117,11 +67,9 @@ def create_task_description(
n_similar_tasks: int = 5,
temperature: float = 0,
):
embedder, faiss_db = self.read_embeddings(project)

if n_similar_tasks > 0:
similar_tasks, n_tokens = self.get_similar_queries(
faiss_db, embedder, task_summary, task_desc, n_similar_tasks
similar_tasks, n_tokens = self.vector_store.similarity_search(
project, task_summary, task_desc, n_similar=n_similar_tasks
)
else:
similar_tasks = []
Expand Down

0 comments on commit afa8558

Please sign in to comment.