Skip to content

Commit

Permalink
feat: add support for client-tools (Python SDK)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisjoecodes committed Dec 17, 2024
1 parent 97d77ed commit 57f85b6
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 10 deletions.
147 changes: 137 additions & 10 deletions src/elevenlabs/conversational_ai/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import base64
import json
import threading
from typing import Callable, Optional
from typing import Callable, Optional, Awaitable, Union, Any
import asyncio
from concurrent.futures import ThreadPoolExecutor

from websockets.sync.client import connect

Expand Down Expand Up @@ -52,22 +54,133 @@ def interrupt(self):
"""
pass


class ClientTools:
"""Handles registration and execution of client-side tools that can be called by the agent.
Supports both synchronous and asynchronous tools running in a dedicated event loop,
ensuring non-blocking operation of the main conversation thread.
"""

def __init__(self):
self.tools: dict[str, tuple[Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]], bool]] = {}
self.lock = threading.Lock()
self._loop = None
self._thread = None
self._running = threading.Event()
self.thread_pool = ThreadPoolExecutor()

def start(self):
"""Start the event loop in a separate thread for handling async operations."""
if self._running.is_set():
return

def run_event_loop():
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._running.set()
try:
self._loop.run_forever()
finally:
self._running.clear()
self._loop.close()
self._loop = None

self._thread = threading.Thread(target=run_event_loop, daemon=True, name="ClientTools-EventLoop")
self._thread.start()
# Wait for loop to be ready
self._running.wait()

def stop(self):
"""Gracefully stop the event loop and clean up resources."""
if self._loop and self._running.is_set():
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
self.thread_pool.shutdown(wait=False)

def register(
self,
tool_name: str,
handler: Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]],
is_async: bool = False,
) -> None:
"""Register a new tool that can be called by the AI agent.
Args:
tool_name: Unique identifier for the tool
handler: Function that implements the tool's logic
is_async: Whether the handler is an async function
"""
with self.lock:
if not callable(handler):
raise ValueError("Handler must be callable")
if tool_name in self.tools:
raise ValueError(f"Tool '{tool_name}' is already registered")
self.tools[tool_name] = (handler, is_async)

async def handle(self, tool_name: str, parameters: dict) -> Any:
"""Execute a registered tool with the given parameters.
Returns the result of the tool execution.
"""
with self.lock:
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' is not registered")
handler, is_async = self.tools[tool_name]

if is_async:
return await handler(parameters)
else:
return await asyncio.get_event_loop().run_in_executor(self.thread_pool, handler, parameters)

def execute_tool(self, tool_name: str, parameters: dict, callback: Callable[[dict], None]):
"""Execute a tool and send its result via the provided callback.
This method is non-blocking and handles both sync and async tools.
"""
if not self._running.is_set():
raise RuntimeError("ClientTools event loop is not running")

async def _execute_and_callback():
try:
result = await self.handle(tool_name, parameters)
response = {
"type": "client_tool_result",
"tool_call_id": parameters.get("tool_call_id"),
"result": result or f"Client tool: {tool_name} called successfully.",
"is_error": False,
}
except Exception as e:
response = {
"type": "client_tool_result",
"tool_call_id": parameters.get("tool_call_id"),
"result": str(e),
"is_error": True,
}
callback(response)

asyncio.run_coroutine_threadsafe(_execute_and_callback(), self._loop)


class ConversationConfig:
"""Configuration options for the Conversation."""

def __init__(
self,
extra_body: Optional[dict] = None,
conversation_config_override: Optional[dict] = None,
):
self.extra_body = extra_body or {}
self.conversation_config_override = conversation_config_override or {}



class Conversation:
client: BaseElevenLabs
agent_id: str
requires_auth: bool
config: ConversationConfig
audio_interface: AudioInterface
client_tools: Optional[ClientTools]
callback_agent_response: Optional[Callable[[str], None]]
callback_agent_response_correction: Optional[Callable[[str, str], None]]
callback_user_transcript: Optional[Callable[[str], None]]
Expand All @@ -86,7 +199,7 @@ def __init__(
requires_auth: bool,
audio_interface: AudioInterface,
config: Optional[ConversationConfig] = None,

client_tools: Optional[ClientTools] = None,
callback_agent_response: Optional[Callable[[str], None]] = None,
callback_agent_response_correction: Optional[Callable[[str, str], None]] = None,
callback_user_transcript: Optional[Callable[[str], None]] = None,
Expand All @@ -101,6 +214,7 @@ def __init__(
agent_id: The ID of the agent to converse with.
requires_auth: Whether the agent requires authentication.
audio_interface: The audio interface to use for input and output.
client_tools: The client tools to use for the conversation.
callback_agent_response: Callback for agent responses.
callback_agent_response_correction: Callback for agent response corrections.
First argument is the original response (previously given to
Expand All @@ -112,14 +226,16 @@ def __init__(
self.client = client
self.agent_id = agent_id
self.requires_auth = requires_auth

self.audio_interface = audio_interface
self.callback_agent_response = callback_agent_response
self.config = config or ConversationConfig()
self.client_tools = client_tools or ClientTools()
self.callback_agent_response_correction = callback_agent_response_correction
self.callback_user_transcript = callback_user_transcript
self.callback_latency_measurement = callback_latency_measurement

self.client_tools.start()

self._thread = None
self._should_stop = threading.Event()
self._conversation_id = None
Expand All @@ -135,8 +251,9 @@ def start_session(self):
self._thread.start()

def end_session(self):
"""Ends the conversation session."""
"""Ends the conversation session and cleans up resources."""
self.audio_interface.stop()
self.client_tools.stop()
self._should_stop.set()

def wait_for_session_end(self) -> Optional[str]:
Expand All @@ -155,10 +272,10 @@ def _run(self, ws_url: str):
with connect(ws_url) as ws:
ws.send(
json.dumps(
{
"type": "conversation_initiation_client_data",
"custom_llm_extra_body": self.config.extra_body,
"conversation_config_override": self.config.conversation_config_override,
{
"type": "conversation_initiation_client_data",
"custom_llm_extra_body": self.config.extra_body,
"conversation_config_override": self.config.conversation_config_override,
}
)
)
Expand Down Expand Up @@ -210,7 +327,7 @@ def _handle_message(self, message, ws):
self.callback_user_transcript(event["user_transcript"].strip())
elif message["type"] == "interruption":
event = message["interruption_event"]
self.last_interrupt_id = int(event["event_id"])
self._last_interrupt_id = int(event["event_id"])
self.audio_interface.interrupt()
elif message["type"] == "ping":
event = message["ping_event"]
Expand All @@ -224,6 +341,16 @@ def _handle_message(self, message, ws):
)
if self.callback_latency_measurement and event["ping_ms"]:
self.callback_latency_measurement(int(event["ping_ms"]))
elif message["type"] == "client_tool_call":
tool_call = message.get("client_tool_call", {})
tool_name = tool_call.get("tool_name")
parameters = {"tool_call_id": tool_call["tool_call_id"], **tool_call.get("parameters", {})}

def send_response(response):
if not self._should_stop.is_set():
ws.send(json.dumps(response))

self.client_tools.execute_tool(tool_name, parameters, send_response)
else:
pass # Ignore all other message types.

Expand Down
79 changes: 79 additions & 0 deletions tests/e2e_test_convai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import os
import time
import asyncio

import pytest
from elevenlabs import ElevenLabs
from elevenlabs.conversational_ai.conversation import Conversation, ClientTools
from elevenlabs.conversational_ai.default_audio_interface import DefaultAudioInterface


@pytest.mark.skipif(os.getenv("CI") == "true", reason="Skip live conversation test in CI environment")
def test_live_conversation():
"""Test a live conversation with actual audio I/O"""

api_key = os.getenv("ELEVENLABS_API_KEY")
if not api_key:
raise ValueError("ELEVENLABS_API_KEY environment variable missing.")

agent_id = os.getenv("AGENT_ID")
if not api_key or not agent_id:
raise ValueError("AGENT_ID environment variable missing.")

client = ElevenLabs(api_key=api_key)

# Create conversation handlers
def on_agent_response(text: str):
print(f"Agent: {text}")

def on_user_transcript(text: str):
print(f"You: {text}")

def on_latency(ms: int):
print(f"Latency: {ms}ms")

# Initialize client tools
client_tools = ClientTools()

def test(parameters):
print("Sync tool called with parameters:", parameters)
return "Tool called successfully"

async def test_async(parameters):
# Simulate some async work
await asyncio.sleep(10)
print("Async tool called with parameters:", parameters)
return "Tool called successfully"

client_tools.register("test", test)
client_tools.register("test_async", test_async, is_async=True)

# Initialize conversation
conversation = Conversation(
client=client,
agent_id=agent_id,
requires_auth=False,
audio_interface=DefaultAudioInterface(),
callback_agent_response=on_agent_response,
callback_user_transcript=on_user_transcript,
callback_latency_measurement=on_latency,
client_tools=client_tools,
)

# Start the conversation
conversation.start_session()

# Let it run for 100 seconds
time.sleep(100)

# End the conversation
conversation.end_session()
conversation.wait_for_session_end()

# Get the conversation ID for reference
conversation_id = conversation._conversation_id
print(f"Conversation ID: {conversation_id}")


if __name__ == "__main__":
test_live_conversation()

0 comments on commit 57f85b6

Please sign in to comment.