diff --git a/README.md b/README.md index 20416b7..5518006 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ docker compose watch 1. **Advanced Streaming**: A novel approach to support both token-based and message-based streaming. 1. **Content Moderation**: Implements LlamaGuard for content moderation (requires Groq API key). 1. **Streamlit Interface**: Provides a user-friendly chat interface for interacting with the agent. +1. **Multiple Agent Support**: Run multiple agents in the service and call by URL path 1. **Asynchronous Design**: Utilizes async/await for efficient handling of concurrent requests. 1. **Feedback Mechanism**: Includes a star-based feedback system integrated with LangSmith. 1. **Docker Support**: Includes Dockerfiles and a docker compose file for easy development and deployment. @@ -60,10 +61,12 @@ docker compose watch The repository is structured as follows: -- `src/agent/research_assistant.py`: Defines the LangGraph agent -- `src/agent/llama_guard.py`: Defines the LlamaGuard content moderation -- `src/schema/schema.py`: Defines the service schema -- `src/service/service.py`: FastAPI service to serve the agent +- `src/agents/research_assistant.py`: Defines the main LangGraph agent +- `src/agents/llama_guard.py`: Defines the LlamaGuard content moderation +- `src/agents/models.py`: Configures available models based on ENV +- `src/agents/agents.py`: Mapping of all agents provided by the service +- `src/schema/schema.py`: Defines the protocol schema +- `src/service/service.py`: FastAPI service to serve the agents - `src/client/client.py`: Client to interact with the agent service - `src/streamlit_app.py`: Streamlit app providing a chat interface @@ -208,8 +211,9 @@ Currently the tests need to be run using the local development without Docker se To customize the agent for your own use case: -1. Modify the `src/agent/research_assistant.py` file to change the agent's behavior and tools. Or, build a new agent from scratch. -2. Adjust the Streamlit interface in `src/streamlit_app.py` to match your agent's capabilities. +1. Add your new agent to the `src/agents` directory. You can copy `research_assistant.py` or `chatbot.py` and modify it to change the agent's behavior and tools. +1. Import and add your new agent to the `agents` dictionary in `src/agents/agents.py`. Your agent can be called by `//invoke` or `//stream`. +1. Adjust the Streamlit interface in `src/streamlit_app.py` to match your agent's capabilities. ## Building other apps on the AgentClient @@ -239,7 +243,7 @@ Contributions are welcome! Please feel free to submit a Pull Request. - [x] Get LlamaGuard working for content moderation (anyone know a reliable and fast hosted version?) - [x] Add more sophisticated tools for the research assistant - [x] Increase test coverage and add CI pipeline -- [ ] Add support for multiple agents running on the same service, including non-chat agent +- [x] Add support for multiple agents running on the same service, including non-chat agent - [ ] Deployment instructions and configuration for cloud providers - [ ] More ideas? File an issue or create a discussion! diff --git a/compose.yaml b/compose.yaml index c26bfe9..5ff954d 100644 --- a/compose.yaml +++ b/compose.yaml @@ -9,9 +9,9 @@ services: - .env develop: watch: - - path: src/agent/ + - path: src/agents/ action: sync+restart - target: /app/agent/ + target: /app/agents/ - path: src/schema/ action: sync+restart target: /app/schema/ diff --git a/docker/Dockerfile.service b/docker/Dockerfile.service index ca28542..0a0e5ff 100644 --- a/docker/Dockerfile.service +++ b/docker/Dockerfile.service @@ -10,7 +10,7 @@ COPY uv.lock . RUN pip install --no-cache-dir uv RUN uv sync --frozen --no-install-project --no-dev -COPY src/agent/ ./agent/ +COPY src/agents/ ./agents/ COPY src/schema/ ./schema/ COPY src/service/ ./service/ COPY src/run_service.py . diff --git a/langgraph.json b/langgraph.json index 8a3888c..5699b04 100644 --- a/langgraph.json +++ b/langgraph.json @@ -2,7 +2,7 @@ "python_version": "3.12", "dependencies": ["."], "graphs": { - "research_assistant": "./src/agent/research_assistant.py:research_assistant" + "research_assistant": "./src/agents/research_assistant.py:research_assistant" }, "env": "./.env" } diff --git a/pyproject.toml b/pyproject.toml index 9d8a448..a303d25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,5 +53,10 @@ target-version = "py310" [tool.ruff.lint] extend-select = ["I", "U"] +[tool.pytest.ini_options] +pythonpath = [ + "src" +] + [tool.pytest_env] OPENAI_API_KEY = "sk-fake-openai-key" diff --git a/src/agent/__init__.py b/src/agent/__init__.py deleted file mode 100644 index ba17529..0000000 --- a/src/agent/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from agent.research_assistant import research_assistant - -__all__ = ["research_assistant"] diff --git a/src/agents/__init__.py b/src/agents/__init__.py new file mode 100644 index 0000000..e7d276d --- /dev/null +++ b/src/agents/__init__.py @@ -0,0 +1,3 @@ +from agents.agents import DEFAULT_AGENT, agents + +__all__ = ["agents", "DEFAULT_AGENT"] diff --git a/src/agents/agents.py b/src/agents/agents.py new file mode 100644 index 0000000..0d22601 --- /dev/null +++ b/src/agents/agents.py @@ -0,0 +1,12 @@ +from langgraph.graph.state import CompiledStateGraph + +from agents.chatbot import chatbot +from agents.research_assistant import research_assistant + +DEFAULT_AGENT = "research-assistant" + + +agents: dict[str, CompiledStateGraph] = { + "chatbot": chatbot, + "research-assistant": research_assistant, +} diff --git a/src/agents/chatbot.py b/src/agents/chatbot.py new file mode 100644 index 0000000..fb20efd --- /dev/null +++ b/src/agents/chatbot.py @@ -0,0 +1,44 @@ +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, MessagesState, StateGraph + +from agents.models import models + + +class AgentState(MessagesState, total=False): + """`total=False` is PEP589 specs. + + documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality + """ + + +def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: + preprocessor = RunnableLambda( + lambda state: state["messages"], + name="StateModifier", + ) + return preprocessor | model + + +async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: + m = models[config["configurable"].get("model", "gpt-4o-mini")] + model_runnable = wrap_model(m) + response = await model_runnable.ainvoke(state, config) + + # We return a list, because this will get added to the existing list + return {"messages": [response]} + + +# Define the graph +agent = StateGraph(AgentState) +agent.add_node("model", acall_model) +agent.set_entry_point("model") + +# Always END after blocking unsafe content +agent.add_edge("model", END) + +chatbot = agent.compile( + checkpointer=MemorySaver(), +) diff --git a/src/agent/llama_guard.py b/src/agents/llama_guard.py similarity index 100% rename from src/agent/llama_guard.py rename to src/agents/llama_guard.py diff --git a/src/agents/models.py b/src/agents/models.py new file mode 100644 index 0000000..1b9e3d3 --- /dev/null +++ b/src/agents/models.py @@ -0,0 +1,29 @@ +import os + +from langchain_anthropic import ChatAnthropic +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_groq import ChatGroq +from langchain_openai import ChatOpenAI + +# NOTE: models with streaming=True will send tokens as they are generated +# if the /stream endpoint is called with stream_tokens=True (the default) +models: dict[str, BaseChatModel] = {} +if os.getenv("OPENAI_API_KEY") is not None: + models["gpt-4o-mini"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.5, streaming=True) +if os.getenv("GROQ_API_KEY") is not None: + models["llama-3.1-70b"] = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5) +if os.getenv("GOOGLE_API_KEY") is not None: + models["gemini-1.5-flash"] = ChatGoogleGenerativeAI( + model="gemini-1.5-flash", temperature=0.5, streaming=True + ) +if os.getenv("ANTHROPIC_API_KEY") is not None: + models["claude-3-haiku"] = ChatAnthropic( + model="claude-3-haiku-20240307", temperature=0.5, streaming=True + ) + +if not models: + print("No LLM available. Please set API keys to enable at least one LLM.") + if os.getenv("MODE") == "dev": + print("FastAPI initialized failed. Please use Ctrl + C to exit uvicorn.") + exit(1) diff --git a/src/agent/research_assistant.py b/src/agents/research_assistant.py similarity index 70% rename from src/agent/research_assistant.py rename to src/agents/research_assistant.py index 29ce288..e91d6e1 100644 --- a/src/agent/research_assistant.py +++ b/src/agents/research_assistant.py @@ -2,21 +2,18 @@ from datetime import datetime from typing import Literal -from langchain_anthropic import ChatAnthropic from langchain_community.tools import DuckDuckGoSearchResults, OpenWeatherMapQueryRun from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, SystemMessage from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable -from langchain_google_genai import ChatGoogleGenerativeAI -from langchain_groq import ChatGroq -from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, MessagesState, StateGraph from langgraph.managed import IsLastStep from langgraph.prebuilt import ToolNode -from agent.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment -from agent.tools import calculator +from agents.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment +from agents.models import models +from agents.tools import calculator class AgentState(MessagesState, total=False): @@ -29,29 +26,6 @@ class AgentState(MessagesState, total=False): is_last_step: IsLastStep -# NOTE: models with streaming=True will send tokens as they are generated -# if the /stream endpoint is called with stream_tokens=True (the default) -models: dict[str, BaseChatModel] = {} -if os.getenv("OPENAI_API_KEY") is not None: - models["gpt-4o-mini"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.5, streaming=True) -if os.getenv("GROQ_API_KEY") is not None: - models["llama-3.1-70b"] = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5) -if os.getenv("GOOGLE_API_KEY") is not None: - models["gemini-1.5-flash"] = ChatGoogleGenerativeAI( - model="gemini-1.5-flash", temperature=0.5, streaming=True - ) -if os.getenv("ANTHROPIC_API_KEY") is not None: - models["claude-3-haiku"] = ChatAnthropic( - model="claude-3-haiku-20240307", temperature=0.5, streaming=True - ) - -if not models: - print("No LLM available. Please set API keys to enable at least one LLM.") - if os.getenv("MODE") == "dev": - print("FastAPI initialized failed. Please use Ctrl + C to exit uvicorn.") - exit(1) - - web_search = DuckDuckGoSearchResults(name="WebSearch") tools = [web_search, calculator] @@ -171,31 +145,3 @@ def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]: research_assistant = agent.compile( checkpointer=MemorySaver(), ) - - -if __name__ == "__main__": - import asyncio - from uuid import uuid4 - - from dotenv import load_dotenv - - load_dotenv() - - async def main() -> None: - inputs = {"messages": [("user", "Find me a recipe for chocolate chip cookies")]} - result = await research_assistant.ainvoke( - inputs, - config=RunnableConfig(configurable={"thread_id": uuid4()}), - ) - result["messages"][-1].pretty_print() - - # Draw the agent graph as png - # requires: - # brew install graphviz - # export CFLAGS="-I $(brew --prefix graphviz)/include" - # export LDFLAGS="-L $(brew --prefix graphviz)/lib" - # pip install pygraphviz - # - # research_assistant.get_graph().draw_png("agent_diagram.png") - - asyncio.run(main()) diff --git a/src/agent/tools.py b/src/agents/tools.py similarity index 100% rename from src/agent/tools.py rename to src/agents/tools.py diff --git a/src/client/client.py b/src/client/client.py index 3640237..96e7957 100644 --- a/src/client/client.py +++ b/src/client/client.py @@ -11,7 +11,12 @@ class AgentClient: """Client for interacting with the agent service.""" - def __init__(self, base_url: str = "http://localhost:80", timeout: float | None = None) -> None: + def __init__( + self, + base_url: str = "http://localhost:80", + agent: str = "research-assistant", + timeout: float | None = None, + ) -> None: """ Initialize the client. @@ -19,6 +24,7 @@ def __init__(self, base_url: str = "http://localhost:80", timeout: float | None base_url (str): The base URL of the agent service. """ self.base_url = base_url + self.agent = agent self.auth_secret = os.getenv("AUTH_SECRET") self.timeout = timeout @@ -50,7 +56,7 @@ async def ainvoke( request.model = model async with httpx.AsyncClient() as client: response = await client.post( - f"{self.base_url}/invoke", + f"{self.base_url}/{self.agent}/invoke", json=request.model_dump(), headers=self._headers, timeout=self.timeout, @@ -79,7 +85,7 @@ def invoke( if model: request.model = model response = httpx.post( - f"{self.base_url}/invoke", + f"{self.base_url}/{self.agent}/invoke", json=request.model_dump(), headers=self._headers, timeout=self.timeout, @@ -143,7 +149,7 @@ def stream( request.model = model with httpx.stream( "POST", - f"{self.base_url}/stream", + f"{self.base_url}/{self.agent}/stream", json=request.model_dump(), headers=self._headers, timeout=self.timeout, @@ -189,7 +195,7 @@ async def astream( async with httpx.AsyncClient() as client: async with client.stream( "POST", - f"{self.base_url}/stream", + f"{self.base_url}/{self.agent}/stream", json=request.model_dump(), headers=self._headers, timeout=self.timeout, diff --git a/src/run_agent.py b/src/run_agent.py new file mode 100644 index 0000000..d3d6bbd --- /dev/null +++ b/src/run_agent.py @@ -0,0 +1,32 @@ +import asyncio +from uuid import uuid4 + +from dotenv import load_dotenv +from langchain_core.runnables import RunnableConfig + +load_dotenv() + +from agents import DEFAULT_AGENT, agents # noqa: E402 + +agent = agents[DEFAULT_AGENT] + + +async def main() -> None: + inputs = {"messages": [("user", "Find me a recipe for chocolate chip cookies")]} + result = await agent.ainvoke( + inputs, + config=RunnableConfig(configurable={"thread_id": uuid4()}), + ) + result["messages"][-1].pretty_print() + + # Draw the agent graph as png + # requires: + # brew install graphviz + # export CFLAGS="-I $(brew --prefix graphviz)/include" + # export LDFLAGS="-L $(brew --prefix graphviz)/lib" + # pip install pygraphviz + # + # agent.get_graph().draw_png("agent_diagram.png") + + +asyncio.run(main()) diff --git a/src/service/service.py b/src/service/service.py index 599918d..06cfea1 100644 --- a/src/service/service.py +++ b/src/service/service.py @@ -17,7 +17,7 @@ from langgraph.graph.state import CompiledStateGraph from langsmith import Client as LangsmithClient -from agent import research_assistant +from agents import DEFAULT_AGENT, agents from schema import ( ChatHistory, ChatHistoryInput, @@ -53,9 +53,10 @@ def verify_bearer( @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # Construct agent with Sqlite checkpointer + # TODO: It's probably dangerous to share the same checkpointer on multiple agents async with AsyncSqliteSaver.from_conn_string("checkpoints.db") as saver: - research_assistant.checkpointer = saver - app.state.agent = research_assistant + for a in agents.values(): + a.checkpointer = saver yield # context manager will clean up the AsyncSqliteSaver on exit @@ -76,15 +77,8 @@ def _parse_input(user_input: UserInput) -> tuple[dict[str, Any], str]: return kwargs, run_id -@router.post("/invoke") -async def invoke(user_input: UserInput) -> ChatMessage: - """ - Invoke the agent with user input to retrieve a final response. - - Use thread_id to persist and continue a multi-turn conversation. run_id kwarg - is also attached to messages for recording feedback. - """ - agent: CompiledStateGraph = app.state.agent +async def ainvoke(user_input: UserInput, agent_id: str = DEFAULT_AGENT) -> ChatMessage: + agent: CompiledStateGraph = agents[agent_id] kwargs, run_id = _parse_input(user_input) try: response = await agent.ainvoke(**kwargs) @@ -96,13 +90,37 @@ async def invoke(user_input: UserInput) -> ChatMessage: raise HTTPException(status_code=500, detail="Unexpected error") -async def message_generator(user_input: StreamInput) -> AsyncGenerator[str, None]: +@router.post("/invoke") +async def invoke(user_input: UserInput) -> ChatMessage: + """ + Invoke the default agent with user input to retrieve a final response. + + Use thread_id to persist and continue a multi-turn conversation. run_id kwarg + is also attached to messages for recording feedback. + """ + return await ainvoke(user_input=user_input) + + +@router.post("/{agent_id}/invoke") +async def agent_invoke(user_input: UserInput, agent_id: str) -> ChatMessage: + """ + Invoke an agent with user input to retrieve a final response. + + Use thread_id to persist and continue a multi-turn conversation. run_id kwarg + is also attached to messages for recording feedback. + """ + return await ainvoke(user_input=user_input, agent_id=agent_id) + + +async def message_generator( + user_input: StreamInput, agent_id: str = DEFAULT_AGENT +) -> AsyncGenerator[str, None]: """ Generate a stream of messages from the agent. This is the workhorse method for the /stream endpoint. """ - agent: CompiledStateGraph = app.state.agent + agent: CompiledStateGraph = agents[agent_id] kwargs, run_id = _parse_input(user_input) # Process streamed events from the graph and yield messages over the SSE stream. @@ -166,7 +184,7 @@ def _sse_response_example() -> dict[int, Any]: @router.post("/stream", response_class=StreamingResponse, responses=_sse_response_example()) async def stream(user_input: StreamInput) -> StreamingResponse: """ - Stream the agent's response to a user input, including intermediate messages and tokens. + Stream the default agent's response to a user input, including intermediate messages and tokens. Use thread_id to persist and continue a multi-turn conversation. run_id kwarg is also attached to all messages for recording feedback. @@ -176,6 +194,23 @@ async def stream(user_input: StreamInput) -> StreamingResponse: return StreamingResponse(message_generator(user_input), media_type="text/event-stream") +@router.post( + "/{agent_id}/stream", response_class=StreamingResponse, responses=_sse_response_example() +) +async def agent_stream(user_input: StreamInput, agent_id: str) -> StreamingResponse: + """ + Stream an agent's response to a user input, including intermediate messages and tokens. + + Use thread_id to persist and continue a multi-turn conversation. run_id kwarg + is also attached to all messages for recording feedback. + + Set `stream_tokens=false` to return intermediate messages but not token-by-token. + """ + return StreamingResponse( + message_generator(user_input, agent_id=agent_id), media_type="text/event-stream" + ) + + @router.post("/feedback") async def feedback(feedback: Feedback) -> FeedbackResponse: """ @@ -201,7 +236,8 @@ def history(input: ChatHistoryInput) -> ChatHistory: """ Get chat history. """ - agent: CompiledStateGraph = app.state.agent + # TODO: Hard-coding DEFAULT_AGENT here is wonky + agent: CompiledStateGraph = agents[DEFAULT_AGENT] try: state_snapshot = agent.get_state( config=RunnableConfig( diff --git a/src/service/test_service.py b/tests/service/test_service.py similarity index 67% rename from src/service/test_service.py rename to tests/service/test_service.py index 8c5f222..1c8bddf 100644 --- a/src/service/test_service.py +++ b/tests/service/test_service.py @@ -3,28 +3,29 @@ import langsmith from fastapi.testclient import TestClient from langchain_core.messages import AIMessage, HumanMessage -from langgraph.graph.state import CompiledStateGraph from langgraph.pregel.types import StateSnapshot +from agents import DEFAULT_AGENT from schema import ChatHistory, ChatMessage from service import app -client = TestClient(app) +test_client = TestClient(app) -@patch("service.service.research_assistant") -def test_invoke(mock_agent: CompiledStateGraph) -> None: +def test_invoke() -> None: QUESTION = "What is the weather in Tokyo?" ANSWER = "The weather in Tokyo is 70 degrees." agent_response = {"messages": [AIMessage(content=ANSWER)]} - mock_agent.ainvoke = AsyncMock(return_value=agent_response) + agent_mock = AsyncMock() + agent_mock.ainvoke = AsyncMock(return_value=agent_response) - with client as c: - response = c.post("/invoke", json={"message": QUESTION}) - assert response.status_code == 200 + with patch.dict("service.service.agents", {DEFAULT_AGENT: agent_mock}): + with test_client as c: + response = c.post("/invoke", json={"message": QUESTION}) + assert response.status_code == 200 - mock_agent.ainvoke.assert_awaited_once() - input_message = mock_agent.ainvoke.await_args.kwargs["input"]["messages"][0] + agent_mock.ainvoke.assert_awaited_once() + input_message = agent_mock.ainvoke.await_args.kwargs["input"]["messages"][0] assert input_message.content == QUESTION output = ChatMessage.model_validate(response.json()) @@ -41,7 +42,7 @@ def test_feedback(mock_client: langsmith.Client) -> None: "key": "human-feedback-stars", "score": 0.8, } - response = client.post("/feedback", json=body) + response = test_client.post("/feedback", json=body) assert response.status_code == 200 assert response.json() == {"status": "success"} ls_instance.create_feedback.assert_called_once_with( @@ -51,13 +52,13 @@ def test_feedback(mock_client: langsmith.Client) -> None: ) -@patch("service.service.research_assistant") -def test_history(mock_agent: CompiledStateGraph) -> None: +def test_history() -> None: QUESTION = "What is the weather in Tokyo?" ANSWER = "The weather in Tokyo is 70 degrees." user_question = HumanMessage(content=QUESTION) agent_response = AIMessage(content=ANSWER) - mock_agent.get_state = Mock( + agent_mock = AsyncMock() + agent_mock.get_state = Mock( return_value=StateSnapshot( values={"messages": [user_question, agent_response]}, next=(), @@ -69,9 +70,12 @@ def test_history(mock_agent: CompiledStateGraph) -> None: ) ) - with client as c: - response = c.post("/history", json={"thread_id": "7bcc7cc1-99d7-4b1d-bdb5-e6f90ed44de6"}) - assert response.status_code == 200 + with patch.dict("service.service.agents", {DEFAULT_AGENT: agent_mock}): + with test_client as c: + response = c.post( + "/history", json={"thread_id": "7bcc7cc1-99d7-4b1d-bdb5-e6f90ed44de6"} + ) + assert response.status_code == 200 output = ChatHistory.model_validate(response.json()) assert output.messages[0].type == "human" diff --git a/src/service/test_utils.py b/tests/service/test_utils.py similarity index 100% rename from src/service/test_utils.py rename to tests/service/test_utils.py