-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
create shared memory with vector db , (#251)
* create shared memory with vector db , change the agent to take user_type event , create a test that use qa agent to show sharedmemory with vector db fix pr review fix review issues create shared memory with vector db , change the agent to take user_type event , create a test that use qa agent to show sharedmemory with vector db fix issues raised on pr review * reformat the code * fix offline testing * fix prompt * remove redundant class * Update ChromaVectorStore class and integration tests * reformat with make * Add INDEX_NAME_FILE_STORAGE to .env-sample * create shared memory with vector db , change the agent to take user_type event , create a test that use qa agent to show sharedmemory with vector db fix pr review fix review issues create shared memory with vector db , change the agent to take user_type event , create a test that use qa agent to show sharedmemory with vector db fix issues raised on pr review * reformat the code * fix offline testing * remove redundant class * Update ChromaVectorStore class and integration tests * reformat with make * Refactor imports in user_usage_tracker.py * create shared memory with vector db , change the agent to take user_type event , create a test that use qa agent to show sharedmemory with vector db fix pr review fix review issues create shared memory with vector db , change the agent to take user_type event , create a test that use qa agent to show sharedmemory with vector db fix issues raised on pr review * Add pal dependency and remove unused test file * Fix bug in login functionality * Remove faiss-cpu dependency * Mock tests * Update poetry lock --------- Co-authored-by: Boqi Chen <[email protected]>
- Loading branch information
1 parent
38a2fc3
commit cb512c8
Showing
14 changed files
with
540 additions
and
4,527 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,5 @@ __pycache__ | |
**/dist/ | ||
**/env | ||
**/.coverage | ||
**/db | ||
**.db |
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import chromadb | ||
from chromadb.utils import embedding_functions | ||
from langchain.text_splitter import CharacterTextSplitter | ||
from langchain.docstore.document import Document | ||
import uuid | ||
import sherpa_ai.config as cfg | ||
|
||
class ChromaVectorStore: | ||
def __init__(self, db) -> None: | ||
self.db = db | ||
|
||
@classmethod | ||
def chroma_from_texts(cls, texts, embedding=None , meta_datas=None): | ||
openai_ef = embedding_functions.OpenAIEmbeddingFunction( | ||
model_name="text-embedding-ada-002" | ||
) | ||
embeded_data = openai_ef(texts) | ||
meta_datas = [] if meta_datas is None else meta_datas | ||
client =chromadb.PersistentClient(path="./db") | ||
db = client.get_or_create_collection(name=cfg.INDEX_NAME_FILE_STORAGE,embedding_function=openai_ef) | ||
db.add( | ||
embeddings = embeded_data, | ||
documents = texts, | ||
metadatas = meta_datas, | ||
ids = [str(uuid.uuid1()) for text in texts] | ||
) | ||
|
||
return cls(db) | ||
|
||
@classmethod | ||
def chroma_from_existing(cls): | ||
openai_ef = embedding_functions.OpenAIEmbeddingFunction( | ||
model_name="text-embedding-ada-002" | ||
) | ||
|
||
client =chromadb.PersistentClient(path="./db") | ||
db = client.get_or_create_collection(name=cfg.INDEX_NAME_FILE_STORAGE,embedding_function=openai_ef) | ||
|
||
return cls(db) | ||
|
||
@classmethod | ||
def file_text_splitter(cls,data , meta_data): | ||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | ||
texts = text_splitter.split_text(data) | ||
metadatas = [] | ||
temp_texts = [] | ||
for doc in texts: | ||
metadatas.append(meta_data) | ||
temp_texts.append(f"""'file_content': '{doc}' ,{meta_data}""") | ||
texts = temp_texts | ||
|
||
return {'texts':texts ,'meta_datas':metadatas} | ||
|
||
|
||
def similarity_search(self, query: str="", session_id: str = None): | ||
filter = {} if session_id is None else {"session_id": session_id} | ||
results = self.db.query( | ||
query_texts=[query], | ||
n_results=2, | ||
where=filter, | ||
include=['documents','metadatas'], | ||
) | ||
documents = [] | ||
if results is not None: | ||
for i in range(0,len(results['documents'][0])): | ||
documents.append(Document(metadata=results['metadatas'][0][i],page_content=results['documents'][0][i])) | ||
return documents | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import uuid | ||
|
||
import chromadb | ||
from chromadb.utils import embedding_functions | ||
from langchain.docstore.document import Document | ||
from langchain.text_splitter import CharacterTextSplitter | ||
|
||
import sherpa_ai.config as cfg | ||
|
||
|
||
class ChromaVectorStore: | ||
""" | ||
A class used to represent a Chroma Vector Store. | ||
This class provides methods to create a Chroma Vector Store from texts or from an existing store, | ||
split file text, and perform a similarity search. | ||
... | ||
Attributes | ||
---------- | ||
db : chromadb.PersistentClient | ||
a persistent client to interact with the ChromaDB | ||
Methods | ||
------- | ||
chroma_from_texts(texts, embedding, meta_datas) | ||
Class method to create a Chroma Vector Store from given texts. | ||
chroma_from_existing(embedding) | ||
Class method to create a Chroma Vector Store from an existing store. | ||
file_text_splitter(data, meta_data) | ||
Class method to split file text into chunks. | ||
similarity_search(query, session_id) | ||
Method to perform a similarity search in the Chroma Vector Store. | ||
""" | ||
|
||
def __init__(self, db, path="./db") -> None: | ||
self.db = db | ||
self.path = path | ||
|
||
@classmethod | ||
def chroma_from_texts( | ||
cls, | ||
texts, | ||
embedding=None, | ||
meta_datas=None, | ||
path="./db", | ||
): | ||
# Use OpenAIEmbeddingFunction as default embedding function, this cannot be in the | ||
# method signature for mocking purposes | ||
if embedding is None: | ||
embedding = embedding_functions.OpenAIEmbeddingFunction( | ||
model_name="text-embedding-ada-002" | ||
) | ||
|
||
embeded_data = embedding(texts) | ||
meta_datas = [] if meta_datas is None else meta_datas | ||
client = chromadb.PersistentClient(path=path) | ||
db = client.get_or_create_collection( | ||
name=cfg.INDEX_NAME_FILE_STORAGE, embedding_function=embedding | ||
) | ||
db.add( | ||
embeddings=embeded_data, | ||
documents=texts, | ||
metadatas=meta_datas, | ||
ids=[str(uuid.uuid1()) for _ in texts], | ||
) | ||
|
||
return cls(db, path) | ||
|
||
@classmethod | ||
def chroma_from_existing( | ||
cls, | ||
embedding=None, | ||
path="./db", | ||
): | ||
# Use OpenAIEmbeddingFunction as default embedding function, this cannot be in the | ||
# method signature for mocking purposes | ||
if embedding is None: | ||
embedding = embedding_functions.OpenAIEmbeddingFunction( | ||
model_name="text-embedding-ada-002" | ||
) | ||
|
||
client = chromadb.PersistentClient(path=path) | ||
db = client.get_or_create_collection( | ||
name=cfg.INDEX_NAME_FILE_STORAGE, embedding_function=embedding | ||
) | ||
|
||
return cls(db) | ||
|
||
@classmethod | ||
def file_text_splitter( | ||
cls, | ||
data, | ||
meta_data, | ||
content_key="file_content", | ||
chunk_size=1000, | ||
chunk_overlap=0, | ||
): | ||
text_splitter = CharacterTextSplitter( | ||
chunk_size=chunk_size, chunk_overlap=chunk_overlap | ||
) | ||
texts = text_splitter.split_text(data) | ||
metadatas = [] | ||
temp_texts = [] | ||
for doc in texts: | ||
metadatas.append(meta_data) | ||
temp_texts.append(f"'{content_key}': '{doc}', {meta_data}") | ||
|
||
return {"texts": temp_texts, "meta_datas": metadatas} | ||
|
||
def similarity_search( | ||
self, query: str = "", session_id: str = None, number_of_results=2 | ||
): | ||
filter = {} if session_id is None else {"session_id": session_id} | ||
results = self.db.query( | ||
query_texts=[query], | ||
n_results=number_of_results, | ||
where=filter, | ||
include=["documents", "metadatas"], | ||
) | ||
documents = [] | ||
if results is not None: | ||
for i in range(0, len(results["documents"][0])): | ||
documents.append( | ||
Document( | ||
metadata=results["metadatas"][0][i], | ||
page_content=results["documents"][0][i], | ||
) | ||
) | ||
return documents |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ class EventType(Enum): | |
feedback = 4 | ||
action = 5 | ||
action_output = 6 | ||
user_input = 7 | ||
|
||
|
||
class Event: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from typing import List, Optional | ||
|
||
from langchain.embeddings.openai import OpenAIEmbeddings | ||
|
||
from sherpa_ai.actions.planning import Plan | ||
from sherpa_ai.agents import AgentPool | ||
from sherpa_ai.connectors.chroma_vector_store import ChromaVectorStore | ||
from sherpa_ai.events import Event, EventType | ||
from sherpa_ai.memory import SharedMemory | ||
from sherpa_ai.memory.belief import Belief | ||
|
||
|
||
class SharedMemoryWithVectorDB(SharedMemory): | ||
""" | ||
Custom implementation of SharedMemory that integrates with ChromaVectorStore. | ||
Use this class whenever context retrieval from a vector database is needed. | ||
Attributes: | ||
session_id (str): Unique identifier for the current session. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
objective: str, | ||
session_id: str, | ||
agent_pool: AgentPool = None, | ||
): | ||
self.objective = objective | ||
self.agent_pool = agent_pool | ||
self.events: List[Event] = [] | ||
self.plan: Optional[Plan] = None | ||
self.current_step = None | ||
self.session_id = session_id | ||
|
||
def observe(self, belief: Belief): | ||
vec_db = ChromaVectorStore.chroma_from_existing() | ||
|
||
tasks = super().get_by_type(EventType.task) | ||
|
||
task = tasks[-1] if len(tasks) > 0 else None | ||
|
||
# based on the current task search similarity on the context and add it as an | ||
# event type user_input which is going to be used as a context on the prompt | ||
contexts = vec_db.similarity_search(task.content, session_id=self.session_id) | ||
|
||
# Loop through the similarity search results, add the chunks as user_input events which will be added as a context in the belief class. | ||
for context in contexts: | ||
super().add( | ||
agent="", | ||
event_type=EventType.user_input, | ||
content=context.page_content, | ||
) | ||
|
||
belief.set_current_task(task) | ||
|
||
for event in self.events: | ||
if event.event_type in [ | ||
EventType.task, | ||
EventType.result, | ||
EventType.user_input, | ||
]: | ||
belief.update(event) |
1 change: 1 addition & 0 deletions
1
src/tests/data/test_qa_agent_with_vector_shared_memory_test_shared_memory_with_vector.jsonl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"input": [{"text": "You are a **question answering assistant** who solves user questions and offers a detailed solution. Your name is QA Agent. Context: summerize the file rtgfqq 'file_content': 'Avocados are a fruit, not a vegetable. They're technically considered a single-seeded berry, believe it or not. The Eiffel Tower can be 15 cm taller during the summer, due to thermal expansion meaning the iron heats up, the particles gain kinetic energy and take up more space. Trypophobia is the fear of closely-packed holes. Or more specifically, 'an aversion to the sight of irregular patterns or clusters of small holes or bumps.' No crumpets for them, then.Allodoxaphobia is the fear of other people's opinions. It's a rare social phobia that's characterised by an irrational and overwhelming fear of what other people think. Australia is wider than the moon. The moon sits at 3400km in diameter, while Australia’s diameter from east to west is almost 4000km. 'Mellifluous' is a sound that is pleasingly smooth and musical to hear. The Spice Girls were originally a band called Touch. 'When we first started [with the name Touch], we were pretty bland,' Mel C told The Guardian in 2018. 'We felt like we had to fit into a mould.' Emma Bunton auditioned for the role of Bianca Butcher in Eastenders. Baby Spice already had a small part in the soap back in the 90s but tried out for a full-time role. She was pipped to the post by Patsy Palmer but ended up auditioning for the Spice Girls not long after. Human teeth are the only part of the body that cannot heal themselves. Teeth are coated in enamel which is not a living tissue. It's illegal to own just one guinea pig in Switzerland. It's considered animal abuse because they're social beings and get lonely. The Ancient Romans used to drop a piece of toast into their wine for good health - hence why we 'raise a toast'. The heart of a shrimp is located in its head. They also have an open circulatory system, which means they have no arteries and their organs float directly in blood. Amy Poehler was only seven years older than Rachel McAdams when she took on the role of 'cool mom' in Mean Girls. Rachel was 25 as Regina George - Amy was 32 as her mum. People are more creative in the shower. When we take a warm shower, we experience an increased dopamine flow that makes us more creative. Baby rabbits are called kits. Cute! my dog died on march 2021. The unicorn is the national animal of Scotland. It was apparently chosen because of its connection with dominance and chivalry as well as purity and innocence in Celtic mythology. The first aeroplane flew on December 17,1903 and . Wilbur and Orville Wright made four brief flights at Kitty Hawk, North Carolina, with their first powered aircraft, aka the first airplane. Venus is the only planet to spin clockwise. It travels around the sun once every 225 Earth days but it rotates clockwise once every 243 days. Nutmeg is a hallucinogen. The spice contains myristicin, a natural compound that has mind-altering effects if ingested in large doses. A 73-year-old bottle of French Burgundy became the most expensive bottle of wine ever sold at auction in 2018, going for $558,000 (approx £439,300). The bottle of 1945 Romanee-Conti sold at Sotheby for more than 17 times its original estimate of $32,000.' ,'session_id': '6', 'file_name': 'rtgfqq', 'file_type': 'pdf', 'title': 'NoMeaning', 'data_type': 'user_input' Action - Result History: Google Search'query': 'Summary of file rtgfqq' Answer: Rich Text Format (RTF) is a file Format that lets you exchange text files between different word processors in different operating systems (OSes). For example, you can create a file in Microsoft Word and then open it in another word processor, such as Apple Pages or Google Docs. Link:https://www.fileformat.info/format/rtf/egff.htm Google Search 'query': 'Summary of file rtgfqq' Answer: Rich Text Format (RTF) is a file Format that lets you exchange text files between different word processors in different operating systems (OSes). For example, you can create a file in Microsoft Word and then open it in another word processor, such as Apple Pages or Google Docs. Link:https://www.fileformat.info/format/rtf/egff.htm Google Search'query': 'Summary of file rtgfqq' Answer: Rich Text Format (RTF) is a file Format that lets you exchange text files between different word processors in different operating systems (OSes). For example, you can create a file in Microsoft Word and then open it in another word processor, such as Apple Pages or Google Docs. Link:https://www.fileformat.info/format/rtf/egff.htm Given the context and the action-result history, please complete the task mentioned. Include any links you used from the context and history in the result. Task: summerize the file rtgfqq Result:', 'agent': 'human'"}], "output": "Based on the file content provided, the file 'rtgfqq' contains various interesting facts and trivia on different topics such as avocados, the Eiffel Tower, phobias, geography, music, biology, pop culture, history, astronomy, and more. Some highlights include the fact that avocados are fruits, not vegetables, Australia is wider than the moon, human teeth cannot heal themselves, it's illegal to own just one guinea pig in Switzerland, and Venus is the only planet to spin clockwise. Additionally, the file mentions fun facts like the Spice Girls were originally a band called Touch, baby rabbits are called kits, and the unicorn is the national animal of Scotland. It also includes historical events such as the first airplane flight by the Wright brothers and the most expensive bottle of wine sold at auction. For more information about the file format RTF, you can refer to the following link: [Rich Text Format (RTF)](https://www.fileformat.info/format/rtf/egff.htm)", "llm_name": "gpt-3.5-turbo"} |
Oops, something went wrong.