Skip to content

Commit

Permalink
Add linting and format with ruff (#18)
Browse files Browse the repository at this point in the history
* Install ruff and run linting that way

* Run ruff format

* Fix ruff check

* Fix broken test

* Add pre-commit hooks

* Add pre-commit config (lol)
  • Loading branch information
JoshuaC215 authored Sep 2, 2024
1 parent 8f09eec commit 1cec789
Show file tree
Hide file tree
Showing 23 changed files with 249 additions and 122 deletions.
2 changes: 1 addition & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ __pycache__
env
venv
.venv
*.db
*.db
13 changes: 5 additions & 8 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,14 @@ jobs:
curl -LsSf https://astral.sh/uv/0.3.2/install.sh | sh
- name: Install dependencies with uv
run: |
uv pip install flake8
uv pip install -r requirements.txt
uv pip install -r test-requirements.txt
uv pip install -r pyproject.toml --extra dev
env:
UV_SYSTEM_PYTHON: 1
- name: Lint with flake8
- name: Lint and format with ruff
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
ruff format --check
ruff check --output-format github
- name: Test with pytest
run: |
pytest
Expand Down
13 changes: 13 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.3
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
SOFTWARE.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This project offers a template for you to easily build and run your own agents u
### [Try the app!](https://agent-service-toolkit.streamlit.app/)

[![Streamlit App](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://agent-service-toolkit.streamlit.app/)


<a href="https://agent-service-toolkit.streamlit.app/"><img src="media/app_screenshot.png" width="600"></a>

Expand Down Expand Up @@ -89,7 +89,7 @@ With that said, there are several other interesting projects in this space that
# Optional, to enable simple header-based auth on the service
AUTH_SECRET=any_string_you_choose
# Optional, to enable LangSmith tracing
LANGCHAIN_TRACING_V2=true
LANGCHAIN_ENDPOINT=https://api.smith.langchain.com
Expand Down Expand Up @@ -153,15 +153,17 @@ The agent supports [LangGraph Studio](https://github.com/langchain-ai/langgraph-

You can simply install LangGraph Studio, add your `.env` file to the root directory as described above, and then launch LangGraph studio pointed at the `agent/` directory. Customize `agent/langgraph.json` as needed.

### Running Tests
### Contributing

Currently the tests need to be run using the local development without Docker setup. To run the tests for the agent service:

1. Ensure you're in the project root directory and have activated your virtual environment.

2. Install the test dependencies:
2. Install the development dependencies and pre-commit hooks:
```
pip install -r test-requirements.txt
pip install uv
uv pip install -r pyproject.toml --extra dev
pre-commit install
```

3. Run the tests using pytest:
Expand Down
4 changes: 1 addition & 3 deletions agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from agent.research_assistant import research_assistant

__all__ = [
"research_assistant"
]
__all__ = ["research_assistant"]
2 changes: 1 addition & 1 deletion agent/langgraph.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
"research_assistant": "./research_assistant.py:research_assistant"
},
"env": "../.env"
}
}
30 changes: 19 additions & 11 deletions agent/llama_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class SafetyAssessment(Enum):

class LlamaGuardOutput(BaseModel):
safety_assessment: SafetyAssessment = Field(description="The safety assessment of the content.")
unsafe_categories: List[str] = Field(description="If content is unsafe, the list of unsafe categories.", default=[])
unsafe_categories: List[str] = Field(
description="If content is unsafe, the list of unsafe categories.", default=[]
)


unsafe_content_categories = {
Expand All @@ -31,7 +33,7 @@ class LlamaGuardOutput(BaseModel):
"S11": "Self-Harm.",
"S12": "Sexual Content.",
"S13": "Elections.",
"S14": "Code Interpreter Abuse."
"S14": "Code Interpreter Abuse.",
}

categories_str = "\n".join([f"{k}: {v}" for k, v in unsafe_content_categories.items()])
Expand Down Expand Up @@ -70,9 +72,7 @@ def parse_llama_guard_output(output: str) -> LlamaGuardOutput:
return LlamaGuardOutput(safety_assessment=SafetyAssessment.ERROR)
try:
categories = parsed_output[1].split(",")
readable_categories = [
unsafe_content_categories[c.strip()].strip(".") for c in categories
]
readable_categories = [unsafe_content_categories[c.strip()].strip(".") for c in categories]
return LlamaGuardOutput(
safety_assessment=SafetyAssessment.UNSAFE,
unsafe_categories=readable_categories,
Expand All @@ -83,9 +83,13 @@ def parse_llama_guard_output(output: str) -> LlamaGuardOutput:

async def llama_guard(role: str, messages: List[AnyMessage]) -> LlamaGuardOutput:
role_mapping = {"ai": "Agent", "human": "User"}
messages_str = [f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"]]
messages_str = [
f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"]
]
conversation_history = "\n\n".join(messages_str)
compiled_prompt = llama_guard_prompt.format(role=role, conversation_history=conversation_history)
compiled_prompt = llama_guard_prompt.format(
role=role, conversation_history=conversation_history
)
result = await model.ainvoke([SystemMessage(content=compiled_prompt)])
return parse_llama_guard_output(result.content)

Expand All @@ -94,9 +98,13 @@ async def llama_guard(role: str, messages: List[AnyMessage]) -> LlamaGuardOutput
import asyncio

async def main():
output = await llama_guard("Agent", [
HumanMessage(content="Tell me a fun fact?"),
AIMessage(content="Did you know that honey never spoils?"),
])
output = await llama_guard(
"Agent",
[
HumanMessage(content="Tell me a fun fact?"),
AIMessage(content="Did you know that honey never spoils?"),
],
)
print(output)

asyncio.run(main())
25 changes: 17 additions & 8 deletions agent/research_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,28 @@
from langgraph.managed import IsLastStep
from langgraph.prebuilt import ToolNode

from agent.tools import arxiv_search, calculator, web_search
from agent.tools import calculator, web_search
from agent.llama_guard import llama_guard, LlamaGuardOutput


class AgentState(MessagesState):
safety: LlamaGuardOutput
is_last_step: IsLastStep


# NOTE: models with streaming=True will send tokens as they are generated
# if the /stream endpoint is called with stream_tokens=True (the default)
models = {
"gpt-4o-mini": ChatOpenAI(model="gpt-4o-mini", temperature=0.5, streaming=True),
"llama-3.1-70b": ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5)
"llama-3.1-70b": ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5),
}

tools = [web_search, calculator]
current_date = datetime.now().strftime("%B %d, %Y")
instructions = f"""
You are a helpful research assistant with the ability to search the web for information.
Today's date is {current_date}.
NOTE: THE USER CAN'T SEE THE TOOL RESPONSE.
A few things to remember:
Expand All @@ -39,6 +40,7 @@ class AgentState(MessagesState):
so for the final response, use human readable format - e.g. "300 * 200", not "(300 \\times 200)".
"""


def wrap_model(model: BaseChatModel):
model = model.bind_tools(tools)
preprocessor = RunnableLambda(
Expand All @@ -47,6 +49,7 @@ def wrap_model(model: BaseChatModel):
)
return preprocessor | model


async def acall_model(state: AgentState, config: RunnableConfig):
m = models[config["configurable"].get("model", "gpt-4o-mini")]
model_runnable = wrap_model(m)
Expand All @@ -68,6 +71,7 @@ async def llama_guard_input(state: AgentState, config: RunnableConfig):
safety_output = await llama_guard("User", state["messages"])
return {"safety": safety_output}


async def block_unsafe_content(state: AgentState, config: RunnableConfig):
safety: LlamaGuardOutput = state["safety"]
output_messages = []
Expand All @@ -77,10 +81,13 @@ async def block_unsafe_content(state: AgentState, config: RunnableConfig):
if last_message.type == "ai":
output_messages.append(RemoveMessage(id=last_message.id))

content_warning = f"This conversation was flagged for unsafe content: {', '.join(safety.unsafe_categories)}"
content_warning = (
f"This conversation was flagged for unsafe content: {', '.join(safety.unsafe_categories)}"
)
output_messages.append(AIMessage(content=content_warning))
return {"messages": output_messages}


# Define the graph
agent = StateGraph(AgentState)
agent.add_node("model", acall_model)
Expand All @@ -105,18 +112,21 @@ async def block_unsafe_content(state: AgentState, config: RunnableConfig):
# )

# Always END after blocking unsafe content
#agent.add_edge("block_unsafe_content", END)
# agent.add_edge("block_unsafe_content", END)

# Always run "model" after "tools"
agent.add_edge("tools", "model")


# After "model", if there are tool calls, run "tools". Otherwise END.
def pending_tool_calls(state: AgentState):
last_message = state["messages"][-1]
if last_message.tool_calls:
return "tools"
else:
return END


agent.add_conditional_edges("model", pending_tool_calls, {"tools": "tools", END: END})

research_assistant = agent.compile(
Expand All @@ -130,7 +140,7 @@ def pending_tool_calls(state: AgentState):
from dotenv import load_dotenv

load_dotenv()

async def main():
inputs = {"messages": [("user", "Find me a recipe for chocolate chip cookies")]}
result = await research_assistant.ainvoke(
Expand All @@ -145,8 +155,7 @@ async def main():
# export CFLAGS="-I $(brew --prefix graphviz)/include"
# export LDFLAGS="-L $(brew --prefix graphviz)/lib"
# pip install pygraphviz
#
#
# researcH_assistant.get_graph().draw_png("agent_diagram.png")


asyncio.run(main())
7 changes: 3 additions & 4 deletions agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
import numexpr
import re
from langchain_core.tools import tool, BaseTool
from langchain_community.tools import DuckDuckGoSearchResults, ArxivQueryRun
from langchain_community.tools import DuckDuckGoSearchResults

web_search = DuckDuckGoSearchResults(name="WebSearch")

# Kinda busted since it doesn't return links
arxiv_search = ArxivQueryRun(name="ArxivSearch")

def calculator_func(expression: str) -> str:
"""Calculates a math expression using numexpr.
Useful for when you need to answer questions about math using numexpr.
This tool is only for math questions and nothing else. Only input
math expressions.
Expand Down Expand Up @@ -39,5 +37,6 @@ def calculator_func(expression: str) -> str:
" Please try again with a valid numerical expression"
)


calculator: BaseTool = tool(calculator_func)
calculator.name = "Calculator"
Loading

0 comments on commit 1cec789

Please sign in to comment.