Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete and re-index docs when embedding model changes #137

Merged
merged 5 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
from IPython import get_ipython
from IPython.core.magic import Magics, magics_class, line_cell_magic
from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring

from IPython.display import HTML, JSON, Markdown, Math

from jupyter_ai_magics.utils import decompose_model_id, load_providers

from .providers import BaseProvider


Expand All @@ -37,6 +34,7 @@ def _repr_mimebundle_(self, include=None, exclude=None):
}
)


class TextWithMetadata(object):
def __init__(self, text, metadata):
self.text = text
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SagemakerEndpoint
)
from langchain.utils import get_from_dict_or_env
from langchain.llms.utils import enforce_stop_tokens

from pydantic import BaseModel, Extra, root_validator
from langchain.chat_models import ChatOpenAI
Expand Down
4 changes: 2 additions & 2 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, Optional, Tuple, Union
from importlib_metadata import entry_points
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES

from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider

from jupyter_ai_magics.providers import BaseProvider
Expand All @@ -14,6 +15,7 @@ def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())

providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.model_providers")
Expand Down Expand Up @@ -47,8 +49,6 @@ def load_embedding_providers(log: Optional[Logger] = None) -> Dict[str, BaseEmbe

return providers



def decompose_model_id(model_id: str, providers: Dict[str, BaseProvider]) -> Tuple[str, str]:
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
if model_id in MODEL_ID_ALIASES:
Expand Down
17 changes: 13 additions & 4 deletions packages/jupyter-ai/jupyter_ai/actors/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,19 @@ def _process_message(self, message: HumanChatMessage):

self.get_llm_chain()

result = self.llm_chain({"question": query, "chat_history": self.chat_history})
response = result['answer']
self.chat_history.append((query, response))
self.reply(response, message)
try:
result = self.llm_chain({"question": query, "chat_history": self.chat_history})
response = result['answer']
self.chat_history.append((query, response))
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
response = """Sorry, an error occurred while reading the from the learned documents.
If you have changed the embedding provider, try deleting the existing index by running
`/learn -d` command and then re-submitting the `learn <directory>` to learn the documents,
and then asking the question again.
"""
self.reply(response, message)


class Retriever(BaseRetriever):
Expand Down
11 changes: 7 additions & 4 deletions packages/jupyter-ai/jupyter_ai/actors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from uuid import uuid4
import time
import logging
from typing import Dict, Type, Union
from typing import Dict, Optional, Type, Union
import traceback

from jupyter_ai_magics.providers import BaseProvider
Expand Down Expand Up @@ -48,6 +48,7 @@ def __init__(
self.llm_chain = None
self.embeddings = None
self.embeddings_params = None
self.embedding_model_id = None

def process_message(self, message: HumanChatMessage):
"""Processes the message passed by the `Router`"""
Expand All @@ -62,12 +63,12 @@ def _process_message(self, message: HumanChatMessage):
"""Processes the message passed by the `Router`"""
raise NotImplementedError("Should be implemented by subclasses.")

def reply(self, response, message: HumanChatMessage):
def reply(self, response, message: Optional[HumanChatMessage] = None):
m = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=response,
reply_to=message.id
reply_to=message.id if message else ""
)
self.reply_queue.put(m)

Expand All @@ -87,12 +88,14 @@ def get_embeddings(self):
actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER)
provider = ray.get(actor.get_provider.remote())
embedding_params = ray.get(actor.get_provider_params.remote())
embedding_model_id = ray.get(actor.get_model_id.remote())

if not provider:
return None

if provider.__class__.__name__ != self.embeddings.__class__.__name__:
if embedding_model_id != self.embedding_model_id:
self.embeddings = provider(**embedding_params)

return self.embeddings

def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]):
Expand Down
10 changes: 4 additions & 6 deletions packages/jupyter-ai/jupyter_ai/actors/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,23 @@ def update(self, config: GlobalConfig, save_to_disk: bool = True):
self._update_chat_provider(config)
self._update_embeddings_provider(config)
if save_to_disk:
self._save()
self._save(config)
self.config = config

def _update_chat_provider(self, config: GlobalConfig):
actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER)
handle = actor.update.remote(config)
ray.get(handle)
ray.get(actor.update.remote(config))

def _update_embeddings_provider(self, config: GlobalConfig):
actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER)
handle = actor.update.remote(config)
ray.get(handle)
ray.get(actor.update.remote(config))

def _save(self, config: GlobalConfig):
if not os.path.exists:
os.makedirs(self.save_dir)

with open(self.save_path, 'w') as f:
f.write(json.dumps(config))
f.write(config.json())

def _load(self):
if os.path.exists(self.save_path):
Expand Down
13 changes: 12 additions & 1 deletion packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def __init__(self, log: Logger):
self.log = log
self.provider = None
self.provider_params = None
self.model_id = None

def update(self, config: GlobalConfig):
model_id = config.embeddings_provider_id
Expand All @@ -33,9 +34,19 @@ def update(self, config: GlobalConfig):

self.provider = provider.provider_klass
self.provider_params = provider_params
previous_model_id = self.model_id
self.model_id = model_id

if previous_model_id and previous_model_id != model_id:
# delete the index
actor = ray.get_actor(ACTOR_TYPE.LEARN)
actor.delete_and_relearn.remote()

def get_provider(self):
return self.provider

def get_provider_params(self):
return self.provider_params
return self.provider_params

def get_model_id(self):
return self.model_id
109 changes: 91 additions & 18 deletions packages/jupyter-ai/jupyter_ai/actors/learn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import argparse
import time
from typing import List

import ray
Expand All @@ -14,30 +16,34 @@
)
from langchain.schema import Document

from jupyter_ai.models import HumanChatMessage
from jupyter_ai.models import HumanChatMessage, IndexedDir, IndexMetadata
from jupyter_ai.actors.base import BaseActor, Logger
from jupyter_ai.document_loaders.directory import RayRecursiveDirectoryLoader
from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter


INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), 'jupyter_ai', 'indices')
METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, 'metadata.json')

@ray.remote
class LearnActor(BaseActor):

def __init__(self, reply_queue: Queue, log: Logger, root_dir: str):
super().__init__(reply_queue=reply_queue, log=log)
self.root_dir = root_dir
self.index_save_dir = os.path.join(jupyter_data_dir(), 'jupyter_ai', 'indices')
self.chunk_size = 2000
self.chunk_overlap = 100
self.parser.prog = '/learn'
self.parser.add_argument('-v', '--verbose', action='store_true')
self.parser.add_argument('-d', '--delete', action='store_true')
self.parser.add_argument('-l', '--list', action='store_true')
self.parser.add_argument('path', nargs=argparse.REMAINDER)
self.index_name = 'default'
self.index = None

if not os.path.exists(self.index_save_dir):
os.makedirs(self.index_save_dir)
self.metadata = IndexMetadata(dirs=[])

if not os.path.exists(INDEX_SAVE_DIR):
os.makedirs(INDEX_SAVE_DIR)

self.load_or_create()

Expand All @@ -57,6 +63,10 @@ def _process_message(self, message: HumanChatMessage):
self.delete()
self.reply(f"👍 I have deleted everything I previously learned.", message)
return

if args.list:
self.reply(self._build_list_response())
return

# Make sure the path exists.
if not len(args.path) == 1:
Expand All @@ -72,6 +82,24 @@ def _process_message(self, message: HumanChatMessage):
if args.verbose:
self.reply(f"Loading and splitting files for {load_path}", message)

self.learn_dir(load_path)
self.save()

response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them.
You can ask questions about these docs by prefixing your message with **/ask**."""
self.reply(response, message)

def _build_list_response(self):
if not self.metadata.dirs:
return "There are no docs that have been learned yet."

dirs = [dir.path for dir in self.metadata.dirs]
dir_list = "\n- " + "\n- ".join(dirs) + "\n\n"
message = f"""I can answer questions from docs in these directories:
{dir_list}"""
return message

def learn_dir(self, path: str):
splitters={
'.py': PythonCodeTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap),
'.md': MarkdownTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap),
Expand All @@ -83,26 +111,56 @@ def _process_message(self, message: HumanChatMessage):
default_splitter=RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
)

loader = RayRecursiveDirectoryLoader(load_path)
texts = loader.load_and_split(text_splitter=splitter)
loader = RayRecursiveDirectoryLoader(path)
texts = loader.load_and_split(text_splitter=splitter)
self.index.add_documents(texts)
self.save()

response = f"""🎉 I have indexed documents at **{load_path}** and I am ready to answer questions about them.
You can ask questions about these docs by prefixing your message with **/ask**."""
self.reply(response, message)

def get_index(self):
return self.index
self._add_dir_to_metadata(path)

def _add_dir_to_metadata(self, path: str):
dirs = self.metadata.dirs
index = next((i for i, dir in enumerate(dirs) if dir.path == path), None)
if not index:
dirs.append(IndexedDir(path=path))
self.metadata.dirs = dirs

def delete_and_relearn(self):
if not self.metadata.dirs:
self.delete()
return
dlqqq marked this conversation as resolved.
Show resolved Hide resolved
message = """🔔 Hi there, It seems like you have updated the embeddings model. For the **/ask**
command to work with the new model, I have to re-learn the documents you had previously
submitted for learning. Please wait to use the **/ask** command until I am done with this task."""
self.reply(message)

metadata = self.metadata
self.delete()
self.relearn(metadata)

def delete(self):
self.index = None
paths = [os.path.join(self.index_save_dir, self.index_name+ext) for ext in ['.pkl', '.faiss']]
self.metadata = IndexMetadata(dirs=[])
paths = [os.path.join(INDEX_SAVE_DIR, self.index_name+ext) for ext in ['.pkl', '.faiss']]
for path in paths:
if os.path.isfile(path):
os.remove(path)
self.create()

def relearn(self, metadata: IndexMetadata):
# Index all dirs in the metadata
if not metadata.dirs:
return

for dir in metadata.dirs:
self.learn_dir(dir.path)

self.save()

dir_list = "\n- " + "\n- ".join([dir.path for dir in self.metadata.dirs]) + "\n\n"
message = f"""🎉 I am done learning docs in these directories:
{dir_list} I am ready to answer questions about them.
You can ask questions about these docs by prefixing your message with **/ask**."""
self.reply(message)

def create(self):
embeddings = self.get_embeddings()
if not embeddings:
Expand All @@ -112,18 +170,33 @@ def create(self):

def save(self):
if self.index is not None:
self.index.save_local(self.index_save_dir, index_name=self.index_name)
self.index.save_local(INDEX_SAVE_DIR, index_name=self.index_name)

self.save_metadata()

def save_metadata(self):
with open(METADATA_SAVE_PATH, 'w') as f:
f.write(self.metadata.json())

def load_or_create(self):
embeddings = self.get_embeddings()
if not embeddings:
return
if self.index is None:
try:
self.index = FAISS.load_local(self.index_save_dir, embeddings, index_name=self.index_name)
self.index = FAISS.load_local(INDEX_SAVE_DIR, embeddings, index_name=self.index_name)
self.load_metadata()
except Exception as e:
self.create()

def load_metadata(self):
if not os.path.exists(METADATA_SAVE_PATH):
return

with open(METADATA_SAVE_PATH, 'r', encoding='utf-8') as f:
j = json.loads(f.read())
self.metadata = IndexMetadata(**j)

def get_relevant_documents(self, question: str) -> List[Document]:
if self.index:
docs = self.index.similarity_search(question)
Expand Down
Loading