Skip to content

Commit

Permalink
Make get_tools sync, add separate async initialize.
Browse files Browse the repository at this point in the history
  • Loading branch information
rectalogic committed Nov 28, 2024
1 parent 94fb24e commit be403f0
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 35 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dev = [
"ruff~=0.8.0",
"mypy~=1.13.0",
"typing-extensions~=4.12.2",
"langchain-groq~=0.2.1",
]

[project.urls]
Expand Down Expand Up @@ -81,4 +82,4 @@ warn_unused_ignores = true
strict_equality = true
no_implicit_optional = true
show_error_codes = true
files = "src/**/*.py"
files = ["src/**/*.py", "tests/demo.py"]
28 changes: 17 additions & 11 deletions src/langchain_mcp/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pydantic_core
import typing_extensions as t
from langchain_core.tools.base import BaseTool, BaseToolkit, ToolException
from mcp import ClientSession
from mcp import ClientSession, ListToolsResult


class MCPToolkit(BaseToolkit):
Expand All @@ -20,25 +20,30 @@ class MCPToolkit(BaseToolkit):
session: ClientSession
"""The MCP session used to obtain the tools"""

_initialized: bool = False
_tools: ListToolsResult | None = None

model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)

@t.override
async def get_tools(self) -> list[BaseTool]: # type: ignore[override]
if not self._initialized:
async def initialize(self) -> None:
"""Initialize the session and retrieve tools list"""
if self._tools is None:
await self.session.initialize()
self._initialized = True
self._tools = await self.session.list_tools()

@t.override
def get_tools(self) -> list[BaseTool]:
if self._tools is None:
raise RuntimeError("Must initialize the toolkit first")

return [
MCPTool(
toolkit=self,
session=self.session,
name=tool.name,
description=tool.description or "",
args_schema=create_schema_model(tool.inputSchema),
)
# list_tools returns a PaginatedResult, but I don't see a way to pass the cursor to retrieve more tools
for tool in (await self.session.list_tools()).tools
for tool in self._tools.tools
]


Expand Down Expand Up @@ -67,19 +72,20 @@ class MCPTool(BaseTool):
MCP server tool
"""

toolkit: MCPToolkit
session: ClientSession
handle_tool_error: bool | str | Callable[[ToolException], str] | None = True

@t.override
def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
warnings.warn(
"Invoke this tool asynchronousely using `ainvoke`. This method exists only to satisfy tests.", stacklevel=1
"Invoke this tool asynchronousely using `ainvoke`. This method exists only to satisfy standard tests.",
stacklevel=1,
)
return asyncio.run(self._arun(*args, **kwargs))

@t.override
async def _arun(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
result = await self.toolkit.session.call_tool(self.name, arguments=kwargs)
result = await self.session.call_tool(self.name, arguments=kwargs)
content = pydantic_core.to_json(result.content).decode()
if result.isError:
raise ToolException(content)
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def mcptoolkit(request):

@pytest.fixture(scope="class")
async def mcptool(request, mcptoolkit):
tool = (await mcptoolkit.get_tools())[0]
await mcptoolkit.initialize()
tool = mcptoolkit.get_tools()[0]
request.cls.tool = tool
yield tool
42 changes: 20 additions & 22 deletions tests/demo.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,46 @@
# Copyright (C) 2024 Andrew Wason
# SPDX-License-Identifier: MIT

# /// script
# requires-python = ">=3.10"
# dependencies = [
# "langchain-mcp",
# "langchain-groq",
# ]
# ///


import asyncio
import pathlib
import sys
import typing as t

from langchain_core.messages import HumanMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.tools import BaseTool
from langchain_groq import ChatGroq
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

from langchain_mcp import MCPToolkit


async def run(tools: list[BaseTool], prompt: str) -> str:
model = ChatGroq(model="llama-3.1-8b-instant", stop_sequences=None) # requires GROQ_API_KEY
tools_map = {tool.name: tool for tool in tools}
tools_model = model.bind_tools(tools)
messages: list[BaseMessage] = [HumanMessage(prompt)]
ai_message = t.cast(AIMessage, await tools_model.ainvoke(messages))
messages.append(ai_message)
for tool_call in ai_message.tool_calls:
selected_tool = tools_map[tool_call["name"].lower()]
tool_msg = await selected_tool.ainvoke(tool_call)
messages.append(tool_msg)
return await (tools_model | StrOutputParser()).ainvoke(messages)


async def main(prompt: str) -> None:
model = ChatGroq(model="llama-3.1-8b-instant") # requires GROQ_API_KEY
server_params = StdioServerParameters(
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", str(pathlib.Path(__file__).parent.parent)],
)
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
toolkit = MCPToolkit(session=session)
tools = await toolkit.get_tools()
tools_map = {tool.name: tool for tool in tools}
tools_model = model.bind_tools(tools)
messages = [HumanMessage(prompt)]
messages.append(await tools_model.ainvoke(messages))
for tool_call in messages[-1].tool_calls:
selected_tool = tools_map[tool_call["name"].lower()]
tool_msg = await selected_tool.ainvoke(tool_call)
messages.append(tool_msg)
result = await (tools_model | StrOutputParser()).ainvoke(messages)
print(result)
await toolkit.initialize()
response = await run(toolkit.get_tools(), prompt)
print(response)


if __name__ == "__main__":
Expand Down
41 changes: 41 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit be403f0

Please sign in to comment.