From 767efc0145ab771a6c7733c4f2417698811cf0dc Mon Sep 17 00:00:00 2001 From: david qiu Date: Fri, 5 May 2023 07:46:23 -0700 Subject: [PATCH] Runtime model configurability (#146) * Refactored provider load, decompose logic, aded model provider list api * Renamed model * Sorted the provider names * WIP: Embedding providers * Added embeddings provider api * Added missing import * Moved providers to ray actor, added config actor * Ability to load llm and embeddings from config * Moved llm creation to specific actors * Added apis for fetching, updating config. Fixed config update, error handling * Updated as per PR feedback * Fixes issue with cohere embeddings, api keys not working * Added an error check when embedding change causes read error * Delete and re-index docs when embedding model changes (#137) * Added an error check when embedding change causes read error * Refactored provider load, decompose logic, aded model provider list api * Re-indexes dirs when embeddings change, learn list command * Fixed typo, simplified adding metadata * Moved index dir, metadata path to constants * Chat settings UI (#141) * remove unused div * automatically create config if not present * allow all-caps envvars in config * implement basic chat settings UI * hide API key text inputs * limit popup size, show success banner * show welcome message if no LM is selected * fix buggy UI with no selected LM/EM * exclude legacy OpenAI chat provider used in magics * Added a button with welcome message --------- Co-authored-by: Jain * Various chat chain enhancements and fixes (#144) * fix /clear command * use model IDs to compare LLMs instead * specify stop sequence in chat chain * add empty AI message, improve system prompt * add RTD configuration --------- Co-authored-by: Piyush Jain Co-authored-by: Jain --- .readthedocs.yaml | 17 ++ .../jupyter_ai_magics/__init__.py | 7 + .../jupyter_ai_magics/aliases.py | 6 + .../jupyter_ai_magics/embedding_providers.py | 75 +++++ .../jupyter_ai_magics/magics.py | 40 +-- .../jupyter_ai_magics/providers.py | 3 + .../jupyter_ai_magics/utils.py | 70 +++++ packages/jupyter-ai-magics/pyproject.toml | 6 + packages/jupyter-ai/.gitignore | 3 - packages/jupyter-ai/jupyter_ai/actors/ask.py | 62 ++-- packages/jupyter-ai/jupyter_ai/actors/base.py | 58 +++- .../jupyter_ai/actors/chat_provider.py | 40 +++ .../jupyter-ai/jupyter_ai/actors/config.py | 61 ++++ .../jupyter-ai/jupyter_ai/actors/default.py | 73 +++-- .../jupyter_ai/actors/embeddings_provider.py | 52 ++++ .../jupyter-ai/jupyter_ai/actors/generate.py | 23 +- .../jupyter-ai/jupyter_ai/actors/learn.py | 143 ++++++++-- .../jupyter-ai/jupyter_ai/actors/providers.py | 42 +++ .../jupyter-ai/jupyter_ai/actors/router.py | 5 +- packages/jupyter-ai/jupyter_ai/extension.py | 70 +++-- packages/jupyter-ai/jupyter_ai/handlers.py | 122 +++++++- packages/jupyter-ai/jupyter_ai/models.py | 27 ++ .../src/components/chat-settings.tsx | 266 ++++++++++++++++++ packages/jupyter-ai/src/components/chat.tsx | 125 ++++++-- packages/jupyter-ai/src/components/select.tsx | 48 ++++ packages/jupyter-ai/src/handler.ts | 63 ++++- 26 files changed, 1323 insertions(+), 184 deletions(-) create mode 100644 .readthedocs.yaml create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/utils.py create mode 100644 packages/jupyter-ai/jupyter_ai/actors/chat_provider.py create mode 100644 packages/jupyter-ai/jupyter_ai/actors/config.py create mode 100644 packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py create mode 100644 packages/jupyter-ai/jupyter_ai/actors/providers.py create mode 100644 packages/jupyter-ai/src/components/chat-settings.tsx create mode 100644 packages/jupyter-ai/src/components/select.tsx diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 000000000..8654753a1 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +sphinx: + configuration: docs/source/conf.py + +python: + install: + - requirements: docs/requirements.txt diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index d7f5de232..ff35147db 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -10,8 +10,15 @@ HfHubProvider, OpenAIProvider, ChatOpenAIProvider, + ChatOpenAINewProvider, SmEndpointProvider ) +# expose embedding model providers on the package root +from .embedding_providers import ( + OpenAIEmbeddingsProvider, + CohereEmbeddingsProvider, + HfHubEmbeddingsProvider +) from .providers import BaseProvider def load_ipython_extension(ipython): diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py new file mode 100644 index 000000000..74be10485 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py @@ -0,0 +1,6 @@ +MODEL_ID_ALIASES = { + "gpt2": "huggingface_hub:gpt2", + "gpt3": "openai:text-davinci-003", + "chatgpt": "openai-chat:gpt-3.5-turbo", + "gpt4": "openai-chat:gpt-4", +} \ No newline at end of file diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py new file mode 100644 index 000000000..cfc8481bc --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -0,0 +1,75 @@ +from typing import ClassVar, List, Type +from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy +from pydantic import BaseModel, Extra +from langchain.embeddings import OpenAIEmbeddings, CohereEmbeddings, HuggingFaceHubEmbeddings +from langchain.embeddings.base import Embeddings + + +class BaseEmbeddingsProvider(BaseModel): + """Base class for embedding providers""" + + class Config: + extra = Extra.allow + + id: ClassVar[str] = ... + """ID for this provider class.""" + + name: ClassVar[str] = ... + """User-facing name of this provider.""" + + models: ClassVar[List[str]] = ... + """List of supported models by their IDs. For registry providers, this will + be just ["*"].""" + + model_id_key: ClassVar[str] = ... + """Kwarg expected by the upstream LangChain provider.""" + + pypi_package_deps: ClassVar[List[str]] = [] + """List of PyPi package dependencies.""" + + auth_strategy: ClassVar[AuthStrategy] = None + """Authentication/authorization strategy. Declares what credentials are + required to use this model provider. Generally should not be `None`.""" + + model_id: str + + provider_klass: ClassVar[Type[Embeddings]] + + +class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider): + id = "openai" + name = "OpenAI" + models = [ + "text-embedding-ada-002" + ] + model_id_key = "model" + pypi_package_deps = ["openai"] + auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + provider_klass = OpenAIEmbeddings + + +class CohereEmbeddingsProvider(BaseEmbeddingsProvider): + id = "cohere" + name = "Cohere" + models = [ + 'large', + 'multilingual-22-12', + 'small' + ] + model_id_key = "model" + pypi_package_deps = ["cohere"] + auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY") + provider_klass = CohereEmbeddings + + +class HfHubEmbeddingsProvider(BaseEmbeddingsProvider): + id = "huggingface_hub" + name = "HuggingFace Hub" + models = ["*"] + model_id_key = "repo_id" + # ipywidgets needed to suppress tqdm warning + # https://stackoverflow.com/questions/67998191 + # tqdm is a dependency of huggingface_hub + pypi_package_deps = ["huggingface_hub", "ipywidgets"] + auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN") + provider_klass = HuggingFaceHubEmbeddings diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index f79a70973..927766ba9 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -2,16 +2,14 @@ import json import os import re -import traceback import warnings from typing import Optional -from importlib_metadata import entry_points from IPython import get_ipython from IPython.core.magic import Magics, magics_class, line_cell_magic from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring -from IPython.display import HTML, Image, JSON, Markdown, Math - +from IPython.display import HTML, JSON, Markdown, Math +from jupyter_ai_magics.utils import decompose_model_id, load_providers from .providers import BaseProvider @@ -36,8 +34,8 @@ def _repr_mimebundle_(self, include=None, exclude=None): } ) -class TextWithMetadata: +class TextWithMetadata(object): def __init__(self, text, metadata): self.text = text self.metadata = metadata @@ -109,18 +107,7 @@ def __init__(self, shell): "no longer supported. Instead, please use: " "`from langchain.chat_models import ChatOpenAI`") - # load model providers from entry point - self.providers = {} - eps = entry_points() - model_provider_eps = eps.select(group="jupyter_ai.model_providers") - for model_provider_ep in model_provider_eps: - try: - Provider = model_provider_ep.load() - except: - print(f"Unable to load entry point {model_provider_ep.name}"); - traceback.print_exc() - continue - self.providers[Provider.id] = Provider + self.providers = load_providers() def _ai_help_command_markdown(self): table = ("| Command | Description |\n" @@ -272,24 +259,7 @@ def _append_exchange_openai(self, prompt: str, output: str): }) def _decompose_model_id(self, model_id: str): - """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" - if model_id in MODEL_ID_ALIASES: - model_id = MODEL_ID_ALIASES[model_id] - - if ":" not in model_id: - # case: model ID was not provided with a prefix indicating the provider - # ID. try to infer the provider ID before returning (None, None). - - # naively search through the dictionary and return the first provider - # that provides a model of the same ID. - for provider_id, Provider in self.providers.items(): - if model_id in Provider.models: - return (provider_id, model_id) - - return (None, None) - - provider_id, local_model_id = model_id.split(":", 1) - return (provider_id, local_model_id) + return decompose_model_id(model_id, self.providers) def _get_provider(self, provider_id: Optional[str]) -> BaseProvider: """Returns the model provider ID and class for a model ID. Returns None if indeterminate.""" diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 852c6ce7f..22507acbc 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -15,6 +15,7 @@ SagemakerEndpoint ) from langchain.utils import get_from_dict_or_env +from langchain.llms.utils import enforce_stop_tokens from pydantic import BaseModel, Extra, root_validator from langchain.chat_models import ChatOpenAI @@ -298,3 +299,5 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint): model_id_key = "endpoint_name" pypi_package_deps = ["boto3"] auth_strategy = AwsAuthStrategy() + + diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py new file mode 100644 index 000000000..aab722240 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -0,0 +1,70 @@ +import logging +from typing import Dict, Optional, Tuple, Union +from importlib_metadata import entry_points +from jupyter_ai_magics.aliases import MODEL_ID_ALIASES + +from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider + +from jupyter_ai_magics.providers import BaseProvider + + +Logger = Union[logging.Logger, logging.LoggerAdapter] + + +def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]: + if not log: + log = logging.getLogger() + log.addHandler(logging.NullHandler()) + + providers = {} + eps = entry_points() + model_provider_eps = eps.select(group="jupyter_ai.model_providers") + for model_provider_ep in model_provider_eps: + try: + provider = model_provider_ep.load() + except: + log.error(f"Unable to load model provider class from entry point `{model_provider_ep.name}`.") + continue + providers[provider.id] = provider + log.info(f"Registered model provider `{provider.id}`.") + + return providers + + +def load_embedding_providers(log: Optional[Logger] = None) -> Dict[str, BaseEmbeddingsProvider]: + if not log: + log = logging.getLogger() + log.addHandler(logging.NullHandler()) + providers = {} + eps = entry_points() + model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers") + for model_provider_ep in model_provider_eps: + try: + provider = model_provider_ep.load() + except: + log.error(f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`.") + continue + providers[provider.id] = provider + log.info(f"Registered embeddings model provider `{provider.id}`.") + + return providers + +def decompose_model_id(model_id: str, providers: Dict[str, BaseProvider]) -> Tuple[str, str]: + """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" + if model_id in MODEL_ID_ALIASES: + model_id = MODEL_ID_ALIASES[model_id] + + if ":" not in model_id: + # case: model ID was not provided with a prefix indicating the provider + # ID. try to infer the provider ID before returning (None, None). + + # naively search through the dictionary and return the first provider + # that provides a model of the same ID. + for provider_id, provider in providers.items(): + if model_id in provider.models: + return (provider_id, model_id) + + return (None, None) + + provider_id, local_model_id = model_id.split(":", 1) + return (provider_id, local_model_id) diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 66a40e3a3..a95720ecc 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -54,8 +54,14 @@ cohere = "jupyter_ai_magics:CohereProvider" huggingface_hub = "jupyter_ai_magics:HfHubProvider" openai = "jupyter_ai_magics:OpenAIProvider" openai-chat = "jupyter_ai_magics:ChatOpenAIProvider" +openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider" sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider" +[project.entry-points."jupyter_ai.embeddings_model_providers"] +cohere = "jupyter_ai_magics:CohereEmbeddingsProvider" +huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider" +openai = "jupyter_ai_magics:OpenAIEmbeddingsProvider" + [tool.hatch.version] source = "nodejs" diff --git a/packages/jupyter-ai/.gitignore b/packages/jupyter-ai/.gitignore index 7fa065974..56891ff87 100644 --- a/packages/jupyter-ai/.gitignore +++ b/packages/jupyter-ai/.gitignore @@ -119,9 +119,6 @@ dmypy.json # OSX files .DS_Store -# local config storing authn credentials -config.py - # vscode .vscode diff --git a/packages/jupyter-ai/jupyter_ai/actors/ask.py b/packages/jupyter-ai/jupyter_ai/actors/ask.py index 5676154f0..e78837ca4 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/ask.py +++ b/packages/jupyter-ai/jupyter_ai/actors/ask.py @@ -1,10 +1,12 @@ import argparse +from typing import Dict, List, Type +from jupyter_ai_magics.providers import BaseProvider import ray from ray.util.queue import Queue -from langchain import OpenAI from langchain.chains import ConversationalRetrievalChain +from langchain.schema import BaseRetriever, Document from jupyter_ai.models import HumanChatMessage from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger @@ -21,21 +23,18 @@ class AskActor(BaseActor): def __init__(self, reply_queue: Queue, log: Logger): super().__init__(reply_queue=reply_queue, log=log) - index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value) - handle = index_actor.get_index.remote() - vectorstore = ray.get(handle) - if not vectorstore: - return - - self.chat_history = [] - self.chat_provider = ConversationalRetrievalChain.from_llm( - OpenAI(temperature=0, verbose=True), - vectorstore.as_retriever() - ) self.parser.prog = '/ask' self.parser.add_argument('query', nargs=argparse.REMAINDER) + def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): + retriever = Retriever() + self.llm = provider(**provider_params) + self.chat_history = [] + self.llm_chain = ConversationalRetrievalChain.from_llm( + self.llm, + retriever + ) def _process_message(self, message: HumanChatMessage): args = self.parse_args(message) @@ -46,13 +45,34 @@ def _process_message(self, message: HumanChatMessage): self.reply(f"{self.parser.format_usage()}", message) return + self.get_llm_chain() + + try: + result = self.llm_chain({"question": query, "chat_history": self.chat_history}) + response = result['answer'] + self.chat_history.append((query, response)) + self.reply(response, message) + except AssertionError as e: + self.log.error(e) + response = """Sorry, an error occurred while reading the from the learned documents. + If you have changed the embedding provider, try deleting the existing index by running + `/learn -d` command and then re-submitting the `learn ` to learn the documents, + and then asking the question again. + """ + self.reply(response, message) + + +class Retriever(BaseRetriever): + """Wrapper retriever class to get relevant docs + from the vector store, this is important because + of inconsistent de-serialization of index when it's + accessed directly from the ask actor. + """ + + def get_relevant_documents(self, question: str): index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value) - handle = index_actor.get_index.remote() - vectorstore = ray.get(handle) - # Have to reference the latest index - self.chat_provider.retriever = vectorstore.as_retriever() - - result = self.chat_provider({"question": query, "chat_history": self.chat_history}) - response = result['answer'] - self.chat_history.append((query, response)) - self.reply(response, message) + docs = ray.get(index_actor.get_relevant_documents.remote(question)) + return docs + + async def aget_relevant_documents(self, query: str) -> List[Document]: + return await super().aget_relevant_documents(query) \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py index 587f84560..84a62bf8b 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/base.py +++ b/packages/jupyter-ai/jupyter_ai/actors/base.py @@ -3,22 +3,33 @@ from uuid import uuid4 import time import logging -from typing import Union +from typing import Dict, Optional, Type, Union import traceback +from jupyter_ai_magics.providers import BaseProvider +import ray + from ray.util.queue import Queue from jupyter_ai.models import HumanChatMessage, AgentChatMessage - Logger = Union[logging.Logger, logging.LoggerAdapter] class ACTOR_TYPE(str, Enum): + # the top level actor that routes incoming messages to the appropriate actor + ROUTER = "router" + + # the default actor that responds to messages using a language model DEFAULT = "default" + ASK = "ask" LEARN = 'learn' MEMORY = 'memory' GENERATE = 'generate' + PROVIDERS = 'providers' + CONFIG = 'config' + CHAT_PROVIDER = 'chat_provider' + EMBEDDINGS_PROVIDER = 'embeddings_provider' COMMANDS = { '/ask': ACTOR_TYPE.ASK, @@ -37,6 +48,12 @@ def __init__( self.log = log self.reply_queue = reply_queue self.parser = argparse.ArgumentParser() + self.llm = None + self.llm_params = None + self.llm_chain = None + self.embeddings = None + self.embeddings_params = None + self.embedding_model_id = None def process_message(self, message: HumanChatMessage): """Processes the message passed by the `Router`""" @@ -51,14 +68,47 @@ def _process_message(self, message: HumanChatMessage): """Processes the message passed by the `Router`""" raise NotImplementedError("Should be implemented by subclasses.") - def reply(self, response, message: HumanChatMessage): + def reply(self, response, message: Optional[HumanChatMessage] = None): m = AgentChatMessage( id=uuid4().hex, time=time.time(), body=response, - reply_to=message.id + reply_to=message.id if message else "" ) self.reply_queue.put(m) + + def get_llm_chain(self): + actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) + lm_provider = ray.get(actor.get_provider.remote()) + lm_provider_params = ray.get(actor.get_provider_params.remote()) + + curr_lm_id = f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None + next_lm_id = f'{lm_provider.id}:{lm_provider_params["model_id"]}' if lm_provider else None + + if not lm_provider: + return None + + if curr_lm_id != next_lm_id: + self.log.info(f"Switching chat language model from {curr_lm_id} to {next_lm_id}.") + self.create_llm_chain(lm_provider, lm_provider_params) + return self.llm_chain + + def get_embeddings(self): + actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER) + provider = ray.get(actor.get_provider.remote()) + embedding_params = ray.get(actor.get_provider_params.remote()) + embedding_model_id = ray.get(actor.get_model_id.remote()) + + if not provider: + return None + + if embedding_model_id != self.embedding_model_id: + self.embeddings = provider(**embedding_params) + + return self.embeddings + + def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): + raise NotImplementedError("Should be implemented by subclasses") def parse_args(self, message): args = message.body.split(' ') diff --git a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py new file mode 100644 index 000000000..5885e8e9d --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py @@ -0,0 +1,40 @@ +from jupyter_ai.actors.base import Logger, ACTOR_TYPE +from jupyter_ai.models import GlobalConfig +import ray + +@ray.remote +class ChatProviderActor(): + + def __init__(self, log: Logger): + self.log = log + self.provider = None + self.provider_params = None + + def update(self, config: GlobalConfig): + model_id = config.model_provider_id + actor = ray.get_actor(ACTOR_TYPE.PROVIDERS.value) + local_model_id, provider = ray.get( + actor.get_model_provider_data.remote(model_id) + ) + + if not provider: + raise ValueError(f"No provider and model found with '{model_id}'") + + provider_params = { "model_id": local_model_id} + + auth_strategy = provider.auth_strategy + if auth_strategy and auth_strategy.type == "env": + api_keys = config.api_keys + name = auth_strategy.name + if name not in api_keys: + raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.") + provider_params[name.lower()] = api_keys[name] + + self.provider = provider + self.provider_params = provider_params + + def get_provider(self): + return self.provider + + def get_provider_params(self): + return self.provider_params \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/config.py b/packages/jupyter-ai/jupyter_ai/actors/config.py new file mode 100644 index 000000000..dc9a719ce --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/actors/config.py @@ -0,0 +1,61 @@ +import json +import os +from jupyter_ai.actors.base import ACTOR_TYPE, Logger +from jupyter_ai.models import GlobalConfig +import ray +from jupyter_core.paths import jupyter_data_dir + + +@ray.remote +class ConfigActor(): + """Provides model and embedding provider id along + with the credentials to authenticate providers. + """ + + def __init__(self, log: Logger): + self.log = log + self.save_dir = os.path.join(jupyter_data_dir(), 'jupyter_ai') + self.save_path = os.path.join(self.save_dir, 'config.json') + self.config = None + self._load() + + def update(self, config: GlobalConfig, save_to_disk: bool = True): + self._update_chat_provider(config) + self._update_embeddings_provider(config) + if save_to_disk: + self._save(config) + self.config = config + + def _update_chat_provider(self, config: GlobalConfig): + if not config.model_provider_id: + return + + actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) + ray.get(actor.update.remote(config)) + + def _update_embeddings_provider(self, config: GlobalConfig): + if not config.embeddings_provider_id: + return + + actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER) + ray.get(actor.update.remote(config)) + + def _save(self, config: GlobalConfig): + if not os.path.exists: + os.makedirs(self.save_dir) + + with open(self.save_path, 'w') as f: + f.write(config.json()) + + def _load(self): + if os.path.exists(self.save_path): + with open(self.save_path, 'r', encoding='utf-8') as f: + config = GlobalConfig(**json.loads(f.read())) + self.update(config, False) + return + + # otherwise, create a new empty config file + self.update(GlobalConfig(), True) + + def get_config(self): + return self.config \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/default.py b/packages/jupyter-ai/jupyter_ai/actors/default.py index 6ab33ae08..7914daad3 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/default.py +++ b/packages/jupyter-ai/jupyter_ai/actors/default.py @@ -1,42 +1,75 @@ +from typing import Dict, Type, List import ray -from ray.util.queue import Queue from langchain import ConversationChain from langchain.prompts import ( ChatPromptTemplate, MessagesPlaceholder, - SystemMessagePromptTemplate, - HumanMessagePromptTemplate + HumanMessagePromptTemplate, + SystemMessagePromptTemplate +) +from langchain.schema import ( + AIMessage, ) -from jupyter_ai.actors.base import BaseActor, Logger, ACTOR_TYPE +from jupyter_ai.actors.base import BaseActor, ACTOR_TYPE from jupyter_ai.actors.memory import RemoteMemory -from jupyter_ai.models import HumanChatMessage -from jupyter_ai_magics.providers import ChatOpenAINewProvider +from jupyter_ai.models import HumanChatMessage, ClearMessage, ChatMessage +from jupyter_ai_magics.providers import BaseProvider -SYSTEM_PROMPT = "The following is a friendly conversation between a human and an AI, whose name is Jupyter AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know." +SYSTEM_PROMPT = """ +You are Jupyter AI, a conversational assistant living in JupyterLab to help users. +You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. +You are talkative and provides lots of specific details from its context. +You may use Markdown to format your response. +Code blocks must be formatted in Markdown. +Math should be rendered with inline TeX markup, surrounded by $. +If you do not know the answer to a question, answer truthfully by responding that you do not know. +The following is a friendly conversation between you and a human. +""".strip() @ray.remote class DefaultActor(BaseActor): - def __init__(self, reply_queue: Queue, log: Logger): - super().__init__(reply_queue=reply_queue, log=log) - provider = ChatOpenAINewProvider(model_id="gpt-3.5-turbo") - - # Create a conversation memory - memory = RemoteMemory(actor_name=ACTOR_TYPE.MEMORY) + def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): + super().__init__(*args, **kwargs) + self.memory = None + self.chat_history = chat_history + + def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): + llm = provider(**provider_params) + self.memory = RemoteMemory(actor_name=ACTOR_TYPE.MEMORY) prompt_template = ChatPromptTemplate.from_messages([ - SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT), + SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(provider_name=llm.name, local_model_id=llm.model_id), MessagesPlaceholder(variable_name="history"), - HumanMessagePromptTemplate.from_template("{input}") + HumanMessagePromptTemplate.from_template("{input}"), + AIMessage(content="") ]) - chain = ConversationChain( - llm=provider, + self.llm = llm + self.llm_chain = ConversationChain( + llm=llm, prompt=prompt_template, verbose=True, - memory=memory + memory=self.memory ) - self.chat_provider = chain + + def clear_memory(self): + if not self.memory: + return + + # clear chain memory + self.memory.clear() + + # clear transcript for existing chat clients + reply_message = ClearMessage() + self.reply_queue.put(reply_message) + + # clear transcript for new chat clients + self.chat_history.clear() def _process_message(self, message: HumanChatMessage): - response = self.chat_provider.predict(input=message.body) + self.get_llm_chain() + response = self.llm_chain.predict( + input=message.body, + stop=["\nHuman:"] + ) self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py new file mode 100644 index 000000000..068ce0388 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py @@ -0,0 +1,52 @@ +from jupyter_ai.actors.base import Logger, ACTOR_TYPE +from jupyter_ai.models import GlobalConfig +import ray + +@ray.remote +class EmbeddingsProviderActor(): + + def __init__(self, log: Logger): + self.log = log + self.provider = None + self.provider_params = None + self.model_id = None + + def update(self, config: GlobalConfig): + model_id = config.embeddings_provider_id + actor = ray.get_actor(ACTOR_TYPE.PROVIDERS.value) + local_model_id, provider = ray.get( + actor.get_embeddings_provider_data.remote(model_id) + ) + + if not provider: + raise ValueError(f"No provider and model found with '{model_id}'") + + provider_params = {} + provider_params[provider.model_id_key] = local_model_id + + auth_strategy = provider.auth_strategy + if auth_strategy and auth_strategy.type == "env": + api_keys = config.api_keys + name = auth_strategy.name + if name not in api_keys: + raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.") + provider_params[name.lower()] = api_keys[name] + + self.provider = provider.provider_klass + self.provider_params = provider_params + previous_model_id = self.model_id + self.model_id = model_id + + if previous_model_id and previous_model_id != model_id: + # delete the index + actor = ray.get_actor(ACTOR_TYPE.LEARN) + actor.delete_and_relearn.remote() + + def get_provider(self): + return self.provider + + def get_provider_params(self): + return self.provider_params + + def get_model_id(self): + return self.model_id \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/generate.py b/packages/jupyter-ai/jupyter_ai/actors/generate.py index bbcdef88b..a240078b5 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/generate.py +++ b/packages/jupyter-ai/jupyter_ai/actors/generate.py @@ -1,22 +1,20 @@ import json import os -import time -from uuid import uuid4 +from typing import Dict, Type import ray from ray.util.queue import Queue from langchain.llms import BaseLLM -from langchain.chat_models import ChatOpenAI from langchain.prompts import PromptTemplate from langchain.llms import BaseLLM from langchain.chains import LLMChain import nbformat -from jupyter_ai.models import AgentChatMessage, HumanChatMessage +from jupyter_ai.models import HumanChatMessage from jupyter_ai.actors.base import BaseActor, Logger -from jupyter_ai_magics.providers import ChatOpenAINewProvider +from jupyter_ai_magics.providers import BaseProvider, ChatOpenAINewProvider schema = """{ "$schema": "http://json-schema.org/draft-07/schema#", @@ -67,8 +65,6 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: def generate_outline(description, llm=None, verbose=False): """Generate an outline of sections given a description of a notebook.""" - if llm is None: - llm = ChatOpenAINewProvider(model_id='gpt-3.5-turbo') chain = NotebookOutlineChain.from_llm(llm=llm, verbose=verbose) outline = chain.predict(description=description, schema=schema) return json.loads(outline) @@ -125,8 +121,6 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: def generate_code(outline, llm=None, verbose=False): """Generate source code for a section given a description of the notebook and section.""" - if llm is None: - llm = ChatOpenAINewProvider(model_id='gpt-3.5-turbo') chain = NotebookSectionCodeChain.from_llm(llm=llm, verbose=verbose) code_so_far = [] for section in outline['sections']: @@ -177,8 +171,6 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: def generate_title_and_summary(outline, llm=None, verbose=False): """Generate a title and summary of a notebook outline using an LLM.""" - if llm is None: - llm = ChatOpenAINewProvider(model_id='gpt-3.5-turbo') summary_chain = NotebookSummaryChain.from_llm(llm=llm, verbose=verbose) title_chain = NotebookTitleChain.from_llm(llm=llm, verbose=verbose) summary = summary_chain.predict(content=outline) @@ -210,9 +202,16 @@ class GenerateActor(BaseActor): def __init__(self, reply_queue: Queue, root_dir: str, log: Logger): super().__init__(log=log, reply_queue=reply_queue) self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) - self.llm = ChatOpenAINewProvider(model_id='gpt-3.5-turbo') + self.llm = None + + def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): + llm = provider(**provider_params) + self.llm = llm + return llm def _process_message(self, message: HumanChatMessage): + self.get_llm_chain() + response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions." self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/actors/learn.py b/packages/jupyter-ai/jupyter_ai/actors/learn.py index c32c1b054..d65002238 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/learn.py +++ b/packages/jupyter-ai/jupyter_ai/actors/learn.py @@ -1,7 +1,8 @@ +import json import os -import traceback -from collections import Counter import argparse +import time +from typing import List import ray from ray.util.queue import Queue @@ -9,44 +10,51 @@ from jupyter_core.paths import jupyter_data_dir from langchain import FAISS -from langchain.embeddings.openai import OpenAIEmbeddings from langchain.text_splitter import ( RecursiveCharacterTextSplitter, PythonCodeTextSplitter, MarkdownTextSplitter, LatexTextSplitter ) +from langchain.schema import Document -from jupyter_ai.models import HumanChatMessage +from jupyter_ai.models import HumanChatMessage, IndexedDir, IndexMetadata from jupyter_ai.actors.base import BaseActor, Logger -from jupyter_ai_magics.providers import ChatOpenAINewProvider from jupyter_ai.document_loaders.directory import RayRecursiveDirectoryLoader from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter +INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), 'jupyter_ai', 'indices') +METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, 'metadata.json') + @ray.remote class LearnActor(BaseActor): def __init__(self, reply_queue: Queue, log: Logger, root_dir: str): super().__init__(reply_queue=reply_queue, log=log) self.root_dir = root_dir - self.index_save_dir = os.path.join(jupyter_data_dir(), 'jupyter_ai', 'indices') self.chunk_size = 2000 self.chunk_overlap = 100 self.parser.prog = '/learn' self.parser.add_argument('-v', '--verbose', action='store_true') self.parser.add_argument('-d', '--delete', action='store_true') + self.parser.add_argument('-l', '--list', action='store_true') self.parser.add_argument('path', nargs=argparse.REMAINDER) self.index_name = 'default' self.index = None - - if ChatOpenAINewProvider.auth_strategy.name not in os.environ: - return + self.metadata = IndexMetadata(dirs=[]) - if not os.path.exists(self.index_save_dir): - os.makedirs(self.index_save_dir) - - self.load_or_create() + if not os.path.exists(INDEX_SAVE_DIR): + os.makedirs(INDEX_SAVE_DIR) + + self.load_or_create() def _process_message(self, message: HumanChatMessage): + if not self.index: + self.load_or_create() + + # If index is not still there, embeddings are not present + if not self.index: + self.reply("Sorry, please select an embedding provider before using the `/learn` command.") + args = self.parse_args(message) if args is None: return @@ -55,6 +63,10 @@ def _process_message(self, message: HumanChatMessage): self.delete() self.reply(f"👍 I have deleted everything I previously learned.", message) return + + if args.list: + self.reply(self._build_list_response()) + return # Make sure the path exists. if not len(args.path) == 1: @@ -70,6 +82,24 @@ def _process_message(self, message: HumanChatMessage): if args.verbose: self.reply(f"Loading and splitting files for {load_path}", message) + self.learn_dir(load_path) + self.save() + + response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. + You can ask questions about these docs by prefixing your message with **/ask**.""" + self.reply(response, message) + + def _build_list_response(self): + if not self.metadata.dirs: + return "There are no docs that have been learned yet." + + dirs = [dir.path for dir in self.metadata.dirs] + dir_list = "\n- " + "\n- ".join(dirs) + "\n\n" + message = f"""I can answer questions from docs in these directories: + {dir_list}""" + return message + + def learn_dir(self, path: str): splitters={ '.py': PythonCodeTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap), '.md': MarkdownTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap), @@ -81,39 +111,94 @@ def _process_message(self, message: HumanChatMessage): default_splitter=RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) ) - loader = RayRecursiveDirectoryLoader(load_path) - texts = loader.load_and_split(text_splitter=splitter) + loader = RayRecursiveDirectoryLoader(path) + texts = loader.load_and_split(text_splitter=splitter) self.index.add_documents(texts) - self.save() - - response = f"""🎉 I have indexed documents at **{load_path}** and I am ready to answer questions about them. - You can ask questions about these docs by prefixing your message with **/ask**.""" - self.reply(response, message) - - def get_index(self): - return self.index + self._add_dir_to_metadata(path) + + def _add_dir_to_metadata(self, path: str): + dirs = self.metadata.dirs + index = next((i for i, dir in enumerate(dirs) if dir.path == path), None) + if not index: + dirs.append(IndexedDir(path=path)) + self.metadata.dirs = dirs + + def delete_and_relearn(self): + if not self.metadata.dirs: + self.delete() + return + message = """🔔 Hi there, It seems like you have updated the embeddings model. For the **/ask** + command to work with the new model, I have to re-learn the documents you had previously + submitted for learning. Please wait to use the **/ask** command until I am done with this task.""" + self.reply(message) + + metadata = self.metadata + self.delete() + self.relearn(metadata) def delete(self): self.index = None - paths = [os.path.join(self.index_save_dir, self.index_name+ext) for ext in ['.pkl', '.faiss']] + self.metadata = IndexMetadata(dirs=[]) + paths = [os.path.join(INDEX_SAVE_DIR, self.index_name+ext) for ext in ['.pkl', '.faiss']] for path in paths: if os.path.isfile(path): os.remove(path) self.create() - + + def relearn(self, metadata: IndexMetadata): + # Index all dirs in the metadata + if not metadata.dirs: + return + + for dir in metadata.dirs: + self.learn_dir(dir.path) + + self.save() + + dir_list = "\n- " + "\n- ".join([dir.path for dir in self.metadata.dirs]) + "\n\n" + message = f"""🎉 I am done learning docs in these directories: + {dir_list} I am ready to answer questions about them. + You can ask questions about these docs by prefixing your message with **/ask**.""" + self.reply(message) + def create(self): - embeddings = OpenAIEmbeddings() + embeddings = self.get_embeddings() + if not embeddings: + return self.index = FAISS.from_texts(["Jupyter AI knows about your filesystem, to ask questions first use the /learn command."], embeddings) self.save() def save(self): if self.index is not None: - self.index.save_local(self.index_save_dir, index_name=self.index_name) + self.index.save_local(INDEX_SAVE_DIR, index_name=self.index_name) + + self.save_metadata() + + def save_metadata(self): + with open(METADATA_SAVE_PATH, 'w') as f: + f.write(self.metadata.json()) def load_or_create(self): - embeddings = OpenAIEmbeddings() + embeddings = self.get_embeddings() + if not embeddings: + return if self.index is None: try: - self.index = FAISS.load_local(self.index_save_dir, embeddings, index_name=self.index_name) + self.index = FAISS.load_local(INDEX_SAVE_DIR, embeddings, index_name=self.index_name) + self.load_metadata() except Exception as e: self.create() + + def load_metadata(self): + if not os.path.exists(METADATA_SAVE_PATH): + return + + with open(METADATA_SAVE_PATH, 'r', encoding='utf-8') as f: + j = json.loads(f.read()) + self.metadata = IndexMetadata(**j) + + def get_relevant_documents(self, question: str) -> List[Document]: + if self.index: + docs = self.index.similarity_search(question) + return docs + return [] diff --git a/packages/jupyter-ai/jupyter_ai/actors/providers.py b/packages/jupyter-ai/jupyter_ai/actors/providers.py new file mode 100644 index 000000000..fd249ede9 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/actors/providers.py @@ -0,0 +1,42 @@ +from typing import Optional, Tuple, Type +from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider +from jupyter_ai_magics.providers import BaseProvider +from jupyter_ai_magics.utils import decompose_model_id, load_embedding_providers, load_providers +import ray +from jupyter_ai.actors.base import BaseActor, Logger +from ray.util.queue import Queue + +@ray.remote +class ProvidersActor(): + """Actor that loads model and embedding providers from, + entry points. Also provides utility functions to get the + providers and provider class matching a provider id. + """ + + def __init__(self, log: Logger): + self.log = log + self.model_providers = load_providers(log=log) + self.embeddings_providers = load_embedding_providers(log=log) + + def get_model_providers(self): + """Returns dictionary of registered LLM providers""" + return self.model_providers + + def get_model_provider_data(self, model_id: str) -> Tuple[str, Type[BaseProvider]]: + """Returns the model provider class that matches the provider id""" + provider_id, local_model_id = decompose_model_id(model_id, self.model_providers) + provider = self.model_providers.get(provider_id, None) + return local_model_id, provider + + def get_embeddings_providers(self): + """Returns dictionary of registered embedding providers""" + return self.embeddings_providers + + def get_embeddings_provider_data(self, model_id: str) -> Tuple[str, Type[BaseEmbeddingsProvider]]: + """Returns the embedding provider class that matches the provider id""" + provider_id, local_model_id = decompose_model_id(model_id, self.embeddings_providers) + provider = self.embeddings_providers.get(provider_id, None) + return local_model_id, provider + + + \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/router.py b/packages/jupyter-ai/jupyter_ai/actors/router.py index 7b417a0cd..fbc3234da 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/router.py +++ b/packages/jupyter-ai/jupyter_ai/actors/router.py @@ -2,7 +2,6 @@ from ray.util.queue import Queue from jupyter_ai.actors.base import ACTOR_TYPE, COMMANDS, Logger, BaseActor -from jupyter_ai.models import ClearMessage @ray.remote class Router(BaseActor): @@ -25,7 +24,7 @@ def _process_message(self, message): actor = ray.get_actor(COMMANDS[command].value) actor.process_message.remote(message) if command == '/clear': - reply_message = ClearMessage() - self.reply_queue.put(reply_message) + actor = ray.get_actor(ACTOR_TYPE.DEFAULT) + actor.clear_memory.remote() else: default.process_message.remote(message) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index e4b22a1f5..837bb68f8 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,6 +1,12 @@ import asyncio -import os -import queue + +from jupyter_ai.actors.chat_provider import ChatProviderActor +from jupyter_ai.actors.config import ConfigActor +from jupyter_ai.actors.embeddings_provider import EmbeddingsProviderActor +from jupyter_ai.actors.providers import ProvidersActor + +from jupyter_ai_magics.utils import load_providers + from langchain.memory import ConversationBufferWindowMemory from jupyter_ai.actors.default import DefaultActor from jupyter_ai.actors.ask import AskActor @@ -11,24 +17,37 @@ from jupyter_ai.actors.base import ACTOR_TYPE from jupyter_ai.reply_processor import ReplyProcessor from jupyter_server.extension.application import ExtensionApp -from .handlers import ChatHandler, ChatHistoryHandler, PromptAPIHandler, TaskAPIHandler + +from .handlers import ( + ChatHandler, + ChatHistoryHandler, + EmbeddingsModelProviderHandler, + ModelProviderHandler, + PromptAPIHandler, + TaskAPIHandler, + GlobalConfigHandler +) + from importlib_metadata import entry_points import inspect from .engine import BaseModelEngine -from jupyter_ai_magics.providers import ChatOpenAINewProvider, ChatOpenAIProvider import ray from ray.util.queue import Queue +from jupyter_ai_magics.utils import load_providers class AiExtension(ExtensionApp): name = "jupyter_ai" handlers = [ + ("api/ai/config", GlobalConfigHandler), ("api/ai/prompt", PromptAPIHandler), (r"api/ai/tasks/?", TaskAPIHandler), (r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler), (r"api/ai/chats/?", ChatHandler), (r"api/ai/chats/history?", ChatHistoryHandler), + (r"api/ai/providers?", ModelProviderHandler), + (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), ] @property @@ -91,35 +110,46 @@ def initialize_settings(self): self.settings["ai_default_tasks"] = default_tasks self.log.info("Registered all default tasks.") - if ChatOpenAINewProvider.auth_strategy.name not in os.environ: - raise EnvironmentError(f"`{ChatOpenAINewProvider.auth_strategy.name}` value not set in environment. For chat to work, this value should be provided.") - - ## load OpenAI provider - self.settings["openai_chat"] = ChatOpenAIProvider(model_id="gpt-3.5-turbo") + providers = load_providers(log=self.log) + self.settings["chat_providers"] = providers + self.log.info("Registered providers.") self.log.info(f"Registered {self.name} server extension") - # Add a message queue to the settings to be used by the chat handler - self.settings["chat_message_queue"] = queue.Queue() - # Store chat clients in a dictionary self.settings["chat_clients"] = {} self.settings["chat_handlers"] = {} # store chat messages in memory for now + # this is only used to render the UI, and is not the conversational + # memory object used by the LM chain. self.settings["chat_history"] = [] reply_queue = Queue() self.settings["reply_queue"] = reply_queue - router = Router.options(name="router").remote( + router = Router.options(name=ACTOR_TYPE.ROUTER).remote( reply_queue=reply_queue, - log=self.log + log=self.log, ) default_actor = DefaultActor.options(name=ACTOR_TYPE.DEFAULT.value).remote( reply_queue=reply_queue, - log=self.log + log=self.log, + chat_history=self.settings["chat_history"] + ) + + providers_actor = ProvidersActor.options(name=ACTOR_TYPE.PROVIDERS.value).remote( + log=self.log, + ) + config_actor = ConfigActor.options(name=ACTOR_TYPE.CONFIG.value).remote( + log=self.log, + ) + chat_provider_actor = ChatProviderActor.options(name=ACTOR_TYPE.CHAT_PROVIDER.value).remote( + log=self.log, + ) + embeddings_provider_actor = EmbeddingsProviderActor.options(name=ACTOR_TYPE.EMBEDDINGS_PROVIDER.value).remote( + log=self.log, ) learn_actor = LearnActor.options(name=ACTOR_TYPE.LEARN.value).remote( reply_queue=reply_queue, @@ -128,19 +158,23 @@ def initialize_settings(self): ) ask_actor = AskActor.options(name=ACTOR_TYPE.ASK.value).remote( reply_queue=reply_queue, - log=self.log + log=self.log, ) memory_actor = MemoryActor.options(name=ACTOR_TYPE.MEMORY.value).remote( log=self.log, - memory=ConversationBufferWindowMemory(return_messages=True, k=2) + memory=ConversationBufferWindowMemory(return_messages=True, k=2), ) generate_actor = GenerateActor.options(name=ACTOR_TYPE.GENERATE.value).remote( reply_queue=reply_queue, log=self.log, - root_dir=self.settings['server_root_dir'] + root_dir=self.settings['server_root_dir'], ) self.settings['router'] = router + self.settings['providers_actor'] = providers_actor + self.settings['config_actor'] = config_actor + self.settings['chat_provider_actor'] = chat_provider_actor + self.settings['embeddings_provider_actor'] = embeddings_provider_actor self.settings["default_actor"] = default_actor self.settings["learn_actor"] = learn_actor self.settings["ask_actor"] = ask_actor diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 0bcfd8a62..24db73b2c 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -1,6 +1,7 @@ from dataclasses import asdict import json from typing import Dict, List +from jupyter_ai.actors.base import ACTOR_TYPE import ray import tornado import uuid @@ -16,7 +17,22 @@ from jupyter_server.utils import ensure_async from .task_manager import TaskManager -from .models import ChatHistory, PromptRequest, ChatRequest, ChatMessage, Message, AgentChatMessage, HumanChatMessage, ConnectionMessage, ChatClient, ChatUser + +from .models import ( + ChatHistory, + ChatUser, + ListProvidersEntry, + ListProvidersResponse, + PromptRequest, + ChatRequest, + ChatMessage, + Message, + AgentChatMessage, + HumanChatMessage, + ConnectionMessage, + ChatClient, + GlobalConfig +) class APIHandler(BaseAPIHandler): @@ -36,10 +52,6 @@ def task_manager(self): self.settings["task_manager"] = TaskManager(engines=self.engines, default_tasks=self.default_tasks) return self.settings["task_manager"] - @property - def openai_chat(self): - return self.settings["openai_chat"] - class PromptAPIHandler(APIHandler): @tornado.web.authenticated async def post(self): @@ -105,15 +117,6 @@ class ChatHandler( """ A websocket handler for chat. """ - - _chat_provider = None - _chat_message_queue = None - - @property - def chat_message_queue(self): - if self._chat_message_queue is None: - self._chat_message_queue = self.settings["chat_message_queue"] - return self._chat_message_queue @property def chat_handlers(self) -> Dict[str, 'ChatHandler']: @@ -254,3 +257,94 @@ def on_close(self): self.log.info(f"Client disconnected. ID: {self.client_id}") self.log.debug("Chat clients: %s", self.chat_handlers.keys()) + + +class ModelProviderHandler(BaseAPIHandler): + @property + def chat_providers(self): + actor = ray.get_actor("providers") + o = actor.get_model_providers.remote() + return ray.get(o) + + @web.authenticated + def get(self): + providers = [] + for provider in self.chat_providers.values(): + # skip old legacy OpenAI chat provider used only in magics + if provider.id == "openai-chat": + continue + + providers.append( + ListProvidersEntry( + id=provider.id, + name=provider.name, + models=provider.models, + auth_strategy=provider.auth_strategy + ) + ) + + response = ListProvidersResponse(providers=sorted(providers, key=lambda p: p.name)) + self.finish(response.json()) + + +class EmbeddingsModelProviderHandler(BaseAPIHandler): + + @property + def embeddings_providers(self): + actor = ray.get_actor("providers") + o = actor.get_embeddings_providers.remote() + return ray.get(o) + + @web.authenticated + def get(self): + providers = [] + for provider in self.embeddings_providers.values(): + providers.append( + ListProvidersEntry( + id=provider.id, + name=provider.name, + models=provider.models, + auth_strategy=provider.auth_strategy + ) + ) + + response = ListProvidersResponse(providers=sorted(providers, key=lambda p: p.name)) + self.finish(response.json()) + + +class GlobalConfigHandler(BaseAPIHandler): + """API handler for fetching and setting the + model and emebddings config. + """ + + @web.authenticated + def get(self): + actor = ray.get_actor(ACTOR_TYPE.CONFIG) + config = ray.get(actor.get_config.remote()) + if not config: + raise HTTPError(500, "No config found.") + + self.finish(config.json()) + + @web.authenticated + def post(self): + try: + config = GlobalConfig(**self.get_json_body()) + actor = ray.get_actor(ACTOR_TYPE.CONFIG) + ray.get(actor.update.remote(config)) + + self.set_status(204) + self.finish() + + except ValidationError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except ValueError as e: + self.log.exception(e) + raise HTTPError(500, str(e.cause) if hasattr(e, 'cause') else str(e)) + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred while updating the config." + ) from e + diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 7da9aa47f..a9f28768a 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,3 +1,5 @@ +from jupyter_ai_magics.providers import AuthStrategy + from pydantic import BaseModel from typing import Dict, List, Union, Literal, Optional @@ -80,3 +82,28 @@ class DescribeTaskResponse(BaseModel): class ChatHistory(BaseModel): """History of chat messages""" messages: List[ChatMessage] + + +class ListProvidersEntry(BaseModel): + """Model provider with supported models + and provider's authentication strategy + """ + id: str + name: str + models: List[str] + auth_strategy: AuthStrategy + + +class ListProvidersResponse(BaseModel): + providers: List[ListProvidersEntry] + +class IndexedDir(BaseModel): + path: str + +class IndexMetadata(BaseModel): + dirs: List[IndexedDir] + +class GlobalConfig(BaseModel): + model_provider_id: Optional[str] = None + embeddings_provider_id: Optional[str] = None + api_keys: Dict[str, str] = {} diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx new file mode 100644 index 000000000..d36a3bf63 --- /dev/null +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -0,0 +1,266 @@ +import React, { useEffect, useState } from 'react'; +import { Box } from '@mui/system'; +import { + Alert, + Button, + MenuItem, + TextField, + CircularProgress +} from '@mui/material'; + +import { Select } from './select'; +import { AiService } from '../handler'; + +enum ChatSettingsState { + // chat settings is making initial fetches + Loading, + // chat settings is ready (happy path) + Ready, + // chat settings failed to make initial fetches + FetchError, + // chat settings failed to submit the save request + SubmitError, + // chat settings successfully submitted the save request + Success +} + +export function ChatSettings() { + const [state, setState] = useState( + ChatSettingsState.Loading + ); + // error message from initial fetch + const [fetchEmsg, setFetchEmsg] = useState(); + + // state fetched on initial render + const [config, setConfig] = useState(); + const [lmProviders, setLmProviders] = + useState(); + const [emProviders, setEmProviders] = + useState(); + + // user inputs + const [inputConfig, setInputConfig] = useState({ + model_provider_id: null, + embeddings_provider_id: null, + api_keys: {} + }); + + // whether the form is currently saving + const [saving, setSaving] = useState(false); + // error message from submission + const [saveEmsg, setSaveEmsg] = useState(); + + /** + * Effect: call APIs on initial render + */ + useEffect(() => { + async function getConfig() { + try { + const [config, lmProviders, emProviders] = await Promise.all([ + AiService.getConfig(), + AiService.listLmProviders(), + AiService.listEmProviders() + ]); + setConfig(config); + setInputConfig(config); + setLmProviders(lmProviders); + setEmProviders(emProviders); + setState(ChatSettingsState.Ready); + } catch (e) { + console.error(e); + if (e instanceof Error) { + setFetchEmsg(e.message); + } + setState(ChatSettingsState.FetchError); + } + } + getConfig(); + }, []); + + /** + * Effect: re-initialize API keys object whenever the selected LM/EM changes. + */ + useEffect(() => { + const selectedLmpId = inputConfig.model_provider_id?.split(':')[0]; + const selectedEmpId = inputConfig.embeddings_provider_id?.split(':')[0]; + const lmp = lmProviders?.providers.find( + provider => provider.id === selectedLmpId + ); + const emp = emProviders?.providers.find( + provider => provider.id === selectedEmpId + ); + const newApiKeys: Record = {}; + + if (lmp?.auth_strategy && lmp.auth_strategy.type === 'env') { + newApiKeys[lmp.auth_strategy.name] = + config?.api_keys[lmp.auth_strategy.name] || ''; + } + if (emp?.auth_strategy && emp.auth_strategy.type === 'env') { + newApiKeys[emp.auth_strategy.name] = + config?.api_keys[emp.auth_strategy.name] || ''; + } + + setInputConfig(inputConfig => ({ + ...inputConfig, + api_keys: { ...config?.api_keys, ...newApiKeys } + })); + }, [inputConfig.model_provider_id, inputConfig.embeddings_provider_id]); + + const handleSave = async () => { + const inputConfigCopy: AiService.Config = { + ...inputConfig, + api_keys: { ...inputConfig.api_keys } + }; + + // delete any empty api keys + for (const apiKey in inputConfigCopy.api_keys) { + if (inputConfigCopy.api_keys[apiKey] === '') { + delete inputConfigCopy.api_keys[apiKey]; + } + } + + setSaving(true); + try { + await AiService.updateConfig(inputConfigCopy); + } catch (e) { + console.error(e); + if (e instanceof Error) { + setSaveEmsg(e.message); + } + setState(ChatSettingsState.SubmitError); + } + setState(ChatSettingsState.Success); + setSaving(false); + }; + + if (state === ChatSettingsState.Loading) { + return ( + + + + ); + } + + if ( + state === ChatSettingsState.FetchError || + !lmProviders || + !emProviders || + !config + ) { + return ( + + + {fetchEmsg + ? `An error occurred. Error details:\n\n${fetchEmsg}` + : 'An unknown error occurred. Check the console for more details.'} + + + ); + } + + return ( + .MuiAlert-root': { marginBottom: 2 } + }} + > + {state === ChatSettingsState.SubmitError && ( + + {saveEmsg + ? `An error occurred. Error details:\n\n${saveEmsg}` + : 'An unknown error occurred. Check the console for more details.'} + + )} + {state === ChatSettingsState.Success && ( + Settings saved successfully. + )} + + + {Object.entries(inputConfig.api_keys).map( + ([apiKey, apiKeyValue], idx) => ( + + setInputConfig(inputConfig => ({ + ...inputConfig, + api_keys: { + ...inputConfig.api_keys, + [apiKey]: e.target.value + } + })) + } + /> + ) + )} + + + + + ); +} diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index 0e68dad13..eaae48ebc 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -1,10 +1,14 @@ import React, { useState, useEffect } from 'react'; import { Box } from '@mui/system'; +import { Button, IconButton, Stack } from '@mui/material'; +import SettingsIcon from '@mui/icons-material/Settings'; +import ArrowBackIcon from '@mui/icons-material/ArrowBack'; import type { Awareness } from 'y-protocols/awareness'; import { JlThemeProvider } from './jl-theme-provider'; import { ChatMessages } from './chat-messages'; import { ChatInput } from './chat-input'; +import { ChatSettings } from './chat-settings'; import { AiService } from '../handler'; import { SelectionContextProvider, @@ -17,10 +21,12 @@ import { ScrollContainer } from './scroll-container'; type ChatBodyProps = { chatHandler: ChatHandler; + setChatView: (view: ChatView) => void }; -function ChatBody({ chatHandler }: ChatBodyProps): JSX.Element { +function ChatBody({ chatHandler, setChatView: chatViewHandler }: ChatBodyProps): JSX.Element { const [messages, setMessages] = useState([]); + const [showWelcomeMessage, setShowWelcomeMessage] = useState(false); const [includeSelection, setIncludeSelection] = useState(true); const [replaceSelection, setReplaceSelection] = useState(false); const [input, setInput] = useState(''); @@ -32,12 +38,17 @@ function ChatBody({ chatHandler }: ChatBodyProps): JSX.Element { useEffect(() => { async function fetchHistory() { try { - const history = await chatHandler.getHistory(); + const [history, config] = await Promise.all([ + chatHandler.getHistory(), + AiService.getConfig() + ]); setMessages(history.messages); + if (!config.model_provider_id) { + setShowWelcomeMessage(true); + } } catch (e) { - + console.error(e); } - } fetchHistory(); @@ -71,7 +82,9 @@ function ChatBody({ chatHandler }: ChatBodyProps): JSX.Element { const prompt = input + - (includeSelection && selection?.text ? '\n\n```\n' + selection.text + '```': ''); + (includeSelection && selection?.text + ? '\n\n```\n' + selection.text + '```' + : ''); // send message to backend const messageId = await chatHandler.sendMessage({ prompt }); @@ -90,23 +103,45 @@ function ChatBody({ chatHandler }: ChatBodyProps): JSX.Element { } }; + const openSettingsView = () => { + setShowWelcomeMessage(false) + chatViewHandler(ChatView.Settings) + } + + if (showWelcomeMessage) { + return ( + + +

+ Welcome to Jupyter AI! To get started, please select a language + model to chat with from the settings panel. You will also likely + need to provide API credentials, so be sure to have those handy. +

+ +
+
+ ); + } + return ( - + <> - {/* https://css-tricks.com/books/greatest-css-tricks/pin-scrolling-to-bottom/ */} - Press Shift + Enter to submit message} + helperText={ + + Press Shift + Enter to submit message + + } /> - + ); } @@ -138,14 +177,56 @@ export type ChatProps = { selectionWatcher: SelectionWatcher; chatHandler: ChatHandler; globalAwareness: Awareness | null; + chatView?: ChatView }; +enum ChatView { + Chat, + Settings +} + export function Chat(props: ChatProps) { + const [view, setView] = useState(props.chatView || ChatView.Chat); + return ( - + + {/* top bar */} + + {view !== ChatView.Chat ? ( + setView(ChatView.Chat)}> + + + ) : ( + + )} + {view === ChatView.Chat ? ( + setView(ChatView.Settings)}> + + + ) : ( + + )} + + {/* body */} + {view === ChatView.Chat && ( + + )} + {view === ChatView.Settings && } + diff --git a/packages/jupyter-ai/src/components/select.tsx b/packages/jupyter-ai/src/components/select.tsx new file mode 100644 index 000000000..2e709812d --- /dev/null +++ b/packages/jupyter-ai/src/components/select.tsx @@ -0,0 +1,48 @@ +import React from 'react'; +import { FormControl, InputLabel, Select as MuiSelect } from '@mui/material'; +import type { + SelectChangeEvent, + SelectProps as MuiSelectProps +} from '@mui/material'; + +export type SelectProps = Omit, 'value' | 'onChange'> & { + value: string | null; + onChange: ( + event: SelectChangeEvent, + child: React.ReactNode + ) => void; +}; + +/** + * A helpful wrapper around MUI's native `Select` component that provides the + * following services: + * + * - automatically wraps base `Select` component in `FormControl` context and + * prepends an input label derived from `props.label`. + * + * - limits max height of menu + * + * - handles `null` values by coercing them to the string `'null'`. The + * corresponding `MenuItem` should have the value `'null'`. + */ +export function Select(props: SelectProps) { + return ( + + {props.label} + { + if (e.target.value === 'null') { + e.target.value = null as any; + } + props.onChange?.(e, child); + }} + MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }} + > + {props.children} + + + ); +} diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 6ccad4545..55db58c49 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -99,11 +99,15 @@ export namespace AiService { }; export type ClearMessage = { - type: 'clear' - } + type: 'clear'; + }; export type ChatMessage = AgentChatMessage | HumanChatMessage; - export type Message = AgentChatMessage | HumanChatMessage | ConnectionMessage | ClearMessage; + export type Message = + | AgentChatMessage + | HumanChatMessage + | ConnectionMessage + | ClearMessage; export type ChatHistory = { messages: ChatMessage[]; @@ -160,4 +164,57 @@ export namespace AiService { ): Promise { return requestAPI(`tasks/${id}`); } + + export type Config = { + model_provider_id: string | null; + embeddings_provider_id: string | null; + api_keys: Record; + }; + + export type GetConfigResponse = Config; + + export type UpdateConfigRequest = Config; + + export async function getConfig(): Promise { + return requestAPI('config'); + } + + export type EnvAuthStrategy = { + type: 'env'; + name: string; + }; + + export type AwsAuthStrategy = { + type: 'aws'; + }; + + export type AuthStrategy = EnvAuthStrategy | AwsAuthStrategy | null; + + export type ListProvidersEntry = { + id: string; + name: string; + models: string[]; + auth_strategy: AuthStrategy; + }; + + export type ListProvidersResponse = { + providers: ListProvidersEntry[]; + }; + + export async function listLmProviders(): Promise { + return requestAPI('providers'); + } + + export async function listEmProviders(): Promise { + return requestAPI('providers/embeddings'); + } + + export async function updateConfig( + config: UpdateConfigRequest + ): Promise { + return requestAPI('config', { + method: 'POST', + body: JSON.stringify(config) + }); + } }