Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Aug 7, 2023
1 parent 04a1814 commit 0a65fcb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
32 changes: 24 additions & 8 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from jsonpath_ng import parse
from langchain import PromptTemplate
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.llms import (
AI21,
Anthropic,
Expand Down Expand Up @@ -419,11 +419,16 @@ class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

fields = [
TextField(key="openai_api_base", label="Base API URL (optional)", format="text"),
TextField(key="openai_organization", label="Organization (optional)", format="text"),
TextField(
key="openai_api_base", label="Base API URL (optional)", format="text"
),
TextField(
key="openai_organization", label="Organization (optional)", format="text"
),
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]


class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
id = "azure-chat-openai"
name = "Azure OpenAI"
Expand All @@ -435,12 +440,19 @@ class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
registry = True

fields = [
TextField(key="openai_api_base", label="Base API URL (required)", format="text"),
TextField(key="openai_api_version", label="API version (required)", format="text"),
TextField(key="openai_organization", label="Organization (optional)", format="text"),
TextField(
key="openai_api_base", label="Base API URL (required)", format="text"
),
TextField(
key="openai_api_version", label="API version (required)", format="text"
),
TextField(
key="openai_organization", label="Organization (optional)", format="text"
),
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]


class JsonContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
Expand Down Expand Up @@ -491,8 +503,12 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
registry = True
fields = [
TextField(key="region_name", label="Region name (required)", format="text"),
MultilineTextField(key="request_schema", label="Request schema (required)", format="json"),
TextField(key="response_path", label="Response path (required)", format="jsonpath"),
MultilineTextField(
key="request_schema", label="Request schema (required)", format="json"
),
TextField(
key="response_path", label="Response path (required)", format="jsonpath"
),
]

def __init__(self, *args, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
)

if TYPE_CHECKING:
from jupyter_ai_magics.providers import BaseProvider
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
from jupyter_ai_magics.providers import BaseProvider


class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""
Expand Down

0 comments on commit 0a65fcb

Please sign in to comment.