Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various chat chain enhancements and fixes #144

Merged
merged 5 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details

version: 2

build:
os: ubuntu-22.04
tools:
python: "3.11"

sphinx:
configuration: docs/source/conf.py

python:
install:
- requirements: docs/requirements.txt
19 changes: 14 additions & 5 deletions packages/jupyter-ai/jupyter_ai/actors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
Logger = Union[logging.Logger, logging.LoggerAdapter]

class ACTOR_TYPE(str, Enum):
# the top level actor that routes incoming messages to the appropriate actor
ROUTER = "router"

# the default actor that responds to messages using a language model
DEFAULT = "default"

ASK = "ask"
LEARN = 'learn'
MEMORY = 'memory'
Expand Down Expand Up @@ -74,14 +79,18 @@ def reply(self, response, message: Optional[HumanChatMessage] = None):

def get_llm_chain(self):
actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER)
llm = ray.get(actor.get_provider.remote())
llm_params = ray.get(actor.get_provider_params.remote())
lm_provider = ray.get(actor.get_provider.remote())
lm_provider_params = ray.get(actor.get_provider_params.remote())

curr_lm_id = f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None
next_lm_id = f'{lm_provider.id}:{lm_provider_params["model_id"]}' if lm_provider else None

if not llm:
if not lm_provider:
return None

if llm.__class__.__name__ != self.llm.__class__.__name__:
self.create_llm_chain(llm, llm_params)
if curr_lm_id != next_lm_id:
self.log.info(f"Switching chat language model from {curr_lm_id} to {next_lm_id}.")
self.create_llm_chain(lm_provider, lm_provider_params)
return self.llm_chain

def get_embeddings(self):
Expand Down
59 changes: 45 additions & 14 deletions packages/jupyter-ai/jupyter_ai/actors/default.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,75 @@
from typing import Dict, Type
from typing import Dict, Type, List
import ray
from ray.util.queue import Queue

from langchain import ConversationChain
from langchain.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate
HumanMessagePromptTemplate,
SystemMessagePromptTemplate
)
from langchain.schema import (
AIMessage,
)

from jupyter_ai.actors.base import BaseActor, Logger, ACTOR_TYPE
from jupyter_ai.actors.base import BaseActor, ACTOR_TYPE
from jupyter_ai.actors.memory import RemoteMemory
from jupyter_ai.models import HumanChatMessage
from jupyter_ai.models import HumanChatMessage, ClearMessage, ChatMessage
from jupyter_ai_magics.providers import BaseProvider

SYSTEM_PROMPT = "The following is a friendly conversation between a human and an AI, whose name is Jupyter AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know."
SYSTEM_PROMPT = """
You are Jupyter AI, a conversational assistant living in JupyterLab to help users.
You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}.
You are talkative and provides lots of specific details from its context.
You may use Markdown to format your response.
Code blocks must be formatted in Markdown.
Math should be rendered with inline TeX markup, surrounded by $.
If you do not know the answer to a question, answer truthfully by responding that you do not know.
The following is a friendly conversation between you and a human.
""".strip()

@ray.remote
class DefaultActor(BaseActor):
def __init__(self, reply_queue: Queue, log: Logger):
super().__init__(reply_queue=reply_queue, log=log)
def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
super().__init__(*args, **kwargs)
self.memory = None
self.chat_history = chat_history

def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]):
llm = provider(**provider_params)
memory = RemoteMemory(actor_name=ACTOR_TYPE.MEMORY)
self.memory = RemoteMemory(actor_name=ACTOR_TYPE.MEMORY)
prompt_template = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT),
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(provider_name=llm.name, local_model_id=llm.model_id),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}")
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content="")
])
self.llm = llm
self.llm_chain = ConversationChain(
llm=llm,
prompt=prompt_template,
verbose=True,
memory=memory
memory=self.memory
)

def clear_memory(self):
if not self.memory:
return

# clear chain memory
self.memory.clear()

# clear transcript for existing chat clients
reply_message = ClearMessage()
self.reply_queue.put(reply_message)

# clear transcript for new chat clients
self.chat_history.clear()

def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()
response = self.llm_chain.predict(input=message.body)
response = self.llm_chain.predict(
input=message.body,
stop=["\nHuman:"]
)
self.reply(response, message)
5 changes: 2 additions & 3 deletions packages/jupyter-ai/jupyter_ai/actors/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from ray.util.queue import Queue

from jupyter_ai.actors.base import ACTOR_TYPE, COMMANDS, Logger, BaseActor
from jupyter_ai.models import ClearMessage

@ray.remote
class Router(BaseActor):
Expand All @@ -25,7 +24,7 @@ def _process_message(self, message):
actor = ray.get_actor(COMMANDS[command].value)
actor.process_message.remote(message)
if command == '/clear':
reply_message = ClearMessage()
self.reply_queue.put(reply_message)
actor = ray.get_actor(ACTOR_TYPE.DEFAULT)
actor.clear_memory.remote()
else:
default.process_message.remote(message)
30 changes: 17 additions & 13 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,31 +121,35 @@ def initialize_settings(self):
self.settings["chat_handlers"] = {}

# store chat messages in memory for now
# this is only used to render the UI, and is not the conversational
# memory object used by the LM chain.
self.settings["chat_history"] = []


reply_queue = Queue()
self.settings["reply_queue"] = reply_queue

router = Router.options(name="router").remote(
router = Router.options(name=ACTOR_TYPE.ROUTER).remote(
reply_queue=reply_queue,
log=self.log
log=self.log,
)
default_actor = DefaultActor.options(name=ACTOR_TYPE.DEFAULT.value).remote(
reply_queue=reply_queue,
log=self.log,
chat_history=self.settings["chat_history"]
)

providers_actor = ProvidersActor.options(name=ACTOR_TYPE.PROVIDERS.value).remote(
log=self.log
log=self.log,
)
config_actor = ConfigActor.options(name=ACTOR_TYPE.CONFIG.value).remote(
log=self.log
log=self.log,
)
chat_provider_actor = ChatProviderActor.options(name=ACTOR_TYPE.CHAT_PROVIDER.value).remote(
log=self.log
log=self.log,
)
embeddings_provider_actor = EmbeddingsProviderActor.options(name=ACTOR_TYPE.EMBEDDINGS_PROVIDER.value).remote(
log=self.log
)
default_actor = DefaultActor.options(name=ACTOR_TYPE.DEFAULT.value).remote(
reply_queue=reply_queue,
log=self.log
log=self.log,
)
learn_actor = LearnActor.options(name=ACTOR_TYPE.LEARN.value).remote(
reply_queue=reply_queue,
Expand All @@ -154,16 +158,16 @@ def initialize_settings(self):
)
ask_actor = AskActor.options(name=ACTOR_TYPE.ASK.value).remote(
reply_queue=reply_queue,
log=self.log
log=self.log,
)
memory_actor = MemoryActor.options(name=ACTOR_TYPE.MEMORY.value).remote(
log=self.log,
memory=ConversationBufferWindowMemory(return_messages=True, k=2)
memory=ConversationBufferWindowMemory(return_messages=True, k=2),
)
generate_actor = GenerateActor.options(name=ACTOR_TYPE.GENERATE.value).remote(
reply_queue=reply_queue,
log=self.log,
root_dir=self.settings['server_root_dir']
root_dir=self.settings['server_root_dir'],
)

self.settings['router'] = router
Expand Down