forked from jupyterlab/jupyter-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Runtime model configurability (jupyterlab#146)
* Refactored provider load, decompose logic, aded model provider list api * Renamed model * Sorted the provider names * WIP: Embedding providers * Added embeddings provider api * Added missing import * Moved providers to ray actor, added config actor * Ability to load llm and embeddings from config * Moved llm creation to specific actors * Added apis for fetching, updating config. Fixed config update, error handling * Updated as per PR feedback * Fixes issue with cohere embeddings, api keys not working * Added an error check when embedding change causes read error * Delete and re-index docs when embedding model changes (jupyterlab#137) * Added an error check when embedding change causes read error * Refactored provider load, decompose logic, aded model provider list api * Re-indexes dirs when embeddings change, learn list command * Fixed typo, simplified adding metadata * Moved index dir, metadata path to constants * Chat settings UI (jupyterlab#141) * remove unused div * automatically create config if not present * allow all-caps envvars in config * implement basic chat settings UI * hide API key text inputs * limit popup size, show success banner * show welcome message if no LM is selected * fix buggy UI with no selected LM/EM * exclude legacy OpenAI chat provider used in magics * Added a button with welcome message --------- Co-authored-by: Jain <[email protected]> * Various chat chain enhancements and fixes (jupyterlab#144) * fix /clear command * use model IDs to compare LLMs instead * specify stop sequence in chat chain * add empty AI message, improve system prompt * add RTD configuration --------- Co-authored-by: Piyush Jain <[email protected]> Co-authored-by: Jain <[email protected]>
- Loading branch information
Showing
26 changed files
with
1,323 additions
and
184 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# .readthedocs.yaml | ||
# Read the Docs configuration file | ||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details | ||
|
||
version: 2 | ||
|
||
build: | ||
os: ubuntu-22.04 | ||
tools: | ||
python: "3.11" | ||
|
||
sphinx: | ||
configuration: docs/source/conf.py | ||
|
||
python: | ||
install: | ||
- requirements: docs/requirements.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
MODEL_ID_ALIASES = { | ||
"gpt2": "huggingface_hub:gpt2", | ||
"gpt3": "openai:text-davinci-003", | ||
"chatgpt": "openai-chat:gpt-3.5-turbo", | ||
"gpt4": "openai-chat:gpt-4", | ||
} |
75 changes: 75 additions & 0 deletions
75
packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import ClassVar, List, Type | ||
from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy | ||
from pydantic import BaseModel, Extra | ||
from langchain.embeddings import OpenAIEmbeddings, CohereEmbeddings, HuggingFaceHubEmbeddings | ||
from langchain.embeddings.base import Embeddings | ||
|
||
|
||
class BaseEmbeddingsProvider(BaseModel): | ||
"""Base class for embedding providers""" | ||
|
||
class Config: | ||
extra = Extra.allow | ||
|
||
id: ClassVar[str] = ... | ||
"""ID for this provider class.""" | ||
|
||
name: ClassVar[str] = ... | ||
"""User-facing name of this provider.""" | ||
|
||
models: ClassVar[List[str]] = ... | ||
"""List of supported models by their IDs. For registry providers, this will | ||
be just ["*"].""" | ||
|
||
model_id_key: ClassVar[str] = ... | ||
"""Kwarg expected by the upstream LangChain provider.""" | ||
|
||
pypi_package_deps: ClassVar[List[str]] = [] | ||
"""List of PyPi package dependencies.""" | ||
|
||
auth_strategy: ClassVar[AuthStrategy] = None | ||
"""Authentication/authorization strategy. Declares what credentials are | ||
required to use this model provider. Generally should not be `None`.""" | ||
|
||
model_id: str | ||
|
||
provider_klass: ClassVar[Type[Embeddings]] | ||
|
||
|
||
class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider): | ||
id = "openai" | ||
name = "OpenAI" | ||
models = [ | ||
"text-embedding-ada-002" | ||
] | ||
model_id_key = "model" | ||
pypi_package_deps = ["openai"] | ||
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") | ||
provider_klass = OpenAIEmbeddings | ||
|
||
|
||
class CohereEmbeddingsProvider(BaseEmbeddingsProvider): | ||
id = "cohere" | ||
name = "Cohere" | ||
models = [ | ||
'large', | ||
'multilingual-22-12', | ||
'small' | ||
] | ||
model_id_key = "model" | ||
pypi_package_deps = ["cohere"] | ||
auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY") | ||
provider_klass = CohereEmbeddings | ||
|
||
|
||
class HfHubEmbeddingsProvider(BaseEmbeddingsProvider): | ||
id = "huggingface_hub" | ||
name = "HuggingFace Hub" | ||
models = ["*"] | ||
model_id_key = "repo_id" | ||
# ipywidgets needed to suppress tqdm warning | ||
# https://stackoverflow.com/questions/67998191 | ||
# tqdm is a dependency of huggingface_hub | ||
pypi_package_deps = ["huggingface_hub", "ipywidgets"] | ||
auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN") | ||
provider_klass = HuggingFaceHubEmbeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import logging | ||
from typing import Dict, Optional, Tuple, Union | ||
from importlib_metadata import entry_points | ||
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES | ||
|
||
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider | ||
|
||
from jupyter_ai_magics.providers import BaseProvider | ||
|
||
|
||
Logger = Union[logging.Logger, logging.LoggerAdapter] | ||
|
||
|
||
def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]: | ||
if not log: | ||
log = logging.getLogger() | ||
log.addHandler(logging.NullHandler()) | ||
|
||
providers = {} | ||
eps = entry_points() | ||
model_provider_eps = eps.select(group="jupyter_ai.model_providers") | ||
for model_provider_ep in model_provider_eps: | ||
try: | ||
provider = model_provider_ep.load() | ||
except: | ||
log.error(f"Unable to load model provider class from entry point `{model_provider_ep.name}`.") | ||
continue | ||
providers[provider.id] = provider | ||
log.info(f"Registered model provider `{provider.id}`.") | ||
|
||
return providers | ||
|
||
|
||
def load_embedding_providers(log: Optional[Logger] = None) -> Dict[str, BaseEmbeddingsProvider]: | ||
if not log: | ||
log = logging.getLogger() | ||
log.addHandler(logging.NullHandler()) | ||
providers = {} | ||
eps = entry_points() | ||
model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers") | ||
for model_provider_ep in model_provider_eps: | ||
try: | ||
provider = model_provider_ep.load() | ||
except: | ||
log.error(f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`.") | ||
continue | ||
providers[provider.id] = provider | ||
log.info(f"Registered embeddings model provider `{provider.id}`.") | ||
|
||
return providers | ||
|
||
def decompose_model_id(model_id: str, providers: Dict[str, BaseProvider]) -> Tuple[str, str]: | ||
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" | ||
if model_id in MODEL_ID_ALIASES: | ||
model_id = MODEL_ID_ALIASES[model_id] | ||
|
||
if ":" not in model_id: | ||
# case: model ID was not provided with a prefix indicating the provider | ||
# ID. try to infer the provider ID before returning (None, None). | ||
|
||
# naively search through the dictionary and return the first provider | ||
# that provides a model of the same ID. | ||
for provider_id, provider in providers.items(): | ||
if model_id in provider.models: | ||
return (provider_id, model_id) | ||
|
||
return (None, None) | ||
|
||
provider_id, local_model_id = model_id.split(":", 1) | ||
return (provider_id, local_model_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.