Skip to content

Commit

Permalink
Fix typing jupyter-ai codebase (mostly)
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski authored and dlqqq committed Sep 12, 2024
1 parent b5318a9 commit e556dc1
Show file tree
Hide file tree
Showing 14 changed files with 53 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class InlineCompletionItem(BaseModel):

class CompletionError(BaseModel):
type: str
title: str
traceback: str


Expand Down
Empty file.
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def process_message(self, message: HumanChatMessage):

try:
with self.pending("Searching learned documents", message):
assert self.llm_chain
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.reply(response, message)
Expand Down
27 changes: 14 additions & 13 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Optional,
Type,
Union,
cast,
)
from uuid import uuid4

Expand All @@ -28,6 +29,7 @@
)
from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain
from langchain.pydantic_v1 import BaseModel

if TYPE_CHECKING:
Expand All @@ -36,8 +38,8 @@
from langchain_core.chat_history import BaseChatMessageHistory


def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]:
if preferred_dir != "":
def get_preferred_dir(root_dir: str, preferred_dir: Optional[str]) -> Optional[str]:
if preferred_dir is not None and preferred_dir != "":
preferred_dir = os.path.expanduser(preferred_dir)
if not preferred_dir.startswith(root_dir):
preferred_dir = os.path.join(root_dir, preferred_dir)
Expand All @@ -47,7 +49,7 @@ def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]:

# Chat handler type, with specific attributes for each
class HandlerRoutingType(BaseModel):
routing_method: ClassVar[Union[Literal["slash_command"]]] = ...
routing_method: ClassVar[Union[Literal["slash_command"]]]
"""The routing method that sends commands to this handler."""


Expand Down Expand Up @@ -83,17 +85,17 @@ class BaseChatHandler:
multiple chat handler classes."""

# Class attributes
id: ClassVar[str] = ...
id: ClassVar[str]
"""ID for this chat handler; should be unique"""

name: ClassVar[str] = ...
name: ClassVar[str]
"""User-facing name of this handler"""

help: ClassVar[str] = ...
help: ClassVar[str]
"""What this chat handler does, which third-party models it contacts,
the data it returns to the user, and so on, for display in the UI."""

routing_type: ClassVar[HandlerRoutingType] = ...
routing_type: ClassVar[HandlerRoutingType]

uses_llm: ClassVar[bool] = True
"""Class attribute specifying whether this chat handler uses the LLM
Expand Down Expand Up @@ -153,9 +155,9 @@ def __init__(
self.help_message_template = help_message_template
self.chat_handlers = chat_handlers

self.llm = None
self.llm_params = None
self.llm_chain = None
self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
self.llm_chain: Optional[LLMChain] = None

async def on_message(self, message: HumanChatMessage):
"""
Expand All @@ -168,9 +170,8 @@ async def on_message(self, message: HumanChatMessage):

# ensure the current slash command is supported
if self.routing_type.routing_method == "slash_command":
slash_command = (
"/" + self.routing_type.slash_id if self.routing_type.slash_id else ""
)
routing_type = cast(SlashCommandRoutingType, self.routing_type)
slash_command = "/" + routing_type.slash_id if routing_type.slash_id else ""
if slash_command in lm_provider_klass.unsupported_slash_commands:
self.reply(
"Sorry, the selected language model does not support this slash command."
Expand Down
7 changes: 4 additions & 3 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def create_llm_chain(
prompt_template = llm.get_chat_prompt_template()
self.llm = llm

runnable = prompt_template | llm
runnable = prompt_template | llm # type:ignore
if not llm.manages_history:
runnable = RunnableWithMessageHistory(
runnable=runnable,
runnable=runnable, # type:ignore[arg-type]
get_session_history=self.get_llm_chat_memory,
input_messages_key="input",
history_messages_key="history",
Expand Down Expand Up @@ -106,6 +106,7 @@ async def process_message(self, message: HumanChatMessage):
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
assert self.llm_chain
async for chunk in self.llm_chain.astream(
{"input": message.body},
config={"configurable": {"last_human_msg": message}},
Expand All @@ -117,7 +118,7 @@ async def process_message(self, message: HumanChatMessage):
stream_id = self._start_stream(human_msg=message)
received_first_chunk = True

if isinstance(chunk, AIMessageChunk):
if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ async def process_message(self, message: HumanChatMessage):

self.get_llm_chain()
with self.pending("Analyzing error", message):
assert self.llm_chain
response = await self.llm_chain.apredict(
extra_instructions=extra_instructions,
stop=["\nHuman:"],
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class GenerateChatHandler(BaseChatHandler):
def __init__(self, log_dir: Optional[str], *args, **kwargs):
super().__init__(*args, **kwargs)
self.log_dir = Path(log_dir) if log_dir else None
self.llm = None
self.llm: Optional[BaseProvider] = None

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
Expand All @@ -248,6 +248,7 @@ async def _generate_notebook(self, prompt: str):
# Save the user input prompt, the description property is now LLM generated.
outline["prompt"] = prompt

assert self.llm
if self.llm.allows_concurrency:
# fill the outline concurrently
await afill_outline(outline, llm=self.llm, verbose=True)
Expand Down
14 changes: 9 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ async def learn_dir(
}
splitter = ExtensionSplitter(
splitters=splitters,
default_splitter=RecursiveCharacterTextSplitter(**splitter_kwargs),
default_splitter=RecursiveCharacterTextSplitter(
**splitter_kwargs # type:ignore[arg-type]
),
)

delayed = split(path, all_files, splitter=splitter)
Expand Down Expand Up @@ -352,7 +354,7 @@ async def aget_relevant_documents(
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
if not self.index:
return []
return [] # type:ignore[return-value]

await self.delete_and_relearn()
docs = self.index.similarity_search(query)
Expand All @@ -370,12 +372,14 @@ def get_embedding_model(self):


class Retriever(BaseRetriever):
learn_chat_handler: LearnChatHandler = None
learn_chat_handler: LearnChatHandler = None # type:ignore[assignment]

def _get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents( # type:ignore[override]
self, query: str
) -> List[Document]:
raise NotImplementedError()

async def _aget_relevant_documents(
async def _aget_relevant_documents( # type:ignore[override]
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
docs = await self.learn_chat_handler.aget_relevant_documents(query)
Expand Down
8 changes: 4 additions & 4 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
import time
from typing import List, Optional, Union
from typing import List, Optional, Type, Union

from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
Expand Down Expand Up @@ -60,7 +60,7 @@ class BlockedModelError(Exception):
pass


def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider):
def _validate_provider_authn(config: GlobalConfig, provider: Type[AnyProvider]):
# TODO: handle non-env auth strategies
if not provider.auth_strategy or provider.auth_strategy.type != "env":
return
Expand Down Expand Up @@ -147,7 +147,7 @@ def _init_config_schema(self):
os.makedirs(os.path.dirname(self.schema_path), exist_ok=True)
shutil.copy(OUR_SCHEMA_PATH, self.schema_path)

def _init_validator(self) -> Validator:
def _init_validator(self) -> None:
with open(OUR_SCHEMA_PATH, encoding="utf-8") as f:
schema = json.loads(f.read())
Validator.check_schema(schema)
Expand Down Expand Up @@ -364,7 +364,7 @@ def delete_api_key(self, key_name: str):
config_dict["api_keys"].pop(key_name, None)
self._write_config(GlobalConfig(**config_dict))

def update_config(self, config_update: UpdateConfigRequest):
def update_config(self, config_update: UpdateConfigRequest): # type:ignore
last_write = os.stat(self.config_path).st_mtime_ns
if config_update.last_read and config_update.last_read < last_write:
raise WriteConflictError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def arxiv_to_text(id: str, output_dir: str) -> str:
output path to the downloaded TeX file
"""

import arxiv
import arxiv # type:ignore[import-not-found]

outfile = f"{id}-{datetime.now():%Y-%m-%d-%H-%M}.tex"
download_filename = "downloaded-paper.tar.gz"
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

class AiExtension(ExtensionApp):
name = "jupyter_ai"
handlers = [
handlers = [ # type:ignore[assignment]
(r"api/ai/api_keys/(?P<api_key_name>\w+)", ApiKeysHandler),
(r"api/ai/config/?", GlobalConfigHandler),
(r"api/ai/chats/?", RootChatHandler),
Expand Down
10 changes: 4 additions & 6 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from asyncio import AbstractEventLoop
from dataclasses import asdict
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, cast

import tornado
from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType
Expand Down Expand Up @@ -42,14 +42,12 @@
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
from jupyter_ai_magics.providers import BaseProvider

from .history import BoundChatHistory
from .history import BoundedChatHistory


class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""

_messages = []

@property
def chat_history(self) -> List[ChatMessage]:
return self.settings["chat_history"]
Expand Down Expand Up @@ -103,7 +101,7 @@ def chat_history(self, new_history):
self.settings["chat_history"] = new_history

@property
def llm_chat_memory(self) -> "BoundChatHistory":
def llm_chat_memory(self) -> "BoundedChatHistory":
return self.settings["llm_chat_memory"]

@property
Expand Down Expand Up @@ -401,7 +399,7 @@ def filter_predicate(local_model_id: str):
if self.blocked_models:
return model_id not in self.blocked_models
else:
return model_id in self.allowed_models
return model_id in cast(List, self.allowed_models)

# filter out every model w/ model ID according to allow/blocklist
for provider in providers:
Expand Down
4 changes: 2 additions & 2 deletions packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
_all_messages: List[BaseMessage] = PrivateAttr(default_factory=list)

@property
def messages(self) -> List[BaseMessage]:
def messages(self) -> List[BaseMessage]: # type:ignore[override]
if self.k is None:
return self._all_messages
return self._all_messages[-self.k * 2 :]
Expand Down Expand Up @@ -92,7 +92,7 @@ class WrappedBoundedChatHistory(BaseChatMessageHistory, BaseModel):
last_human_msg: HumanChatMessage

@property
def messages(self) -> List[BaseMessage]:
def messages(self) -> List[BaseMessage]: # type:ignore[override]
return self.history.messages

def add_message(self, message: BaseMessage) -> None:
Expand Down
18 changes: 9 additions & 9 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ class ChatClient(ChatUser):
id: str


class AgentChatMessage(BaseModel):
type: Literal["agent"] = "agent"
class BaseAgentMessage(BaseModel):
id: str
time: float
body: str
Expand All @@ -89,7 +88,11 @@ class AgentChatMessage(BaseModel):
"""


class AgentStreamMessage(AgentChatMessage):
class AgentChatMessage(BaseAgentMessage):
type: Literal["agent"] = "agent"


class AgentStreamMessage(BaseAgentMessage):
type: Literal["agent-stream"] = "agent-stream"
complete: bool
# other attrs inherited from `AgentChatMessage`
Expand Down Expand Up @@ -138,15 +141,13 @@ class PendingMessage(BaseModel):


class ClosePendingMessage(BaseModel):
type: Literal["pending"] = "close-pending"
type: Literal["close-pending"] = "close-pending"
id: str


# the type of messages being broadcast to clients
ChatMessage = Union[
AgentChatMessage,
HumanChatMessage,
AgentStreamMessage,
AgentChatMessage, HumanChatMessage, AgentStreamMessage, AgentStreamChunkMessage
]


Expand All @@ -164,8 +165,7 @@ class ConnectionMessage(BaseModel):


Message = Union[
AgentChatMessage,
HumanChatMessage,
ChatMessage,
ConnectionMessage,
ClearMessage,
PendingMessage,
Expand Down

0 comments on commit e556dc1

Please sign in to comment.