From 1cec78964fc9dfe43e95e700b5ac91716ed445ff Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Sun, 1 Sep 2024 18:08:54 -0700 Subject: [PATCH] Add linting and format with ruff (#18) * Install ruff and run linting that way * Run ruff format * Fix ruff check * Fix broken test * Add pre-commit hooks * Add pre-commit config (lol) --- .dockerignore | 2 +- .github/workflows/test.yml | 13 ++---- .pre-commit-config.yaml | 13 ++++++ LICENSE | 2 +- README.md | 12 +++-- agent/__init__.py | 4 +- agent/langgraph.json | 2 +- agent/llama_guard.py | 30 +++++++----- agent/research_assistant.py | 25 ++++++---- agent/tools.py | 7 ++- client/client.py | 71 +++++++++++++++++------------ media/agent_architecture.excalidraw | 2 +- pyproject.toml | 52 +++++++++++++++++++++ requirements.txt | 1 - run_client.py | 4 +- schema/__init__.py | 4 +- schema/schema.py | 32 +++++++++---- schema/test_schema.py | 4 ++ service/__init__.py | 4 +- service/service.py | 19 ++++++-- service/test_service.py | 14 ++++-- streamlit_app.py | 52 ++++++++++++--------- test-requirements.txt | 2 - 23 files changed, 249 insertions(+), 122 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml delete mode 100644 test-requirements.txt diff --git a/.dockerignore b/.dockerignore index 2810362..bc28ffb 100644 --- a/.dockerignore +++ b/.dockerignore @@ -9,4 +9,4 @@ __pycache__ env venv .venv -*.db \ No newline at end of file +*.db diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b6a7a01..c7d1a27 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,17 +29,14 @@ jobs: curl -LsSf https://astral.sh/uv/0.3.2/install.sh | sh - name: Install dependencies with uv run: | - uv pip install flake8 - uv pip install -r requirements.txt - uv pip install -r test-requirements.txt + uv pip install -r pyproject.toml --extra dev env: UV_SYSTEM_PYTHON: 1 - - name: Lint with flake8 + - name: Lint and format with ruff run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + ruff format --check + ruff check --output-format github + - name: Test with pytest run: | pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f8b1786 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.3 + hooks: + - id: ruff + args: [ --fix ] + - id: ruff-format diff --git a/LICENSE b/LICENSE index 63413f8..34dd27b 100644 --- a/LICENSE +++ b/LICENSE @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/README.md b/README.md index 1a4d29a..57b487d 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ This project offers a template for you to easily build and run your own agents u ### [Try the app!](https://agent-service-toolkit.streamlit.app/) [![Streamlit App](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://agent-service-toolkit.streamlit.app/) - + @@ -89,7 +89,7 @@ With that said, there are several other interesting projects in this space that # Optional, to enable simple header-based auth on the service AUTH_SECRET=any_string_you_choose - + # Optional, to enable LangSmith tracing LANGCHAIN_TRACING_V2=true LANGCHAIN_ENDPOINT=https://api.smith.langchain.com @@ -153,15 +153,17 @@ The agent supports [LangGraph Studio](https://github.com/langchain-ai/langgraph- You can simply install LangGraph Studio, add your `.env` file to the root directory as described above, and then launch LangGraph studio pointed at the `agent/` directory. Customize `agent/langgraph.json` as needed. -### Running Tests +### Contributing Currently the tests need to be run using the local development without Docker setup. To run the tests for the agent service: 1. Ensure you're in the project root directory and have activated your virtual environment. -2. Install the test dependencies: +2. Install the development dependencies and pre-commit hooks: ``` - pip install -r test-requirements.txt + pip install uv + uv pip install -r pyproject.toml --extra dev + pre-commit install ``` 3. Run the tests using pytest: diff --git a/agent/__init__.py b/agent/__init__.py index c6bc481..ba17529 100644 --- a/agent/__init__.py +++ b/agent/__init__.py @@ -1,5 +1,3 @@ from agent.research_assistant import research_assistant -__all__ = [ - "research_assistant" -] +__all__ = ["research_assistant"] diff --git a/agent/langgraph.json b/agent/langgraph.json index 1c72fd8..adfdd06 100644 --- a/agent/langgraph.json +++ b/agent/langgraph.json @@ -5,4 +5,4 @@ "research_assistant": "./research_assistant.py:research_assistant" }, "env": "../.env" -} \ No newline at end of file +} diff --git a/agent/llama_guard.py b/agent/llama_guard.py index 3b6442c..3a0b440 100644 --- a/agent/llama_guard.py +++ b/agent/llama_guard.py @@ -14,7 +14,9 @@ class SafetyAssessment(Enum): class LlamaGuardOutput(BaseModel): safety_assessment: SafetyAssessment = Field(description="The safety assessment of the content.") - unsafe_categories: List[str] = Field(description="If content is unsafe, the list of unsafe categories.", default=[]) + unsafe_categories: List[str] = Field( + description="If content is unsafe, the list of unsafe categories.", default=[] + ) unsafe_content_categories = { @@ -31,7 +33,7 @@ class LlamaGuardOutput(BaseModel): "S11": "Self-Harm.", "S12": "Sexual Content.", "S13": "Elections.", - "S14": "Code Interpreter Abuse." + "S14": "Code Interpreter Abuse.", } categories_str = "\n".join([f"{k}: {v}" for k, v in unsafe_content_categories.items()]) @@ -70,9 +72,7 @@ def parse_llama_guard_output(output: str) -> LlamaGuardOutput: return LlamaGuardOutput(safety_assessment=SafetyAssessment.ERROR) try: categories = parsed_output[1].split(",") - readable_categories = [ - unsafe_content_categories[c.strip()].strip(".") for c in categories - ] + readable_categories = [unsafe_content_categories[c.strip()].strip(".") for c in categories] return LlamaGuardOutput( safety_assessment=SafetyAssessment.UNSAFE, unsafe_categories=readable_categories, @@ -83,9 +83,13 @@ def parse_llama_guard_output(output: str) -> LlamaGuardOutput: async def llama_guard(role: str, messages: List[AnyMessage]) -> LlamaGuardOutput: role_mapping = {"ai": "Agent", "human": "User"} - messages_str = [f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"]] + messages_str = [ + f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"] + ] conversation_history = "\n\n".join(messages_str) - compiled_prompt = llama_guard_prompt.format(role=role, conversation_history=conversation_history) + compiled_prompt = llama_guard_prompt.format( + role=role, conversation_history=conversation_history + ) result = await model.ainvoke([SystemMessage(content=compiled_prompt)]) return parse_llama_guard_output(result.content) @@ -94,9 +98,13 @@ async def llama_guard(role: str, messages: List[AnyMessage]) -> LlamaGuardOutput import asyncio async def main(): - output = await llama_guard("Agent", [ - HumanMessage(content="Tell me a fun fact?"), - AIMessage(content="Did you know that honey never spoils?"), - ]) + output = await llama_guard( + "Agent", + [ + HumanMessage(content="Tell me a fun fact?"), + AIMessage(content="Did you know that honey never spoils?"), + ], + ) print(output) + asyncio.run(main()) diff --git a/agent/research_assistant.py b/agent/research_assistant.py index 2f96f44..8eea701 100644 --- a/agent/research_assistant.py +++ b/agent/research_assistant.py @@ -9,7 +9,7 @@ from langgraph.managed import IsLastStep from langgraph.prebuilt import ToolNode -from agent.tools import arxiv_search, calculator, web_search +from agent.tools import calculator, web_search from agent.llama_guard import llama_guard, LlamaGuardOutput @@ -17,11 +17,12 @@ class AgentState(MessagesState): safety: LlamaGuardOutput is_last_step: IsLastStep + # NOTE: models with streaming=True will send tokens as they are generated # if the /stream endpoint is called with stream_tokens=True (the default) models = { "gpt-4o-mini": ChatOpenAI(model="gpt-4o-mini", temperature=0.5, streaming=True), - "llama-3.1-70b": ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5) + "llama-3.1-70b": ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5), } tools = [web_search, calculator] @@ -29,7 +30,7 @@ class AgentState(MessagesState): instructions = f""" You are a helpful research assistant with the ability to search the web for information. Today's date is {current_date}. - + NOTE: THE USER CAN'T SEE THE TOOL RESPONSE. A few things to remember: @@ -39,6 +40,7 @@ class AgentState(MessagesState): so for the final response, use human readable format - e.g. "300 * 200", not "(300 \\times 200)". """ + def wrap_model(model: BaseChatModel): model = model.bind_tools(tools) preprocessor = RunnableLambda( @@ -47,6 +49,7 @@ def wrap_model(model: BaseChatModel): ) return preprocessor | model + async def acall_model(state: AgentState, config: RunnableConfig): m = models[config["configurable"].get("model", "gpt-4o-mini")] model_runnable = wrap_model(m) @@ -68,6 +71,7 @@ async def llama_guard_input(state: AgentState, config: RunnableConfig): safety_output = await llama_guard("User", state["messages"]) return {"safety": safety_output} + async def block_unsafe_content(state: AgentState, config: RunnableConfig): safety: LlamaGuardOutput = state["safety"] output_messages = [] @@ -77,10 +81,13 @@ async def block_unsafe_content(state: AgentState, config: RunnableConfig): if last_message.type == "ai": output_messages.append(RemoveMessage(id=last_message.id)) - content_warning = f"This conversation was flagged for unsafe content: {', '.join(safety.unsafe_categories)}" + content_warning = ( + f"This conversation was flagged for unsafe content: {', '.join(safety.unsafe_categories)}" + ) output_messages.append(AIMessage(content=content_warning)) return {"messages": output_messages} + # Define the graph agent = StateGraph(AgentState) agent.add_node("model", acall_model) @@ -105,11 +112,12 @@ async def block_unsafe_content(state: AgentState, config: RunnableConfig): # ) # Always END after blocking unsafe content -#agent.add_edge("block_unsafe_content", END) +# agent.add_edge("block_unsafe_content", END) # Always run "model" after "tools" agent.add_edge("tools", "model") + # After "model", if there are tool calls, run "tools". Otherwise END. def pending_tool_calls(state: AgentState): last_message = state["messages"][-1] @@ -117,6 +125,8 @@ def pending_tool_calls(state: AgentState): return "tools" else: return END + + agent.add_conditional_edges("model", pending_tool_calls, {"tools": "tools", END: END}) research_assistant = agent.compile( @@ -130,7 +140,7 @@ def pending_tool_calls(state: AgentState): from dotenv import load_dotenv load_dotenv() - + async def main(): inputs = {"messages": [("user", "Find me a recipe for chocolate chip cookies")]} result = await research_assistant.ainvoke( @@ -145,8 +155,7 @@ async def main(): # export CFLAGS="-I $(brew --prefix graphviz)/include" # export LDFLAGS="-L $(brew --prefix graphviz)/lib" # pip install pygraphviz - # + # # researcH_assistant.get_graph().draw_png("agent_diagram.png") - asyncio.run(main()) diff --git a/agent/tools.py b/agent/tools.py index e7ffa78..fd25233 100644 --- a/agent/tools.py +++ b/agent/tools.py @@ -2,16 +2,14 @@ import numexpr import re from langchain_core.tools import tool, BaseTool -from langchain_community.tools import DuckDuckGoSearchResults, ArxivQueryRun +from langchain_community.tools import DuckDuckGoSearchResults web_search = DuckDuckGoSearchResults(name="WebSearch") -# Kinda busted since it doesn't return links -arxiv_search = ArxivQueryRun(name="ArxivSearch") def calculator_func(expression: str) -> str: """Calculates a math expression using numexpr. - + Useful for when you need to answer questions about math using numexpr. This tool is only for math questions and nothing else. Only input math expressions. @@ -39,5 +37,6 @@ def calculator_func(expression: str) -> str: " Please try again with a valid numerical expression" ) + calculator: BaseTool = tool(calculator_func) calculator.name = "Calculator" diff --git a/client/client.py b/client/client.py index 460292d..3ee798b 100644 --- a/client/client.py +++ b/client/client.py @@ -5,6 +5,7 @@ import requests from schema import ChatMessage, UserInput, StreamInput, Feedback + class AgentClient: """Client for interacting with the agent service.""" @@ -25,7 +26,9 @@ def _headers(self): headers["Authorization"] = f"Bearer {self.auth_secret}" return headers - async def ainvoke(self, message: str, model: str|None = None, thread_id: str|None = None) -> ChatMessage: + async def ainvoke( + self, message: str, model: str | None = None, thread_id: str | None = None + ) -> ChatMessage: """ Invoke the agent asynchronously. Only the final message is returned. @@ -43,14 +46,18 @@ async def ainvoke(self, message: str, model: str|None = None, thread_id: str|Non request.thread_id = thread_id if model: request.model = model - async with session.post(f"{self.base_url}/invoke", json=request.dict(), headers=self._headers) as response: + async with session.post( + f"{self.base_url}/invoke", json=request.dict(), headers=self._headers + ) as response: if response.status == 200: result = await response.json() return ChatMessage.parse_obj(result) else: raise Exception(f"Error: {response.status} - {await response.text()}") - def invoke(self, message: str, model: str|None = None, thread_id: str|None = None) -> ChatMessage: + def invoke( + self, message: str, model: str | None = None, thread_id: str | None = None + ) -> ChatMessage: """ Invoke the agent synchronously. Only the final message is returned. @@ -67,14 +74,16 @@ def invoke(self, message: str, model: str|None = None, thread_id: str|None = Non request.thread_id = thread_id if model: request.model = model - response = requests.post(f"{self.base_url}/invoke", json=request.dict(), headers=self._headers) + response = requests.post( + f"{self.base_url}/invoke", json=request.dict(), headers=self._headers + ) if response.status_code == 200: return ChatMessage.parse_obj(response.json()) else: raise Exception(f"Error: {response.status_code} - {response.text}") def _parse_stream_line(self, line: str) -> ChatMessage | str | None: - line = line.decode('utf-8').strip() + line = line.decode("utf-8").strip() if line.startswith("data: "): data = line[6:] if data == "[DONE]": @@ -97,15 +106,15 @@ def _parse_stream_line(self, line: str) -> ChatMessage | str | None: raise Exception(parsed["content"]) def stream( - self, - message: str, - model: str|None = None, - thread_id: str|None = None, - stream_tokens: bool = True - ) -> Generator[ChatMessage | str, None, None]: + self, + message: str, + model: str | None = None, + thread_id: str | None = None, + stream_tokens: bool = True, + ) -> Generator[ChatMessage | str, None, None]: """ Stream the agent's response synchronously. - + Each intermediate message of the agent process is yielded as a ChatMessage. If stream_tokens is True (the default value), the response will also yield content tokens from streaming models as they are generated. @@ -125,10 +134,12 @@ def stream( request.thread_id = thread_id if model: request.model = model - response = requests.post(f"{self.base_url}/stream", json=request.dict(), headers=self._headers, stream=True) + response = requests.post( + f"{self.base_url}/stream", json=request.dict(), headers=self._headers, stream=True + ) if response.status_code != 200: raise Exception(f"Error: {response.status_code} - {response.text}") - + for line in response.iter_lines(): if line: parsed = self._parse_stream_line(line) @@ -137,15 +148,15 @@ def stream( yield parsed async def astream( - self, - message: str, - model: str|None = None, - thread_id: str|None = None, - stream_tokens: bool = True - ) -> AsyncGenerator[ChatMessage | str, None]: + self, + message: str, + model: str | None = None, + thread_id: str | None = None, + stream_tokens: bool = True, + ) -> AsyncGenerator[ChatMessage | str, None]: """ Stream the agent's response asynchronously. - + Each intermediate message of the agent process is yielded as an AnyMessage. If stream_tokens is True (the default value), the response will also yield content tokens from streaming modelsas they are generated. @@ -166,24 +177,22 @@ async def astream( request.thread_id = thread_id if model: request.model = model - async with session.post(f"{self.base_url}/stream", json=request.dict(), headers=self._headers) as response: + async with session.post( + f"{self.base_url}/stream", json=request.dict(), headers=self._headers + ) as response: if response.status != 200: raise Exception(f"Error: {response.status} - {await response.text()}") # Parse incoming events with the SSE protocol async for line in response.content: - if line.decode('utf-8').strip(): + if line.decode("utf-8").strip(): parsed = self._parse_stream_line(line) if parsed is None: break yield parsed async def acreate_feedback( - self, - run_id: str, - key: str, - score: float, - kwargs: Dict[str, Any] = {} - ): + self, run_id: str, key: str, score: float, kwargs: Dict[str, Any] = {} + ): """ Create a feedback record for a run. @@ -193,7 +202,9 @@ async def acreate_feedback( """ async with aiohttp.ClientSession() as session: request = Feedback(run_id=run_id, key=key, score=score, kwargs=kwargs) - async with session.post(f"{self.base_url}/feedback", json=request.dict(), headers=self._headers) as response: + async with session.post( + f"{self.base_url}/feedback", json=request.dict(), headers=self._headers + ) as response: if response.status != 200: raise Exception(f"Error: {response.status} - {await response.text()}") await response.json() diff --git a/media/agent_architecture.excalidraw b/media/agent_architecture.excalidraw index c41c8db..ff41ec9 100644 --- a/media/agent_architecture.excalidraw +++ b/media/agent_architecture.excalidraw @@ -1030,4 +1030,4 @@ "lastRetrieved": 1722877915754 } } -} \ No newline at end of file +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4024b31 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,52 @@ +[project] +name = "agent-service-toolkit" +version = "0.1.0" +description = "Full toolkit for running an AI agent service built with LangGraph, FastAPI and Streamlit" +readme = "README.md" +authors = [ + {name = "Joshua Carroll", email = "carroll.joshk@gmail.com"}, +] +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: MIT License", + "Framework :: FastAPI", + "Programming Language :: Python :: 3.12", +] + +requires-python = ">=3.9, <= 3.12.3" + +# NOTE: FastAPI < 0.100.0 and Pydantic v1 is required until langchain has full pydantic v2 compatibility +# https://python.langchain.com/v0.1/docs/guides/development/pydantic_compatibility/ +# https://github.com/langchain-ai/langchain/discussions/24405 +# https://github.com/langchain-ai/langchain/discussions/9337 +# IMPORTANT: This also requires using python < 3.12.4 +dependencies = [ + "aiohttp ~=3.10.0", + "duckduckgo-search ~=6.2.6", + "fastapi <0.100.0", + "langchain-core ~=0.2.26", + "langchain-community ~=0.2.11", + "langchain-openai ~=0.1.20", + "langchain-groq ~=0.1.9", + "langgraph ~=0.2.3", + "langgraph-checkpoint ~=1.0.2", + "langgraph-checkpoint-sqlite ~=1.0.0", + "langsmith ~=0.1.96", + "numexpr ~=2.10.1", + "pydantic ~=1.10.17", + "python-dotenv ~=1.0.1", + "streamlit ~=1.37.0", + "uvicorn ~=0.30.5", +] + +[project.optional-dependencies] +dev = [ + "httpx~=0.26.0", + "pre-commit", + "pytest", + "ruff", +] + +[tool.ruff] +line-length = 100 +target-version = "py39" diff --git a/requirements.txt b/requirements.txt index d6fb1ff..0423bb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ # IMPORTANT: This also requires using python < 3.12.4 aiohttp~=3.10.0 -arxiv~=2.1.3 duckduckgo-search~=6.2.6 fastapi<0.100.0 langchain-core~=0.2.26 diff --git a/run_client.py b/run_client.py index 809c66b..a949e92 100644 --- a/run_client.py +++ b/run_client.py @@ -1,9 +1,10 @@ - from client import AgentClient from schema import ChatMessage #### ASYNC #### import asyncio + + async def amain(): client = AgentClient() @@ -20,6 +21,7 @@ async def amain(): else: print(f"ERROR: Unknown type - {type(message)}") + asyncio.run(amain()) #### SYNC #### diff --git a/schema/__init__.py b/schema/__init__.py index 7b03335..2bd457b 100644 --- a/schema/__init__.py +++ b/schema/__init__.py @@ -1,3 +1,3 @@ -from schema.schema import * +from schema.schema import UserInput, AgentResponse, ChatMessage, StreamInput, Feedback -__all__ = ["UserInput", "AgentResponse", "StreamInput", "Feedback"] +__all__ = ["UserInput", "AgentResponse", "ChatMessage", "StreamInput", "Feedback"] diff --git a/schema/schema.py b/schema/schema.py index 321ae30..e066025 100644 --- a/schema/schema.py +++ b/schema/schema.py @@ -1,14 +1,19 @@ from typing import Dict, Any, List, Literal from langchain_core.messages import ( - BaseMessage, HumanMessage, AIMessage, - ToolMessage, ToolCall, - message_to_dict, messages_from_dict, + BaseMessage, + HumanMessage, + AIMessage, + ToolMessage, + ToolCall, + message_to_dict, + messages_from_dict, ) from pydantic import BaseModel, Field class UserInput(BaseModel): """Basic user input for the agent.""" + message: str = Field( description="User input to the agent.", examples=["What is the weather in Tokyo?"], @@ -27,6 +32,7 @@ class UserInput(BaseModel): class StreamInput(UserInput): """User input for streaming the agent's response.""" + stream_tokens: bool = Field( description="Whether to stream LLM tokens to the client.", default=True, @@ -35,18 +41,23 @@ class StreamInput(UserInput): class AgentResponse(BaseModel): """Response from the agent when called via /invoke.""" + message: Dict[str, Any] = Field( description="Final response from the agent, as a serialized LangChain message.", - examples=[{'message': - {'type': 'ai', 'data': - {'content': 'The weather in Tokyo is 70 degrees.', 'type': 'ai'} - } - }], + examples=[ + { + "message": { + "type": "ai", + "data": {"content": "The weather in Tokyo is 70 degrees.", "type": "ai"}, + } + } + ], ) class ChatMessage(BaseModel): """Message in a chat.""" + type: Literal["human", "ai", "tool"] = Field( description="Role of the message.", examples=["human", "ai", "tool"], @@ -97,7 +108,7 @@ def from_langchain(cls, message: BaseMessage) -> "ChatMessage": return tool_message case _: raise ValueError(f"Unsupported message type: {message.__class__.__name__}") - + def to_langchain(self) -> BaseMessage: """Convert the ChatMessage to a LangChain message.""" if self.original: @@ -116,6 +127,7 @@ def pretty_print(self) -> None: class Feedback(BaseModel): """Feedback for a run, to record to LangSmith.""" + run_id: str = Field( description="Run ID to record feedback for.", examples=["847c6285-8fc9-4560-a83f-4e6285809254"], @@ -131,5 +143,5 @@ class Feedback(BaseModel): kwargs: Dict[str, Any] = Field( description="Additional feedback kwargs, passed to LangSmith.", default={}, - examples=[{'comment': 'In-line human feedback'}], + examples=[{"comment": "In-line human feedback"}], ) diff --git a/schema/test_schema.py b/schema/test_schema.py index 37c4a5e..f0b5c91 100644 --- a/schema/test_schema.py +++ b/schema/test_schema.py @@ -1,6 +1,7 @@ from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, ToolCall from schema import ChatMessage + def test_messages_to_langchain(): human_message = ChatMessage(type="human", content="Hello, world!") lc_message = human_message.to_langchain() @@ -8,6 +9,7 @@ def test_messages_to_langchain(): assert lc_message.type == "human" assert lc_message.content == "Hello, world!" + def test_messages_from_langchain(): lc_human_message = HumanMessage(content="Hello, world!") human_message = ChatMessage.from_langchain(lc_human_message) @@ -34,6 +36,7 @@ def test_messages_from_langchain(): except ValueError as e: assert str(e) == "Unsupported message type: SystemMessage" + def test_message_run_id_usage(): run_id = "847c6285-8fc9-4560-a83f-4e6285809254" lc_message = AIMessage(content="Hello, world!") @@ -41,6 +44,7 @@ def test_message_run_id_usage(): ai_message.run_id = run_id assert ai_message.run_id == run_id + def test_messages_tool_calls(): tool_call = ToolCall(name="test_tool", args={"x": 1, "y": 2}, id="call_Jja7") lc_ai_message = AIMessage(content="", tool_calls=[tool_call]) diff --git a/service/__init__.py b/service/__init__.py index 00c98b1..bac9e18 100644 --- a/service/__init__.py +++ b/service/__init__.py @@ -1,5 +1,3 @@ from service.service import app -__all__ = [ - "app" -] \ No newline at end of file +__all__ = ["app"] diff --git a/service/service.py b/service/service.py index b8df69f..5ff0102 100644 --- a/service/service.py +++ b/service/service.py @@ -18,6 +18,7 @@ class TokenQueueStreamingHandler(AsyncCallbackHandler): """LangChain callback handler for streaming LLM tokens to an asyncio queue.""" + def __init__(self, queue: asyncio.Queue): self.queue = queue @@ -35,18 +36,21 @@ async def lifespan(app: FastAPI): yield # context manager will clean up the AsyncSqliteSaver on exit + app = FastAPI(lifespan=lifespan) + @app.middleware("http") async def check_auth_header(request: Request, call_next): if auth_secret := os.getenv("AUTH_SECRET"): - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith("Bearer "): return Response(status_code=401, content="Missing or invalid token") if auth_header[7:] != auth_secret: return Response(status_code=401, content="Invalid token") return await call_next(request) + def _parse_input(user_input: UserInput) -> Tuple[Dict[str, Any], str]: run_id = uuid4() thread_id = user_input.thread_id or str(uuid4()) @@ -60,11 +64,12 @@ def _parse_input(user_input: UserInput) -> Tuple[Dict[str, Any], str]: ) return kwargs, run_id + @app.post("/invoke") async def invoke(user_input: UserInput) -> ChatMessage: """ Invoke the agent with user input to retrieve a final response. - + Use thread_id to persist and continue a multi-turn conversation. run_id kwarg is also attached to messages for recording feedback. """ @@ -78,6 +83,7 @@ async def invoke(user_input: UserInput) -> ChatMessage: except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + async def message_generator(user_input: StreamInput) -> AsyncGenerator[str, None]: """ Generate a stream of messages from the agent. @@ -92,13 +98,14 @@ async def message_generator(user_input: StreamInput) -> AsyncGenerator[str, None output_queue = asyncio.Queue(maxsize=10) if user_input.stream_tokens: kwargs["config"]["callbacks"] = [TokenQueueStreamingHandler(queue=output_queue)] - + # Pass the agent's stream of messages to the queue in a separate task, so # we can yield the messages to the client in the main thread. async def run_agent_stream(): async for s in agent.astream(**kwargs, stream_mode="updates"): await output_queue.put(s) await output_queue.put(None) + stream_task = asyncio.create_task(run_agent_stream()) # Process the queue and yield messages over the SSE stream. @@ -124,20 +131,22 @@ async def run_agent_stream(): if chat_message.type == "human" and chat_message.content == user_input.message: continue yield f"data: {json.dumps({'type': 'message', 'content': chat_message.dict()})}\n\n" - + await stream_task yield "data: [DONE]\n\n" + @app.post("/stream") async def stream_agent(user_input: StreamInput): """ Stream the agent's response to a user input, including intermediate messages and tokens. - + Use thread_id to persist and continue a multi-turn conversation. run_id kwarg is also attached to all messages for recording feedback. """ return StreamingResponse(message_generator(user_input), media_type="text/event-stream") + @app.post("/feedback") async def feedback(feedback: Feedback): """ diff --git a/service/test_service.py b/service/test_service.py index 06d9fb5..700cf9e 100644 --- a/service/test_service.py +++ b/service/test_service.py @@ -7,30 +7,36 @@ client = TestClient(app) + @patch("service.service.research_assistant") def test_invoke(mock_agent): QUESTION = "What is the weather in Tokyo?" ANSWER = "The weather in Tokyo is 70 degrees." agent_response = {"messages": [AIMessage(content=ANSWER)]} mock_agent.ainvoke = AsyncMock(return_value=agent_response) - + with client as c: response = c.post("/invoke", json={"message": QUESTION}) assert response.status_code == 200 - + mock_agent.ainvoke.assert_awaited_once() input_message = mock_agent.ainvoke.await_args.kwargs["input"]["messages"][0] assert input_message.content == QUESTION - + output = ChatMessage.parse_obj(response.json()) assert output.type == "ai" assert output.content == ANSWER + @patch("service.service.LangsmithClient") def test_feedback(mock_client): ls_instance = mock_client.return_value ls_instance.create_feedback.return_value = None - body = {"run_id": "847c6285-8fc9-4560-a83f-4e6285809254", "key": "human-feedback-stars", "score": 0.8} + body = { + "run_id": "847c6285-8fc9-4560-a83f-4e6285809254", + "key": "human-feedback-stars", + "score": 0.8, + } response = client.post("/feedback", json=body) assert response.status_code == 200 assert response.json() == {"status": "success"} diff --git a/streamlit_app.py b/streamlit_app.py index dee5801..64be91d 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -22,6 +22,7 @@ APP_TITLE = "Agent Service Toolkit" APP_ICON = "🧰" + @st.cache_resource def get_agent_client(): agent_url = os.getenv("AGENT_URL", "http://localhost") @@ -65,21 +66,29 @@ async def main(): m = st.radio("LLM to use", options=models.keys()) model = models[m] use_streaming = st.toggle("Stream results", value=True) - + @st.dialog("Architecture") def architecture_dialog(): - st.image("https://github.com/JoshuaC215/agent-service-toolkit/blob/main/media/agent_architecture.png?raw=true") + st.image( + "https://github.com/JoshuaC215/agent-service-toolkit/blob/main/media/agent_architecture.png?raw=true" + ) "[View full size on Github](https://github.com/JoshuaC215/agent-service-toolkit/blob/main/media/agent_architecture.png)" - st.caption("App hosted on [Streamlit Cloud](https://share.streamlit.io/) with FastAPI service running in [Azure](https://learn.microsoft.com/en-us/azure/app-service/)") + st.caption( + "App hosted on [Streamlit Cloud](https://share.streamlit.io/) with FastAPI service running in [Azure](https://learn.microsoft.com/en-us/azure/app-service/)" + ) if st.button(":material/schema: Architecture", use_container_width=True): architecture_dialog() with st.popover(":material/policy: Privacy", use_container_width=True): - st.write("Prompts, responses and feedback in this app are anonymously recorded and saved to LangSmith for product evaluation and improvement purposes only.") + st.write( + "Prompts, responses and feedback in this app are anonymously recorded and saved to LangSmith for product evaluation and improvement purposes only." + ) "[View the source code](https://github.com/JoshuaC215/agent-service-toolkit)" - st.caption("Made with :material/favorite: by [Joshua](https://www.linkedin.com/in/joshua-k-carroll/) in Oakland") + st.caption( + "Made with :material/favorite: by [Joshua](https://www.linkedin.com/in/joshua-k-carroll/) in Oakland" + ) # Draw existing messages if "messages" not in st.session_state: @@ -93,7 +102,9 @@ def architecture_dialog(): # draw_messages() expects an async iterator over messages async def amessage_iter(): - for m in messages: yield m + for m in messages: + yield m + await draw_messages(amessage_iter()) # Generate new message if the user provided new input @@ -116,7 +127,7 @@ async def amessage_iter(): ) messages.append(response) st.chat_message("ai").write(response.content) - st.rerun() # Clear stale containers + st.rerun() # Clear stale containers # If messages have been generated, show feedback widget if len(messages) > 0: @@ -125,9 +136,9 @@ async def amessage_iter(): async def draw_messages( - messages_agen: AsyncGenerator[ChatMessage | str, None], - is_new=False, - ): + messages_agen: AsyncGenerator[ChatMessage | str, None], + is_new=False, +): """ Draws a set of chat messages - either replaying existing messages or streaming new ones. @@ -136,7 +147,7 @@ async def draw_messages( - Use a placeholder container to render streaming tokens as they arrive. - Use a status container to render tool calls. Track the tool inputs and outputs and update the status container accordingly. - + The function also needs to track the last message container in session state since later messages can draw to the same container. This is also used for drawing the feedback widget in the latest chat message. @@ -166,7 +177,7 @@ async def draw_messages( st.session_state.last_message = st.chat_message("ai") with st.session_state.last_message: streaming_placeholder = st.empty() - + streaming_content += msg streaming_placeholder.write(streaming_content) continue @@ -186,12 +197,12 @@ async def draw_messages( # If we're rendering new messages, store the message in session state if is_new: st.session_state.messages.append(msg) - + # If the last message type was not AI, create a new chat message if last_message_type != "ai": last_message_type = "ai" st.session_state.last_message = st.chat_message("ai") - + with st.session_state.last_message: # If the message has content, write it out. # Reset the streaming variables to prepare for the next message. @@ -210,9 +221,9 @@ async def draw_messages( call_results = {} for tool_call in msg.tool_calls: status = st.status( - f"""Tool Call: {tool_call["name"]}""", - state="running" if is_new else "complete", - ) + f"""Tool Call: {tool_call["name"]}""", + state="running" if is_new else "complete", + ) call_results[tool_call["id"]] = status status.write("Input:") status.write(tool_call["args"]) @@ -224,7 +235,7 @@ async def draw_messages( st.error(f"Unexpected ChatMessage type: {tool_result.type}") st.write(tool_result) st.stop() - + # Record the message if it's new, and update the correct # status container with the result if is_new: @@ -235,7 +246,7 @@ async def draw_messages( status.update(state="complete") # In case of an unexpected message type, log an error and stop - case _: + case _: st.error(f"Unexpected ChatMessage type: {msg.type}") st.write(msg) st.stop() @@ -247,13 +258,12 @@ async def handle_feedback(): # Keep track of last feedback sent to avoid sending duplicates if "last_feedback" not in st.session_state: st.session_state.last_feedback = (None, None) - + latest_run_id = st.session_state.messages[-1].run_id feedback = st.feedback("stars", key=latest_run_id) # If the feedback value or run ID has changed, send a new feedback record if feedback and (latest_run_id, feedback) != st.session_state.last_feedback: - # Normalize the feedback value (an index) to a score between 0 and 1 normalized_score = (feedback + 1) / 5.0 diff --git a/test-requirements.txt b/test-requirements.txt deleted file mode 100644 index 4ef4a38..0000000 --- a/test-requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -pytest -httpx==0.26.0 \ No newline at end of file