Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure OpenAI and OpenAI proxy support #322

Merged
merged 2 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ see the following interface:
Each of the additional fields under "Language model" is required. These fields
should contain the following data:

- **Local model ID**: The name of your endpoint. This can be retrieved from the
- **Endpoint name**: The name of your endpoint. This can be retrieved from the
AWS Console at the URL
`https://<region>.console.aws.amazon.com/sagemaker/home?region=<region>#/endpoints`.

Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .providers import (
AI21Provider,
AnthropicProvider,
AzureChatOpenAIProvider,
BaseProvider,
BedrockProvider,
ChatOpenAINewProvider,
Expand Down
50 changes: 46 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from jsonpath_ng import parse
from langchain import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.llms import (
AI21,
Anthropic,
Expand Down Expand Up @@ -100,6 +100,9 @@ class Config:
model_id_key: ClassVar[str] = ...
"""Kwarg expected by the upstream LangChain provider."""

model_id_label: ClassVar[str] = ""
"""Human-readable label of the model ID."""

pypi_package_deps: ClassVar[List[str]] = []
"""List of PyPi package dependencies."""

Expand Down Expand Up @@ -415,6 +418,40 @@ class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

fields = [
TextField(
key="openai_api_base", label="Base API URL (optional)", format="text"
),
TextField(
key="openai_organization", label="Organization (optional)", format="text"
),
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]


class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
id = "azure-chat-openai"
name = "Azure OpenAI"
models = ["*"]
model_id_key = "deployment_name"
model_id_label = "Deployment name"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
registry = True

fields = [
TextField(
key="openai_api_base", label="Base API URL (required)", format="text"
),
TextField(
key="openai_api_version", label="API version (required)", format="text"
),
TextField(
key="openai_organization", label="Organization (optional)", format="text"
),
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]


class JsonContentHandler(LLMContentHandler):
content_type = "application/json"
Expand Down Expand Up @@ -452,6 +489,7 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
name = "SageMaker endpoint"
models = ["*"]
model_id_key = "endpoint_name"
model_id_label = "Endpoint name"
# This all needs to be on one line of markdown, for use in a table
help = (
"Specify an endpoint name as the model ID. "
Expand All @@ -464,9 +502,13 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
auth_strategy = AwsAuthStrategy()
registry = True
fields = [
TextField(key="region_name", label="Region name", format="text"),
MultilineTextField(key="request_schema", label="Request schema", format="json"),
TextField(key="response_path", label="Response path", format="jsonpath"),
TextField(key="region_name", label="Region name (required)", format="text"),
MultilineTextField(
key="request_schema", label="Request schema (required)", format="json"
),
TextField(
key="response_path", label="Response path (required)", format="jsonpath"
),
]

def __init__(self, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ huggingface_hub = "jupyter_ai_magics:HfHubProvider"
openai = "jupyter_ai_magics:OpenAIProvider"
openai-chat = "jupyter_ai_magics:ChatOpenAIProvider"
openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider"
azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider"
sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"
amazon-bedrock = "jupyter_ai_magics:BedrockProvider"

Expand Down
15 changes: 12 additions & 3 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from asyncio import AbstractEventLoop
from dataclasses import asdict
from typing import Dict, List
from typing import TYPE_CHECKING, Dict, List

import tornado
from jupyter_ai.chat_handlers import BaseChatHandler
Expand All @@ -29,6 +29,10 @@
Message,
)

if TYPE_CHECKING:
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
from jupyter_ai_magics.providers import BaseProvider


class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""
Expand Down Expand Up @@ -237,7 +241,7 @@ def on_close(self):

class ModelProviderHandler(BaseAPIHandler):
@property
def lm_providers(self):
def lm_providers(self) -> Dict[str, "BaseProvider"]:
return self.settings["lm_providers"]

@web.authenticated
Expand All @@ -248,6 +252,10 @@ def get(self):
if provider.id == "openai-chat":
continue

optionals = {}
if provider.model_id_label:
optionals["model_id_label"] = provider.model_id_label

providers.append(
ListProvidersEntry(
id=provider.id,
Expand All @@ -256,6 +264,7 @@ def get(self):
auth_strategy=provider.auth_strategy,
registry=provider.registry,
fields=provider.fields,
**optionals,
)
)

Expand All @@ -267,7 +276,7 @@ def get(self):

class EmbeddingsModelProviderHandler(BaseAPIHandler):
@property
def em_providers(self):
def em_providers(self) -> Dict[str, "BaseEmbeddingsProvider"]:
return self.settings["em_providers"]

@web.authenticated
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class ListProvidersEntry(BaseModel):

id: str
name: str
model_id_label: Optional[str]
models: List[str]
auth_strategy: AuthStrategy
registry: bool
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/src/components/chat-settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ export function ChatSettings(): JSX.Element {
</Select>
{showLmLocalId && (
<TextField
label="Local model ID"
label={lmProvider?.model_id_label || 'Local model ID'}
value={lmLocalId}
onChange={e => setLmLocalId(e.target.value)}
fullWidth
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ export namespace AiService {
export type ListProvidersEntry = {
id: string;
name: string;
model_id_label?: string;
models: string[];
auth_strategy: AuthStrategy;
registry: boolean;
Expand Down
Loading