Skip to content

Commit

Permalink
Merge pull request #165 from ks6088ts-labs/feature/issue-164_slm
Browse files Browse the repository at this point in the history
add SLM based chat app
  • Loading branch information
ks6088ts authored Oct 8, 2024
2 parents 54bd03d + 4779767 commit dbd2247
Show file tree
Hide file tree
Showing 7 changed files with 2,200 additions and 1,653 deletions.
25 changes: 25 additions & 0 deletions apps/15_streamlit_chat_slm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Streamlit Chat with SLM

## Overview

```shell
# Run Ollama server
$ ollama serve

# Install dependencies
$ ollama pull phi3

# Run a simple chat with Ollama
$ poetry run python apps/15_streamlit_chat_slm/chat.py

# Run summarization with SLM
$ poetry run python apps/15_streamlit_chat_slm/summarize.py

# Run streamlit app
$ poetry run python -m streamlit run apps/15_streamlit_chat_slm/main.py
```

# References

- [ChatOllama](https://python.langchain.com/docs/integrations/chat/ollama/)
- [Summarize Text](https://python.langchain.com/docs/tutorials/summarization/)
43 changes: 43 additions & 0 deletions apps/15_streamlit_chat_slm/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import argparse
import logging

from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langchain_ollama import ChatOllama


def init_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="slm_chat",
description="Chat with SLM model",
)
parser.add_argument("-m", "--model", default="phi3")
parser.add_argument("-s", "--system", default="You are a helpful assistant.")
parser.add_argument("-p", "--prompt", default="What is the capital of France?")
parser.add_argument("-v", "--verbose", action="store_true")
return parser.parse_args()


if __name__ == "__main__":
args = init_args()

# Set verbose mode
if args.verbose:
logging.basicConfig(level=logging.DEBUG)

# Parse .env file and set environment variables
load_dotenv()

llm = ChatOllama(
model=args.model,
temperature=0,
)

ai_msg: AIMessage = llm.invoke(
input=[
("system", args.system),
("human", args.prompt),
]
)
print(ai_msg.model_dump_json(indent=2))
# print(ai_msg.content)
67 changes: 67 additions & 0 deletions apps/15_streamlit_chat_slm/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import streamlit as st
from dotenv import load_dotenv
from langchain_ollama import ChatOllama

load_dotenv()

SUPPORTED_MODELS = [
"phi3",
]
with st.sidebar:
slm_model = st.selectbox(
label="Model",
options=SUPPORTED_MODELS,
index=0,
)
"[Azure Portal](https://portal.azure.com/)"
"[Azure OpenAI Studio](https://oai.azure.com/resource/overview)"
"[View the source code](https://github.com/ks6088ts-labs/workshop-azure-openai/blob/main/apps/15_streamlit_chat_slm/main.py)"


def is_configured():
return slm_model in SUPPORTED_MODELS


st.title("15_streamlit_chat_slm")

if not is_configured():
st.warning("Please fill in the required fields at the sidebar.")

if "messages" not in st.session_state:
st.session_state["messages"] = [
{
"role": "assistant",
"content": "Hello! I'm a helpful assistant.",
}
]

# Show chat messages
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])

# Receive user input
if prompt := st.chat_input(disabled=not is_configured()):
client = ChatOllama(
model=slm_model,
temperature=0,
)

st.session_state.messages.append(
{
"role": "user",
"content": prompt,
}
)
st.chat_message("user").write(prompt)
with st.spinner("Thinking..."):
response = client.invoke(
input=st.session_state.messages,
)
msg = response.content
st.session_state.messages.append(
{
"role": "assistant",
"content": msg,
}
)
st.chat_message("assistant").write(msg)
162 changes: 162 additions & 0 deletions apps/15_streamlit_chat_slm/summarize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import asyncio
import operator
from os import getenv
from typing import Annotated, Literal, TypedDict

from langchain.chains.combine_documents.reduce import acollapse_docs, split_list_of_docs
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI
from langchain_text_splitters import CharacterTextSplitter
from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph

token_max = 1000
url = "https://lilianweng.github.io/posts/2023-06-23-agent/"

llm_ollama = ChatOllama(
model="phi3",
temperature=0,
)
llm_azure_openai = AzureChatOpenAI(
temperature=0,
api_key=getenv("AZURE_OPENAI_API_KEY"),
api_version=getenv("AZURE_OPENAI_API_VERSION"),
azure_endpoint=getenv("AZURE_OPENAI_ENDPOINT"),
model=getenv("AZURE_OPENAI_GPT_MODEL"),
)
# Use the Ollama model
llm = llm_ollama


def length_function(documents: list[Document]) -> int:
"""Get number of tokens for input contents."""
return sum(llm.get_num_tokens(doc.page_content) for doc in documents)


# This will be the overall state of the main graph.
# It will contain the input document contents, corresponding
# summaries, and a final summary.
class OverallState(TypedDict):
# Notice here we use the operator.add
# This is because we want combine all the summaries we generate
# from individual nodes back into one list - this is essentially
# the "reduce" part
contents: list[str]
summaries: Annotated[list, operator.add]
collapsed_summaries: list[Document]
final_summary: str


# This will be the state of the node that we will "map" all
# documents to in order to generate summaries
class SummaryState(TypedDict):
content: str


map_prompt = ChatPromptTemplate.from_messages([("system", "Write a concise summary of the following:\\n\\n{context}")])

map_chain = map_prompt | llm | StrOutputParser()


# Here we generate a summary, given a document
async def generate_summary(state: SummaryState):
response = await map_chain.ainvoke(state["content"])
return {"summaries": [response]}


# Here we define the logic to map out over the documents
# We will use this an edge in the graph
def map_summaries(state: OverallState):
# We will return a list of `Send` objects
# Each `Send` object consists of the name of a node in the graph
# as well as the state to send to that node
return [Send("generate_summary", {"content": content}) for content in state["contents"]]


def collect_summaries(state: OverallState):
return {"collapsed_summaries": [Document(summary) for summary in state["summaries"]]}


# Also available via the hub: `hub.pull("rlm/reduce-prompt")`
reduce_template = """
The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary
of the main themes.
"""

reduce_prompt = ChatPromptTemplate([("human", reduce_template)])

reduce_chain = reduce_prompt | llm | StrOutputParser()


# Add node to collapse summaries
async def collapse_summaries(state: OverallState):
doc_lists = split_list_of_docs(state["collapsed_summaries"], length_function, token_max)
results = []
for doc_list in doc_lists:
results.append(await acollapse_docs(doc_list, reduce_chain.ainvoke))

return {"collapsed_summaries": results}


# This represents a conditional edge in the graph that determines
# if we should collapse the summaries or not
def should_collapse(
state: OverallState,
) -> Literal["collapse_summaries", "generate_final_summary"]:
num_tokens = length_function(state["collapsed_summaries"])
if num_tokens > token_max:
return "collapse_summaries"
else:
return "generate_final_summary"


# Here we will generate the final summary
async def generate_final_summary(state: OverallState):
response = await reduce_chain.ainvoke(state["collapsed_summaries"])
return {"final_summary": response}


async def main():
# Construct the graph
# Nodes:
graph = StateGraph(OverallState)
graph.add_node("generate_summary", generate_summary) # same as before
graph.add_node("collect_summaries", collect_summaries)
graph.add_node("collapse_summaries", collapse_summaries)
graph.add_node("generate_final_summary", generate_final_summary)

# Edges:
graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
graph.add_edge("generate_summary", "collect_summaries")
graph.add_conditional_edges("collect_summaries", should_collapse)
graph.add_conditional_edges("collapse_summaries", should_collapse)
graph.add_edge("generate_final_summary", END)

app = graph.compile()

# create graph image
app.get_graph().draw_mermaid_png(output_file_path="docs/images/15_streamlit_chat_slm.summarize_graph.png")

text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0)

loader = WebBaseLoader(web_path=url)
docs = loader.load()

split_docs = text_splitter.split_documents(docs)
print(f"Generated {len(split_docs)} documents.")

async for step in app.astream(
{"contents": [doc.page_content for doc in split_docs]},
{"recursion_limit": 10},
):
print(list(step.keys()))
print(step)


asyncio.run(main())
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit dbd2247

Please sign in to comment.