diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/slash_command.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/slash_command.py index f82bd5531..d8f7a1443 100644 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/slash_command.py +++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/slash_command.py @@ -1,6 +1,11 @@ from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType from jupyter_ai.models import HumanChatMessage +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + class TestSlashCommand(BaseChatHandler): """ @@ -25,5 +30,5 @@ class TestSlashCommand(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: YChat): self.reply("This is the `/test` slash command.") diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py b/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py index f82bd5531..f4ae3dcfb 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py @@ -1,6 +1,11 @@ from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType from jupyter_ai.models import HumanChatMessage +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + class TestSlashCommand(BaseChatHandler): """ @@ -25,5 +30,5 @@ class TestSlashCommand(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, message: HumanChatMessage): - self.reply("This is the `/test` slash command.") + async def process_message(self, message: HumanChatMessage, chat: YChat): + self.reply("This is the `/test` slash command.", chat) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py index a8fe9eb50..a46b74e4d 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py @@ -1,3 +1,8 @@ +# The following import is to make sure jupyter_ydoc is imported before +# jupyterlab_collaborative_chat, otherwise it leads to circular import because of the +# YChat relying on YBaseDoc, and jupyter_ydoc registering YChat from the entry point. +import jupyter_ydoc + from .ask import AskChatHandler from .base import BaseChatHandler, SlashCommandRoutingType from .clear import ClearChatHandler diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index b5c4fa38b..56228b43e 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -1,5 +1,5 @@ import argparse -from typing import Dict, Type +from typing import Dict, Optional, Type from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider @@ -7,6 +7,11 @@ from langchain.memory import ConversationBufferWindowMemory from langchain_core.prompts import PromptTemplate +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from .base import BaseChatHandler, SlashCommandRoutingType PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. @@ -59,19 +64,19 @@ def create_llm_chain( verbose=False, ) - async def process_message(self, message: HumanChatMessage): - args = self.parse_args(message) + async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): + args = self.parse_args(message, chat) if args is None: return query = " ".join(args.query) if not query: - self.reply(f"{self.parser.format_usage()}", message) + self.reply(f"{self.parser.format_usage()}", chat, message) return self.get_llm_chain() try: - with self.pending("Searching learned documents", message): + with self.pending("Searching learned documents", message, chat=chat): assert self.llm_chain # TODO: migrate this class to use a LCEL `Runnable` instead of # `Chain`, then remove the below ignore comment. @@ -79,7 +84,7 @@ async def process_message(self, message: HumanChatMessage): {"question": query} ) response = result["answer"] - self.reply(response, message) + self.reply(response, chat, message) except AssertionError as e: self.log.error(e) response = """Sorry, an error occurred while reading the from the learned documents. @@ -87,4 +92,4 @@ async def process_message(self, message: HumanChatMessage): `/learn -d` command and then re-submitting the `learn ` to learn the documents, and then asking the question again. """ - self.reply(response, message) + self.reply(response, chat, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index c844650ad..853335cbb 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -8,6 +8,7 @@ TYPE_CHECKING, Any, Awaitable, + Callable, ClassVar, Dict, List, @@ -43,6 +44,11 @@ from langchain_core.runnables.config import merge_configs as merge_runnable_configs from langchain_core.runnables.utils import Input +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + if TYPE_CHECKING: from jupyter_ai.context_providers import BaseCommandContextProvider from jupyter_ai.handlers import RootChatHandler @@ -156,6 +162,7 @@ def __init__( chat_handlers: Dict[str, "BaseChatHandler"], context_providers: Dict[str, "BaseCommandContextProvider"], message_interrupted: Dict[str, asyncio.Event], + write_message: Callable[[YChat, str, Optional[str]], None], ): self.log = log self.config_manager = config_manager @@ -183,7 +190,9 @@ def __init__( self.llm_params: Optional[dict] = None self.llm_chain: Optional[Runnable] = None - async def on_message(self, message: HumanChatMessage): + self.write_message = write_message + + async def on_message(self, message: HumanChatMessage, chat: Optional[YChat] = None): """ Method which receives a human message, calls `self.get_llm_chain()`, and processes the message via `self.process_message()`, calling @@ -198,7 +207,8 @@ async def on_message(self, message: HumanChatMessage): slash_command = "/" + routing_type.slash_id if routing_type.slash_id else "" if slash_command in lm_provider_klass.unsupported_slash_commands: self.reply( - "Sorry, the selected language model does not support this slash command." + "Sorry, the selected language model does not support this slash command.", + chat, ) return @@ -210,6 +220,7 @@ async def on_message(self, message: HumanChatMessage): if not lm_provider.allows_concurrency: self.reply( "The currently selected language model can process only one request at a time. Please wait for me to reply before sending another question.", + chat, message, ) return @@ -217,43 +228,47 @@ async def on_message(self, message: HumanChatMessage): BaseChatHandler._requests_count += 1 if self.__class__.supports_help: - args = self.parse_args(message, silent=True) + args = self.parse_args(message, chat, silent=True) if args and args.help: - self.reply(self.parser.format_help(), message) + self.reply(self.parser.format_help(), chat, message) return try: - await self.process_message(message) + await self.process_message(message, chat) except Exception as e: try: # we try/except `handle_exc()` in case it was overriden and # raises an exception by accident. - await self.handle_exc(e, message) + await self.handle_exc(e, message, chat) except Exception as e: - await self._default_handle_exc(e, message) + await self._default_handle_exc(e, message, chat) finally: BaseChatHandler._requests_count -= 1 - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): """ Processes a human message routed to this chat handler. Chat handlers (subclasses) must implement this method. Don't forget to call - `self.reply(, message)` at the end! + `self.reply(, chat, message)` at the end! The method definition does not need to be wrapped in a try/except block; any exceptions raised here are caught by `self.handle_exc()`. """ raise NotImplementedError("Should be implemented by subclasses.") - async def handle_exc(self, e: Exception, message: HumanChatMessage): + async def handle_exc( + self, e: Exception, message: HumanChatMessage, chat: Optional[YChat] + ): """ Handles an exception raised by `self.process_message()`. A default implementation is provided, however chat handlers (subclasses) should implement this method to provide a more helpful error response. """ - await self._default_handle_exc(e, message) + await self._default_handle_exc(e, message, chat) - async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): + async def _default_handle_exc( + self, e: Exception, message: HumanChatMessage, chat: Optional[YChat] + ): """ The default definition of `handle_exc()`. This is the default used when the `handle_exc()` excepts. @@ -263,13 +278,13 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): if lm_provider and lm_provider.is_api_key_exc(e): provider_name = getattr(self.config_manager.lm_provider, "name", "") response = f"Oops! There's a problem connecting to {provider_name}. Please update your {provider_name} API key in the chat settings." - self.reply(response, message) + self.reply(response, chat, message) return formatted_e = traceback.format_exc() response = ( f"Sorry, an error occurred. Details below:\n\n```\n{formatted_e}\n```" ) - self.reply(response, message) + self.reply(response, chat, message) def broadcast_message(self, message: Message): """ @@ -291,20 +306,27 @@ def broadcast_message(self, message: Message): cast(ChatMessage, message) self._chat_history.append(message) - def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): + def reply( + self, + response: str, + chat: Optional[YChat], + human_msg: Optional[HumanChatMessage] = None, + ): """ Sends an agent message, usually in response to a received `HumanChatMessage`. """ - agent_msg = AgentChatMessage( - id=uuid4().hex, - time=time.time(), - body=response, - reply_to=human_msg.id if human_msg else "", - persona=self.persona, - ) - - self.broadcast_message(agent_msg) + if chat is not None: + self.write_message(chat, response, None) + else: + agent_msg = AgentChatMessage( + id=uuid4().hex, + time=time.time(), + body=response, + reply_to=human_msg.id if human_msg else "", + persona=self.persona, + ) + self.broadcast_message(agent_msg) @property def persona(self): @@ -315,6 +337,7 @@ def start_pending( text: str, human_msg: Optional[HumanChatMessage] = None, *, + chat: Optional[YChat] = None, ellipsis: bool = True, ) -> PendingMessage: """ @@ -333,10 +356,13 @@ def start_pending( ellipsis=ellipsis, ) - self.broadcast_message(pending_msg) + if chat is not None: + chat.awareness.set_local_state_field("isWriting", True) + else : + self.broadcast_message(pending_msg) return pending_msg - def close_pending(self, pending_msg: PendingMessage): + def close_pending(self, pending_msg: PendingMessage, chat: Optional[YChat] = None): """ Closes a pending message. """ @@ -347,7 +373,10 @@ def close_pending(self, pending_msg: PendingMessage): id=pending_msg.id, ) - self.broadcast_message(close_pending_msg) + if chat is not None: + chat.awareness.set_local_state_field("isWriting", False) + else: + self.broadcast_message(close_pending_msg) pending_msg.closed = True @contextlib.contextmanager @@ -356,18 +385,22 @@ def pending( text: str, human_msg: Optional[HumanChatMessage] = None, *, + chat: Optional[YChat] = None, ellipsis: bool = True, ): """ Context manager that sends a pending message to the client, and closes it after the block is executed. + + TODO: Simplify it by only modifying the awareness as soon as collaborative chat + is the only used chat. """ - pending_msg = self.start_pending(text, human_msg=human_msg, ellipsis=ellipsis) + pending_msg = self.start_pending(text, human_msg=human_msg, chat=chat, ellipsis=ellipsis) try: yield pending_msg finally: if not pending_msg.closed: - self.close_pending(pending_msg) + self.close_pending(pending_msg, chat=chat) def get_llm_chain(self): lm_provider = self.config_manager.lm_provider @@ -409,14 +442,14 @@ def create_llm_chain( ): raise NotImplementedError("Should be implemented by subclasses") - def parse_args(self, message, silent=False): + def parse_args(self, message, chat, silent=False): args = message.body.split(" ") try: args = self.parser.parse_args(args[1:]) except (argparse.ArgumentError, SystemExit) as e: if not silent: response = f"{self.parser.format_usage()}" - self.reply(response, message) + self.reply(response, chat, message) return None return args @@ -439,7 +472,9 @@ def output_dir(self) -> str: else: return self.root_dir - def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> None: + def send_help_message( + self, chat: Optional[YChat], human_msg: Optional[HumanChatMessage] = None + ) -> None: """Sends a help message to all connected clients.""" lm_provider = self.config_manager.lm_provider unsupported_slash_commands = ( @@ -470,38 +505,46 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non slash_commands_list=slash_commands_list, context_commands_list=context_commands_list, ) - help_message = AgentChatMessage( - id=uuid4().hex, - time=time.time(), - body=help_message_body, - reply_to=human_msg.id if human_msg else "", - persona=self.persona, - ) - self.broadcast_message(help_message) + if chat is not None: + self.write_message(chat, help_message_body, None) + else: + help_message = AgentChatMessage( + id=uuid4().hex, + time=time.time(), + body=help_message_body, + reply_to=human_msg.id if human_msg else "", + persona=self.persona, + ) + self.broadcast_message(help_message) - def _start_stream(self, human_msg: HumanChatMessage) -> str: + def _start_stream(self, human_msg: HumanChatMessage, chat: Optional[YChat]) -> str | None: """ Sends an `agent-stream` message to indicate the start of a response stream. Returns the ID of the message, denoted as the `stream_id`. """ - stream_id = uuid4().hex - stream_msg = AgentStreamMessage( - id=stream_id, - time=time.time(), - body="", - reply_to=human_msg.id, - persona=self.persona, - complete=False, - ) + if chat is not None: + stream_id = self.write_message(chat, "", None) + else: + stream_id = uuid4().hex + stream_msg = AgentStreamMessage( + id=stream_id, + time=time.time(), + body="", + reply_to=human_msg.id, + persona=self.persona, + complete=False, + ) + + self.broadcast_message(stream_msg) - self.broadcast_message(stream_msg) return stream_id def _send_stream_chunk( self, - stream_id: str, + stream_id: str | None, content: str, + chat: Optional[YChat], complete: bool = False, metadata: Optional[Dict[str, Any]] = None, ) -> None: @@ -509,18 +552,22 @@ def _send_stream_chunk( Sends an `agent-stream-chunk` message containing content that should be appended to an existing `agent-stream` message with ID `stream_id`. """ - if not metadata: - metadata = {} + if chat is not None: + self.write_message(chat, content, stream_id) + else: + if not metadata: + metadata = {} - stream_chunk_msg = AgentStreamChunkMessage( - id=stream_id, content=content, stream_complete=complete, metadata=metadata - ) - self.broadcast_message(stream_chunk_msg) + stream_chunk_msg = AgentStreamChunkMessage( + id=stream_id, content=content, stream_complete=complete, metadata=metadata + ) + self.broadcast_message(stream_chunk_msg) async def stream_reply( self, input: Input, human_msg: HumanChatMessage, + chat: Optional[YChat], pending_msg="Generating response", config: Optional[RunnableConfig] = None, ): @@ -555,7 +602,7 @@ async def stream_reply( merged_config: RunnableConfig = merge_runnable_configs(base_config, config) # start with a pending message - with self.pending(pending_msg, human_msg) as pending_message: + with self.pending(pending_msg, human_msg, chat=chat) as pending_message: # stream response in chunks. this works even if a provider does not # implement streaming, as `astream()` defaults to yielding `_call()` # when `_stream()` is not implemented on the LLM class. @@ -565,8 +612,8 @@ async def stream_reply( if not received_first_chunk: # when receiving the first chunk, close the pending message and # start the stream. - self.close_pending(pending_message) - stream_id = self._start_stream(human_msg=human_msg) + self.close_pending(pending_message, chat=chat) + stream_id = self._start_stream(human_msg=human_msg, chat=chat) received_first_chunk = True self.message_interrupted[stream_id] = asyncio.Event() @@ -589,9 +636,9 @@ async def stream_reply( break if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str): - self._send_stream_chunk(stream_id, chunk.content) + self._send_stream_chunk(stream_id, chunk.content, chat=chat) elif isinstance(chunk, str): - self._send_stream_chunk(stream_id, chunk) + self._send_stream_chunk(stream_id, chunk, chat=chat) else: self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") break @@ -603,6 +650,7 @@ async def stream_reply( self._send_stream_chunk( stream_id, stream_tombstone, + chat=chat, complete=True, metadata=metadata_handler.jai_metadata, ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index d5b0ab6c7..1d378009d 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -1,5 +1,12 @@ +from typing import Optional + from jupyter_ai.models import ClearRequest +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from .base import BaseChatHandler, SlashCommandRoutingType @@ -16,7 +23,7 @@ class ClearChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, _): + async def process_message(self, _, chat: Optional[YChat]): # Clear chat by triggering `RootChatHandler.on_clear_request()`. for handler in self._root_chat_handlers.values(): if not handler: diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 266ad73ad..0e2765adb 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,11 +1,16 @@ import asyncio -from typing import Dict, Type +from typing import Dict, Optional, Type from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider from langchain_core.runnables import ConfigurableFieldSpec from langchain_core.runnables.history import RunnableWithMessageHistory +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from ..context_providers import ContextProviderException, find_commands from .base import BaseChatHandler, SlashCommandRoutingType @@ -53,7 +58,7 @@ def create_llm_chain( ) self.llm_chain = runnable - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): self.get_llm_chain() assert self.llm_chain @@ -68,7 +73,7 @@ async def process_message(self, message: HumanChatMessage): inputs["context"] = context_prompt inputs["input"] = self.replace_prompt(inputs["input"]) - await self.stream_reply(inputs, message) + await self.stream_reply(inputs, message, chat=chat) async def make_context_prompt(self, human_msg: HumanChatMessage) -> str: return "\n\n".join( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py index ed478f57e..d335a2b21 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py @@ -1,10 +1,15 @@ import argparse import os from datetime import datetime -from typing import List +from typing import List, Optional from jupyter_ai.models import AgentChatMessage, HumanChatMessage +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from .base import BaseChatHandler, SlashCommandRoutingType @@ -31,11 +36,11 @@ def chat_message_to_markdown(self, message): return "" # Write the chat history to a markdown file with a timestamp - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): markdown_content = "\n\n".join( self.chat_message_to_markdown(msg) for msg in self._chat_history ) - args = self.parse_args(message) + args = self.parse_args(message, chat) chat_filename = ( # if no filename, use "chat_history" + timestamp args.path[0] if (args.path and args.path[0] != "") @@ -46,4 +51,4 @@ async def process_message(self, message: HumanChatMessage): ) # Do not use timestamp if filename is entered as argument with open(chat_file, "w") as chat_history: chat_history.write(markdown_content) - self.reply(f"File saved to `{chat_file}`") + self.reply(f"File saved to `{chat_file}`", chat) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index 390b93cf6..8e0b765a4 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -1,9 +1,14 @@ -from typing import Dict, Type +from typing import Dict, Optional, Type from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider from langchain.prompts import PromptTemplate +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from .base import BaseChatHandler, SlashCommandRoutingType FIX_STRING_TEMPLATE = """ @@ -79,10 +84,11 @@ def create_llm_chain( runnable = prompt_template | llm # type:ignore self.llm_chain = runnable - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): if not (message.selection and message.selection.type == "cell-with-error"): self.reply( "`/fix` requires an active code cell with error output. Please click on a cell with error output and retry.", + chat, message, ) return @@ -103,4 +109,4 @@ async def process_message(self, message: HumanChatMessage): "error_name": selection.error.name, "error_value": selection.error.value, } - await self.stream_reply(inputs, message, pending_msg="Analyzing error") + await self.stream_reply(inputs, message, pending_msg="Analyzing error", chat=chat) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index a69b5ed28..36273c78b 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -16,6 +16,11 @@ from langchain.schema.output_parser import BaseOutputParser from langchain_core.prompts import PromptTemplate +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + class OutlineSection(BaseModel): title: str @@ -262,18 +267,20 @@ async def _generate_notebook(self, prompt: str): nbformat.write(notebook, final_path) return final_path - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): self.get_llm_chain() # first send a verification message to user response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions." - self.reply(response, message) + self.reply(response, chat, message) final_path = await self._generate_notebook(prompt=message.body) response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it.""" - self.reply(response, message) + self.reply(response, chat, message) - async def handle_exc(self, e: Exception, message: HumanChatMessage): + async def handle_exc( + self, e: Exception, message: HumanChatMessage, chat: Optional[YChat] + ): timestamp = time.strftime("%Y-%m-%d-%H.%M.%S") default_log_dir = Path(self.output_dir) / "jupyter-ai-logs" log_dir = self.log_dir or default_log_dir @@ -283,4 +290,4 @@ async def handle_exc(self, e: Exception, message: HumanChatMessage): traceback.print_exc(file=log) response = f"An error occurred while generating the notebook. The error details have been saved to `./{log_path}`.\n\nTry running `/generate` again, as some language models require multiple attempts before a notebook is generated." - self.reply(response, message) + self.reply(response, chat, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index cd8556863..82b6e8607 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -1,5 +1,12 @@ +from typing import Optional + from jupyter_ai.models import HumanChatMessage +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from .base import BaseChatHandler, SlashCommandRoutingType @@ -15,5 +22,5 @@ class HelpChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def process_message(self, message: HumanChatMessage): - self.send_help_message(message) + async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): + self.send_help_message(chat, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index e0c6139c0..a74c147de 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -30,6 +30,11 @@ ) from langchain_community.vectorstores import FAISS +try: + from jupyterlab_collaborative_chat.ychat import YChat +except: + from typing import Any as YChat + from .base import BaseChatHandler, SlashCommandRoutingType INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), "jupyter_ai", "indices") @@ -128,26 +133,29 @@ def _load(self): ) self.log.error(e) - async def process_message(self, message: HumanChatMessage): + async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]): # If no embedding provider has been selected em_provider_cls, em_provider_args = self.get_embedding_provider() if not em_provider_cls: self.reply( - "Sorry, please select an embedding provider before using the `/learn` command." + "Sorry, please select an embedding provider before using the `/learn` command.", + chat, ) return - args = self.parse_args(message) + args = self.parse_args(message, chat) if args is None: return if args.delete: self.delete() - self.reply(f"👍 I have deleted everything I previously learned.", message) + self.reply( + f"👍 I have deleted everything I previously learned.", chat, message + ) return if args.list: - self.reply(self._build_list_response()) + self.reply(self._build_list_response(), chat) return if args.remote: @@ -158,19 +166,23 @@ async def process_message(self, message: HumanChatMessage): args.path = [arxiv_to_text(id, self.output_dir)] self.reply( f"Learning arxiv file with id **{id}**, saved in **{args.path[0]}**.", + chat, message, ) except ModuleNotFoundError as e: self.log.error(e) self.reply( - "No `arxiv` package found. " "Install with `pip install arxiv`." + "No `arxiv` package found. " + "Install with `pip install arxiv`.", + chat, ) return except Exception as e: self.log.error(e) self.reply( "An error occurred while processing the arXiv file. " - f"Please verify that the arxiv id {id} is correct." + f"Please verify that the arxiv id {id} is correct.", + chat, ) return @@ -186,7 +198,7 @@ async def process_message(self, message: HumanChatMessage): "- Learn on files in the root directory: `/learn *`\n" "- Learn all python files under the root directory recursively: `/learn **/*.py`" ) - self.reply(f"{self.parser.format_usage()}\n\n {no_path_arg_message}") + self.reply(f"{self.parser.format_usage()}\n\n {no_path_arg_message}", chat) return short_path = args.path[0] load_path = os.path.join(self.output_dir, short_path) @@ -196,13 +208,13 @@ async def process_message(self, message: HumanChatMessage): next(iglob(load_path)) except StopIteration: response = f"Sorry, that path doesn't exist: {load_path}" - self.reply(response, message) + self.reply(response, chat, message) return # delete and relearn index if embedding model was changed - await self.delete_and_relearn() + await self.delete_and_relearn(chat) - with self.pending(f"Loading and splitting files for {load_path}", message): + with self.pending(f"Loading and splitting files for {load_path}", message, chat=chat): try: await self.learn_dir( load_path, args.chunk_size, args.chunk_overlap, args.all_files @@ -218,7 +230,7 @@ async def process_message(self, message: HumanChatMessage): You can ask questions about these docs by prefixing your message with **/ask**.""" % ( load_path.replace("*", r"\*") ) - self.reply(response, message) + self.reply(response, chat, message) def _build_list_response(self): if not self.metadata.dirs: @@ -272,7 +284,7 @@ def _add_dir_to_metadata(self, path: str, chunk_size: int, chunk_overlap: int): ) self.metadata.dirs = dirs - async def delete_and_relearn(self): + async def delete_and_relearn(self, chat: Optional[YChat]=None): """Delete the vector store and relearn all indexed directories if necessary. If the embedding model is unchanged, this method does nothing.""" @@ -299,11 +311,11 @@ async def delete_and_relearn(self): documents you had previously submitted for learning. Please wait to use the **/ask** command until I am done with this task.""" - self.reply(message) + self.reply(message, chat) metadata = self.metadata self.delete() - await self.relearn(metadata) + await self.relearn(metadata, chat) self.prev_em_id = curr_em_id def delete(self): @@ -317,7 +329,7 @@ def delete(self): if os.path.isfile(path): os.remove(path) - async def relearn(self, metadata: IndexMetadata): + async def relearn(self, metadata: IndexMetadata, chat: Optional[YChat]): # Index all dirs in the metadata if not metadata.dirs: return @@ -337,7 +349,7 @@ async def relearn(self, metadata: IndexMetadata): message = f"""🎉 I am done learning docs in these directories: {dir_list} I am ready to answer questions about them. You can ask me about these documents by starting your message with **/ask**.""" - self.reply(message) + self.reply(message, chat) def create( self, diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 08c8c5a47..9ecdf1cda 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -2,13 +2,23 @@ import re import time import types +from typing import Optional +import uuid +from functools import partial from dask.distributed import Client as DaskClient from importlib_metadata import entry_points from jupyter_ai.chat_handlers.learn import Retriever +from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics import BaseProvider, JupyternautPersona from jupyter_ai_magics.utils import get_em_providers, get_lm_providers +from jupyter_collaboration import __version__ as jupyter_collaboration_version +from jupyter_collaboration.utils import JUPYTER_COLLABORATION_EVENTS_URI +from jupyter_events import EventLogger from jupyter_server.extension.application import ExtensionApp +from jupyter_server.utils import url_path_join +from jupyterlab_collaborative_chat.ychat import YChat +from pycrdt import ArrayEvent from tornado.web import StaticFileHandler from traitlets import Dict, Integer, List, Unicode @@ -43,6 +53,20 @@ ) +if int(jupyter_collaboration_version[0]) >= 3: + COLLAB_VERSION = 3 +else: + COLLAB_VERSION = 2 + +# The BOT currently has a fixed username, because this username is used has key in chats, +# it needs to constant. Do we need to change it ? +BOT = { + "username": '5f6a7570-7974-6572-6e61-75742d626f74', + "name": "Jupyternaut", + "display_name": "Jupyternaut", + "initials": "J" +} + DEFAULT_HELP_MESSAGE_TEMPLATE = """Hi there! I'm {persona_name}, your programming assistant. You can ask me a question using the text box below. You can also use these commands: {slash_commands_list} @@ -204,6 +228,122 @@ class AiExtension(ExtensionApp): config=True, ) + def initialize(self): + super().initialize() + self.event_logger = self.serverapp.web_app.settings["event_logger"] + self.event_logger.add_listener( + schema_id=JUPYTER_COLLABORATION_EVENTS_URI, listener=self.connect_chat + ) + + # Keep the message indexes to avoid extra computation looking for a message when + # updating it. + self.messages_indexes = {} + + async def connect_chat( + self, logger: EventLogger, schema_id: str, data: dict + ) -> None: + if ( + data["room"].startswith("text:chat:") + and data["action"] == "initialize" + and data["msg"] == "Room initialized" + ): + + self.log.info(f"Collaborative chat server is listening for {data['room']}") + chat = await self.get_chat(data["room"]) + + # Add the bot user to the chat document awareness. + BOT["avatar_url"] = url_path_join( + self.settings.get("base_url", "/"), "api/ai/static/jupyternaut.svg" + ) + chat.awareness.set_local_state_field("user", BOT) + + callback = partial(self.on_change, chat) + chat.ymessages.observe(callback) + + async def get_chat(self, room_id: str) -> YChat: + if COLLAB_VERSION == 3: + collaboration = self.serverapp.web_app.settings["jupyter_server_ydoc"] + document = await collaboration.get_document(room_id=room_id, copy=False) + else: + collaboration = self.serverapp.web_app.settings["jupyter_collaboration"] + server = collaboration.ywebsocket_server + + room = await server.get_room(room_id) + document = room._document + return document + + def on_change(self, chat: YChat, events: ArrayEvent) -> None: + for change in events.delta: + if not "insert" in change.keys(): + continue + messages = change["insert"] + for message in messages: + + if message["sender"] == BOT["username"] or message["raw_time"]: + continue + try: + chat_message = HumanChatMessage( + id=message["id"], + time=time.time(), + body=message["body"], + prompt="", + selection=None, + client=None, + ) + except Exception as e: + self.log.error(e) + self.serverapp.io_loop.asyncio_loop.create_task( + self._route(chat_message, chat) + ) + + async def _route(self, message: HumanChatMessage, chat: YChat): + """Method that routes an incoming message to the appropriate handler.""" + chat_handlers = self.settings["jai_chat_handlers"] + default = chat_handlers["default"] + # Split on any whitespace, either spaces or newlines + maybe_command = message.body.split(None, 1)[0] + is_command = ( + message.body.startswith("/") + and maybe_command in chat_handlers.keys() + and maybe_command != "default" + ) + command = maybe_command if is_command else "default" + + start = time.time() + if is_command: + await chat_handlers[command].on_message(message, chat) + else: + await default.on_message(message, chat) + + latency_ms = round((time.time() - start) * 1000) + command_readable = "Default" if command == "default" else command + self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.") + + def write_message(self, chat: YChat, body: str, id: Optional[str]=None) -> str: + bot = chat.get_user(BOT["username"]) + if not bot: + chat.set_user(BOT) + + index = self.messages_indexes[id] if id else None + id = id if id else str(uuid.uuid4()) + new_index = chat.set_message( + { + "type": "msg", + "body": body, + "id": id if id else str(uuid.uuid4()), + "time": time.time(), + "sender": BOT["username"], + "raw_time": False, + }, + index, + True, + ) + + if new_index != index: + self.messages_indexes[id] = new_index + + return id + def initialize_settings(self): start = time.time() @@ -320,7 +460,7 @@ def _show_help_message(self): default_chat_handler: DefaultChatHandler = self.settings["jai_chat_handlers"][ "default" ] - default_chat_handler.send_help_message() + default_chat_handler.send_help_message(None) async def _get_dask_client(self): return DaskClient(processes=False, asynchronous=True) @@ -365,6 +505,7 @@ def _init_chat_handlers(self): "preferred_dir": self.serverapp.contents_manager.preferred_dir, "help_message_template": self.help_message_template, "chat_handlers": chat_handlers, + "write_message": self.write_message, "context_providers": self.settings["jai_context_providers"], "message_interrupted": self.settings["jai_message_interrupted"], } diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 48dbe6193..8117a933e 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -140,7 +140,7 @@ class HumanChatMessage(BaseModel): """The prompt typed into the chat input by the user.""" selection: Optional[Selection] """The selection included with the prompt, if any.""" - client: ChatClient + client: Optional[ChatClient] class ClearMessage(BaseModel): diff --git a/packages/jupyter-ai/package.json b/packages/jupyter-ai/package.json index db71f52ae..ff56520df 100644 --- a/packages/jupyter-ai/package.json +++ b/packages/jupyter-ai/package.json @@ -61,6 +61,7 @@ "dependencies": { "@emotion/react": "^11.10.5", "@emotion/styled": "^11.10.5", + "@jupyter/chat": "^0.5.0", "@jupyter/collaboration": "^1", "@jupyterlab/application": "^4.2.0", "@jupyterlab/apputils": "^4.2.0", diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 88f5fc55f..177043594 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "typing_extensions>=4.5.0", "traitlets>=5.0", "deepmerge>=2.0,<3", + "jupyterlab-collaborative-chat>=0.5.0", ] dynamic = ["version", "description", "authors", "urls", "keywords"] diff --git a/packages/jupyter-ai/schema/plugin.json b/packages/jupyter-ai/schema/plugin.json index 78804b5c6..37e0a4671 100644 --- a/packages/jupyter-ai/schema/plugin.json +++ b/packages/jupyter-ai/schema/plugin.json @@ -12,6 +12,27 @@ "preventDefault": false } ], + "jupyter.lab.menus": { + "main": [ + { + "id": "jp-mainmenu-settings", + "items": [ + { + "type": "separator", + "rank": 110 + }, + { + "command": "jupyter-ai:open-settings", + "rank": 110 + }, + { + "type": "separator", + "rank": 110 + } + ] + } + ] + }, "additionalProperties": false, "type": "object" } diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index a1ad0a9b6..b9e9d8bd1 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -34,6 +34,9 @@ type ChatSettingsProps = { rmRegistry: IRenderMimeRegistry; completionProvider: IJaiCompletionProvider | null; openInlineCompleterSettings: () => void; + // The temporary input options, should be removed when the collaborative chat is + // the only chat. + inputOptions?: boolean; }; /** @@ -511,36 +514,42 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { onSuccess={server.refetchApiKeys} /> - {/* Input */} -

Input

- - - When writing a message, press Enter to: - - { - setSendWse(e.target.value === 'newline'); - }} - > - } - label="Send the message" - /> - } - label={ - <> - Start a new line (use Shift+Enter to send) - - } - /> - - + {/* Input - to remove when the collaborative chat is the only chat */} + {(props.inputOptions ?? true) && ( + <> +

Input

+ + + When writing a message, press Enter to: + + { + setSendWse(e.target.value === 'newline'); + }} + > + } + label="Send the message" + /> + } + label={ + <> + Start a new line (use Shift+Enter to + send) + + } + /> + + + + )} +