Skip to content

Commit

Permalink
ai functions improvements follow up (livekit#393)
Browse files Browse the repository at this point in the history
* Update assistant.py

* cleanup voice assistant & make it faster (needs improvements on reliability)
  • Loading branch information
theomonnom authored Jul 3, 2024
1 parent 39a5959 commit 70bc061
Show file tree
Hide file tree
Showing 17 changed files with 1,186 additions and 904 deletions.
16 changes: 6 additions & 10 deletions examples/voice-assistant/minimal_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@


async def entrypoint(ctx: JobContext):
initial_ctx = ChatContext(
messages=[
ChatMessage(
role=ChatRole.SYSTEM,
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
]
initial_ctx = ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)

assistant = VoiceAssistant(
Expand Down
24 changes: 7 additions & 17 deletions livekit-agents/livekit/agents/llm/_oai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
]


def create_ai_function_task(
def create_ai_function_info(
fnc_ctx: function_context.FunctionContext,
tool_call_id: str,
fnc_name: str,
raw_arguments: str, # JSON string
) -> tuple[asyncio.Task[Any], function_context.CalledFunction]:
) -> function_context.FunctionCallInfo:
if fnc_name not in fnc_ctx.ai_functions:
raise ValueError(f"AI function {fnc_name} not found")

Expand Down Expand Up @@ -80,21 +80,11 @@ def create_ai_function_task(

sanitized_arguments[arg_info.name] = sanitized_value

func = functools.partial(fnc_info.callable, **sanitized_arguments)
if asyncio.iscoroutinefunction(fnc_info.callable):
task = asyncio.create_task(func())
else:
task = asyncio.create_task(asyncio.to_thread(func))

return (
task,
function_context.CalledFunction(
tool_call_id=tool_call_id,
raw_arguments=raw_arguments,
function_info=fnc_info,
arguments=sanitized_arguments,
task=task,
),
return function_context.FunctionCallInfo(
tool_call_id=tool_call_id,
raw_arguments=raw_arguments,
function_info=fnc_info,
arguments=sanitized_arguments,
)


Expand Down
15 changes: 8 additions & 7 deletions livekit-agents/livekit/agents/llm/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ChatMessage:
role: ChatRole
name: str | None = None
content: str | list[str | ChatImage] | None = None
tool_calls: list[function_context.CalledFunction] | None = None
tool_calls: list[function_context.FunctionCallInfo] | None = None
tool_call_id: str | None = None

@staticmethod
Expand All @@ -50,20 +50,21 @@ def create_tool_from_called_function(
if not called_function.task.done():
raise ValueError("cannot create a tool result from a running ai function")

content = called_function.task.result()
if called_function.task.exception() is not None:
content = f"Error: {called_function.task.exception}"
try:
content = called_function.task.result()
except BaseException as e:
content = f"Error: {e}"

return ChatMessage(
role="tool",
name=called_function.function_info.name,
name=called_function.call_info.function_info.name,
content=content,
tool_call_id=called_function.tool_call_id,
tool_call_id=called_function.call_info.tool_call_id,
)

@staticmethod
def create_tool_calls(
called_functions: list[function_context.CalledFunction],
called_functions: list[function_context.FunctionCallInfo],
) -> "ChatMessage":
return ChatMessage(
role="assistant",
Expand Down
29 changes: 28 additions & 1 deletion livekit-agents/livekit/agents/llm/function_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import asyncio
import functools
import enum
import inspect
import typing
Expand Down Expand Up @@ -57,12 +58,38 @@ class FunctionInfo:


@dataclass
class CalledFunction:
class FunctionCallInfo:
tool_call_id: str
function_info: FunctionInfo
raw_arguments: str
arguments: dict[str, Any]

def execute(self) -> CalledFunction:
function_info = self.function_info
func = functools.partial(function_info.callable, **self.arguments)
if asyncio.iscoroutinefunction(function_info.callable):
task = asyncio.create_task(func())
else:
task = asyncio.create_task(asyncio.to_thread(func))

called_fnc = CalledFunction(call_info=self, task=task)

def _on_done(fut):
try:
called_fnc.result = fut.result()
except BaseException as e:
called_fnc.exception = e

task.add_done_callback(_on_done)
return called_fnc


@dataclass
class CalledFunction:
call_info: FunctionCallInfo
task: asyncio.Task[Any]
result: Any | None = None
exception: BaseException | None = None


def ai_callable(
Expand Down
49 changes: 37 additions & 12 deletions livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import abc
from dataclasses import dataclass, field
from typing import AsyncIterator
Expand All @@ -12,7 +13,7 @@
class ChoiceDelta:
role: ChatRole
content: str | None = None
tool_calls: list[function_context.CalledFunction] | None = None
tool_calls: list[function_context.FunctionCallInfo] | None = None


@dataclass
Expand All @@ -28,7 +29,7 @@ class ChatChunk:

class LLM(abc.ABC):
@abc.abstractmethod
async def chat(
def chat(
self,
*,
chat_ctx: ChatContext,
Expand All @@ -39,24 +40,48 @@ async def chat(


class LLMStream(abc.ABC):
def __init__(self) -> None:
self._called_functions: list[function_context.CalledFunction] = []
def __init__(
self, *, chat_ctx: ChatContext, fnc_ctx: function_context.FunctionContext | None
) -> None:
self._function_calls_info: list[function_context.FunctionCallInfo] = []
self._tasks = set[asyncio.Task]()
self._chat_ctx = chat_ctx
self._fnc_ctx = fnc_ctx

@property
def called_functions(self) -> list[function_context.CalledFunction]:
def function_calls(self) -> list[function_context.FunctionCallInfo]:
"""List of called functions from this stream."""
return self._called_functions
return self._function_calls_info

@abc.abstractmethod
async def gather_function_results(
@property
def chat_ctx(self) -> ChatContext:
"""The initial chat context of this stream."""
return self._chat_ctx

@property
def fnc_ctx(self) -> function_context.FunctionContext | None:
"""The function context of this stream."""
return self._fnc_ctx

def execute_functions(
self,
) -> list[function_context.CalledFunction]: ...
) -> list[function_context.CalledFunction]:
"""Run all functions in this stream."""
called_functions = []
for fnc_info in self._function_calls_info:
called_fnc = fnc_info.execute()
called_functions.append(called_fnc)

return called_functions

async def aclose(self) -> None:
for task in self._tasks:
task.cancel()

await asyncio.gather(*self._tasks, return_exceptions=True)

def __aiter__(self) -> AsyncIterator[ChatChunk]:
return self

@abc.abstractmethod
async def __anext__(self) -> ChatChunk: ...

@abc.abstractmethod
async def aclose(self) -> None: ...
12 changes: 6 additions & 6 deletions livekit-agents/livekit/agents/utils/event_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

class EventEmitter(Generic[T]):
def __init__(self) -> None:
self._events: Dict[T, Set[Callable[[Any], None]]] = dict()
self._events: Dict[T, Set[Callable[..., Any]]] = dict()

def emit(self, event: T, *args: Any, **kwargs: Any) -> None:
if event in self._events:
callables = self._events[event].copy()
for callback in callables:
callback(*args, **kwargs)

def once(self, event: T, callback: Optional[Callable[[Any], None]] = None):
def once(self, event: T, callback: Optional[Callable[..., Any]] = None):
if callback is not None:

def once_callback(*args: Any, **kwargs: Any):
Expand All @@ -23,26 +23,26 @@ def once_callback(*args: Any, **kwargs: Any):
return self.on(event, once_callback)
else:

def decorator(callback: Callable[[Any], None]):
def decorator(callback: Callable[..., Any]):
self.once(event, callback)
return callback

return decorator

def on(self, event: T, callback: Optional[Callable[[Any], None]] = None):
def on(self, event: T, callback: Optional[Callable[..., Any]] = None):
if callback is not None:
if event not in self._events:
self._events[event] = set()
self._events[event].add(callback)
return callback
else:

def decorator(callback: Callable[[Any], None]):
def decorator(callback: Callable[..., Any]):
self.on(event, callback)
return callback

return decorator

def off(self, event: T, callback: Callable[[Any], None]) -> None:
def off(self, event: T, callback: Callable[..., Any]) -> None:
if event in self._events:
self._events[event].remove(callback)
23 changes: 16 additions & 7 deletions livekit-agents/livekit/agents/vad.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from enum import Enum, unique
from typing import AsyncIterator, List

from livekit import rtc


@unique
class VADEventType(Enum):
START_OF_SPEECH = 1
INFERENCE_DONE = 2
END_OF_SPEECH = 3
START_OF_SPEECH = "start_of_speech"
INFERENCE_DONE = "inference_done"
END_OF_SPEECH = "end_of_speech"


@dataclass
class VADEvent:
type: VADEventType
"""type of the event"""
samples_index: int
"""index of the samples of the event (when the event was fired)"""
duration: float = 0.0
"""duration of the speech in seconds (only for END_SPEAKING event)"""
"""index of the samples when the event was fired"""
duration: float
"""duration of the speech in seconds"""
frames: List[rtc.AudioFrame] = field(default_factory=list)
"""list of audio frames of the speech"""
probability: float = 0.0
Expand All @@ -31,6 +32,14 @@ class VADEvent:


class VAD(ABC):
def __init__(self, *, update_interval: float) -> None:
self._update_interval = update_interval

@property
def update_interval(self) -> float:
"""interval in seconds to update the VAD model"""
return self._update_interval

@abstractmethod
def stream(
self,
Expand Down
4 changes: 2 additions & 2 deletions livekit-agents/livekit/agents/voice_assistant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .assistant import AssistantCallContext, VoiceAssistant
from .assistant import VoiceAssistant

__all__ = ["VoiceAssistant", "AssistantCallContext"]
__all__ = ["VoiceAssistant"]
Loading

0 comments on commit 70bc061

Please sign in to comment.