Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for calling multiple agents on one service #75

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ docker compose watch
1. **Advanced Streaming**: A novel approach to support both token-based and message-based streaming.
1. **Content Moderation**: Implements LlamaGuard for content moderation (requires Groq API key).
1. **Streamlit Interface**: Provides a user-friendly chat interface for interacting with the agent.
1. **Multiple Agent Support**: Run multiple agents in the service and call by URL path
1. **Asynchronous Design**: Utilizes async/await for efficient handling of concurrent requests.
1. **Feedback Mechanism**: Includes a star-based feedback system integrated with LangSmith.
1. **Docker Support**: Includes Dockerfiles and a docker compose file for easy development and deployment.
Expand All @@ -60,10 +61,12 @@ docker compose watch

The repository is structured as follows:

- `src/agent/research_assistant.py`: Defines the LangGraph agent
- `src/agent/llama_guard.py`: Defines the LlamaGuard content moderation
- `src/schema/schema.py`: Defines the service schema
- `src/service/service.py`: FastAPI service to serve the agent
- `src/agents/research_assistant.py`: Defines the main LangGraph agent
- `src/agents/llama_guard.py`: Defines the LlamaGuard content moderation
- `src/agents/models.py`: Configures available models based on ENV
- `src/agents/agents.py`: Mapping of all agents provided by the service
- `src/schema/schema.py`: Defines the protocol schema
- `src/service/service.py`: FastAPI service to serve the agents
- `src/client/client.py`: Client to interact with the agent service
- `src/streamlit_app.py`: Streamlit app providing a chat interface

Expand Down Expand Up @@ -208,8 +211,9 @@ Currently the tests need to be run using the local development without Docker se

To customize the agent for your own use case:

1. Modify the `src/agent/research_assistant.py` file to change the agent's behavior and tools. Or, build a new agent from scratch.
2. Adjust the Streamlit interface in `src/streamlit_app.py` to match your agent's capabilities.
1. Add your new agent to the `src/agents` directory. You can copy `research_assistant.py` or `chatbot.py` and modify it to change the agent's behavior and tools.
1. Import and add your new agent to the `agents` dictionary in `src/agents/agents.py`. Your agent can be called by `/<your_agent_name>/invoke` or `/<your_agent_name>/stream`.
1. Adjust the Streamlit interface in `src/streamlit_app.py` to match your agent's capabilities.

## Building other apps on the AgentClient

Expand Down Expand Up @@ -239,7 +243,7 @@ Contributions are welcome! Please feel free to submit a Pull Request.
- [x] Get LlamaGuard working for content moderation (anyone know a reliable and fast hosted version?)
- [x] Add more sophisticated tools for the research assistant
- [x] Increase test coverage and add CI pipeline
- [ ] Add support for multiple agents running on the same service, including non-chat agent
- [x] Add support for multiple agents running on the same service, including non-chat agent
- [ ] Deployment instructions and configuration for cloud providers
- [ ] More ideas? File an issue or create a discussion!

Expand Down
4 changes: 2 additions & 2 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ services:
- .env
develop:
watch:
- path: src/agent/
- path: src/agents/
action: sync+restart
target: /app/agent/
target: /app/agents/
- path: src/schema/
action: sync+restart
target: /app/schema/
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.service
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ COPY uv.lock .
RUN pip install --no-cache-dir uv
RUN uv sync --frozen --no-install-project --no-dev

COPY src/agent/ ./agent/
COPY src/agents/ ./agents/
COPY src/schema/ ./schema/
COPY src/service/ ./service/
COPY src/run_service.py .
Expand Down
2 changes: 1 addition & 1 deletion langgraph.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"python_version": "3.12",
"dependencies": ["."],
"graphs": {
"research_assistant": "./src/agent/research_assistant.py:research_assistant"
"research_assistant": "./src/agents/research_assistant.py:research_assistant"
},
"env": "./.env"
}
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,10 @@ target-version = "py310"
[tool.ruff.lint]
extend-select = ["I", "U"]

[tool.pytest.ini_options]
pythonpath = [
"src"
]

[tool.pytest_env]
OPENAI_API_KEY = "sk-fake-openai-key"
3 changes: 0 additions & 3 deletions src/agent/__init__.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from agents.agents import DEFAULT_AGENT, agents

__all__ = ["agents", "DEFAULT_AGENT"]
12 changes: 12 additions & 0 deletions src/agents/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from langgraph.graph.state import CompiledStateGraph

from agents.chatbot import chatbot
from agents.research_assistant import research_assistant

DEFAULT_AGENT = "research-assistant"


agents: dict[str, CompiledStateGraph] = {
"chatbot": chatbot,
"research-assistant": research_assistant,
}
44 changes: 44 additions & 0 deletions src/agents/chatbot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph

from agents.models import models


class AgentState(MessagesState, total=False):
"""`total=False` is PEP589 specs.

documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
"""


def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
preprocessor = RunnableLambda(
lambda state: state["messages"],
name="StateModifier",
)
return preprocessor | model


async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
m = models[config["configurable"].get("model", "gpt-4o-mini")]
model_runnable = wrap_model(m)
response = await model_runnable.ainvoke(state, config)

# We return a list, because this will get added to the existing list
return {"messages": [response]}


# Define the graph
agent = StateGraph(AgentState)
agent.add_node("model", acall_model)
agent.set_entry_point("model")

# Always END after blocking unsafe content
agent.add_edge("model", END)

chatbot = agent.compile(
checkpointer=MemorySaver(),
)
File renamed without changes.
29 changes: 29 additions & 0 deletions src/agents/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

from langchain_anthropic import ChatAnthropic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI

# NOTE: models with streaming=True will send tokens as they are generated
# if the /stream endpoint is called with stream_tokens=True (the default)
models: dict[str, BaseChatModel] = {}
if os.getenv("OPENAI_API_KEY") is not None:
models["gpt-4o-mini"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.5, streaming=True)
if os.getenv("GROQ_API_KEY") is not None:
models["llama-3.1-70b"] = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5)
if os.getenv("GOOGLE_API_KEY") is not None:
models["gemini-1.5-flash"] = ChatGoogleGenerativeAI(
model="gemini-1.5-flash", temperature=0.5, streaming=True
)
if os.getenv("ANTHROPIC_API_KEY") is not None:
models["claude-3-haiku"] = ChatAnthropic(
model="claude-3-haiku-20240307", temperature=0.5, streaming=True
)

if not models:
print("No LLM available. Please set API keys to enable at least one LLM.")
if os.getenv("MODE") == "dev":
print("FastAPI initialized failed. Please use Ctrl + C to exit uvicorn.")
exit(1)
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@
from datetime import datetime
from typing import Literal

from langchain_anthropic import ChatAnthropic
from langchain_community.tools import DuckDuckGoSearchResults, OpenWeatherMapQueryRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.managed import IsLastStep
from langgraph.prebuilt import ToolNode

from agent.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment
from agent.tools import calculator
from agents.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment
from agents.models import models
from agents.tools import calculator


class AgentState(MessagesState, total=False):
Expand All @@ -29,29 +26,6 @@ class AgentState(MessagesState, total=False):
is_last_step: IsLastStep


# NOTE: models with streaming=True will send tokens as they are generated
# if the /stream endpoint is called with stream_tokens=True (the default)
models: dict[str, BaseChatModel] = {}
if os.getenv("OPENAI_API_KEY") is not None:
models["gpt-4o-mini"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.5, streaming=True)
if os.getenv("GROQ_API_KEY") is not None:
models["llama-3.1-70b"] = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5)
if os.getenv("GOOGLE_API_KEY") is not None:
models["gemini-1.5-flash"] = ChatGoogleGenerativeAI(
model="gemini-1.5-flash", temperature=0.5, streaming=True
)
if os.getenv("ANTHROPIC_API_KEY") is not None:
models["claude-3-haiku"] = ChatAnthropic(
model="claude-3-haiku-20240307", temperature=0.5, streaming=True
)

if not models:
print("No LLM available. Please set API keys to enable at least one LLM.")
if os.getenv("MODE") == "dev":
print("FastAPI initialized failed. Please use Ctrl + C to exit uvicorn.")
exit(1)


web_search = DuckDuckGoSearchResults(name="WebSearch")
tools = [web_search, calculator]

Expand Down Expand Up @@ -171,31 +145,3 @@ def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]:
research_assistant = agent.compile(
checkpointer=MemorySaver(),
)


if __name__ == "__main__":
import asyncio
from uuid import uuid4

from dotenv import load_dotenv

load_dotenv()

async def main() -> None:
inputs = {"messages": [("user", "Find me a recipe for chocolate chip cookies")]}
result = await research_assistant.ainvoke(
inputs,
config=RunnableConfig(configurable={"thread_id": uuid4()}),
)
result["messages"][-1].pretty_print()

# Draw the agent graph as png
# requires:
# brew install graphviz
# export CFLAGS="-I $(brew --prefix graphviz)/include"
# export LDFLAGS="-L $(brew --prefix graphviz)/lib"
# pip install pygraphviz
#
# research_assistant.get_graph().draw_png("agent_diagram.png")

asyncio.run(main())
File renamed without changes.
16 changes: 11 additions & 5 deletions src/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
class AgentClient:
"""Client for interacting with the agent service."""

def __init__(self, base_url: str = "http://localhost:80", timeout: float | None = None) -> None:
def __init__(
self,
base_url: str = "http://localhost:80",
agent: str = "research-assistant",
timeout: float | None = None,
) -> None:
"""
Initialize the client.

Args:
base_url (str): The base URL of the agent service.
"""
self.base_url = base_url
self.agent = agent
self.auth_secret = os.getenv("AUTH_SECRET")
self.timeout = timeout

Expand Down Expand Up @@ -50,7 +56,7 @@ async def ainvoke(
request.model = model
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/invoke",
f"{self.base_url}/{self.agent}/invoke",
json=request.model_dump(),
headers=self._headers,
timeout=self.timeout,
Expand Down Expand Up @@ -79,7 +85,7 @@ def invoke(
if model:
request.model = model
response = httpx.post(
f"{self.base_url}/invoke",
f"{self.base_url}/{self.agent}/invoke",
json=request.model_dump(),
headers=self._headers,
timeout=self.timeout,
Expand Down Expand Up @@ -143,7 +149,7 @@ def stream(
request.model = model
with httpx.stream(
"POST",
f"{self.base_url}/stream",
f"{self.base_url}/{self.agent}/stream",
json=request.model_dump(),
headers=self._headers,
timeout=self.timeout,
Expand Down Expand Up @@ -189,7 +195,7 @@ async def astream(
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/stream",
f"{self.base_url}/{self.agent}/stream",
json=request.model_dump(),
headers=self._headers,
timeout=self.timeout,
Expand Down
32 changes: 32 additions & 0 deletions src/run_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import asyncio
from uuid import uuid4

from dotenv import load_dotenv
from langchain_core.runnables import RunnableConfig

load_dotenv()

from agents import DEFAULT_AGENT, agents # noqa: E402

agent = agents[DEFAULT_AGENT]


async def main() -> None:
inputs = {"messages": [("user", "Find me a recipe for chocolate chip cookies")]}
result = await agent.ainvoke(
inputs,
config=RunnableConfig(configurable={"thread_id": uuid4()}),
)
result["messages"][-1].pretty_print()

# Draw the agent graph as png
# requires:
# brew install graphviz
# export CFLAGS="-I $(brew --prefix graphviz)/include"
# export LDFLAGS="-L $(brew --prefix graphviz)/lib"
# pip install pygraphviz
#
# agent.get_graph().draw_png("agent_diagram.png")


asyncio.run(main())
Loading
Loading