Skip to content

Commit

Permalink
ipc: use our own duplex instead of mp.Queue (#634)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Aug 16, 2024
1 parent 7b611cd commit 482ed3b
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 44 deletions.
5 changes: 5 additions & 0 deletions .changeset/silent-shoes-drop.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

ipc: use our own duplex instead of mp.Queue
43 changes: 34 additions & 9 deletions livekit-agents/livekit/agents/ipc/proc_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import contextlib
import copy
import logging
import multiprocessing as mp
import pickle
import queue
import socket
import threading
from dataclasses import dataclass
from typing import Optional

from livekit import rtc

Expand All @@ -18,9 +21,29 @@


class LogQueueHandler(logging.Handler):
def __init__(self, queue: mp.Queue) -> None:
_sentinal = None

def __init__(self, duplex: utils.aio.duplex_unix._Duplex) -> None:
super().__init__()
self._q = queue
self._duplex = duplex
self._send_q = queue.SimpleQueue[Optional[logging.LogRecord]]()
self._send_thread = threading.Thread(
target=self._forward_logs, name="ipc_log_forwarder"
)
self._send_thread.start()

def _forward_logs(self):
while True:
record = self._send_q.get()
if record is None:
break

try:
self._duplex.send_bytes(pickle.dumps(record))
except duplex_unix.DuplexClosed:
break

self._duplex.close()

def emit(self, record: logging.LogRecord) -> None:
try:
Expand All @@ -31,10 +54,14 @@ def emit(self, record: logging.LogRecord) -> None:
record.args = None
record.exc_info = None
record.exc_text = None
self._q.put_nowait(record)
self._send_q.put_nowait(record)
except Exception:
self.handleError(record)

def close(self) -> None:
super().close()
self._send_q.put_nowait(self._sentinal)


@dataclass
class _ShutdownInfo:
Expand Down Expand Up @@ -213,9 +240,8 @@ def main(args: proto.ProcStartArgs) -> None:
root_logger = logging.getLogger()
root_logger.setLevel(logging.NOTSET)

log_q = args.log_q
log_q.cancel_join_thread()
log_handler = LogQueueHandler(log_q)
log_cch = utils.aio.duplex_unix._Duplex.open(args.log_cch)
log_handler = LogQueueHandler(log_cch)
root_logger.addHandler(log_handler)

loop = asyncio.new_event_loop()
Expand Down Expand Up @@ -250,7 +276,6 @@ def main(args: proto.ProcStartArgs) -> None:
except duplex_unix.DuplexClosed:
pass
finally:
log_handler.close()
log_q.close()
cch.close()
log_handler.close()
loop.run_until_complete(loop.shutdown_default_executor())
3 changes: 1 addition & 2 deletions livekit-agents/livekit/agents/ipc/proto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import io
import multiprocessing as mp
import socket
from dataclasses import dataclass, field
from typing import Any, Callable, ClassVar
Expand All @@ -21,7 +20,7 @@
class ProcStartArgs:
initialize_process_fnc: Callable[[JobProcess], Any]
job_entrypoint_fnc: Callable[[JobContext], Any]
log_q: mp.Queue
log_cch: socket.socket
mp_cch: socket.socket
asyncio_debug: bool
user_arguments: Any | None = None
Expand Down
37 changes: 19 additions & 18 deletions livekit-agents/livekit/agents/ipc/supervised_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import contextlib
import logging
import multiprocessing as mp
import pickle
import socket
import sys
import threading
Expand All @@ -19,25 +19,24 @@


class LogQueueListener:
_sentinel = None

def __init__(
self, queue: mp.Queue, prepare_fnc: Callable[[logging.LogRecord], None]
self,
duplex: utils.aio.duplex_unix._Duplex,
prepare_fnc: Callable[[logging.LogRecord], None],
):
self._thread: threading.Thread | None = None
self._q = queue
self._duplex = duplex
self._prepare_fnc = prepare_fnc

def start(self) -> None:
self._thread = t = threading.Thread(
target=self._monitor, daemon=True, name="log_listener"
)
t.start()
self._thread = threading.Thread(target=self._monitor, name="ipc_log_listener")
self._thread.start()

def stop(self) -> None:
if self._thread is None:
return
self._q.put_nowait(self._sentinel)

self._duplex.close()
self._thread.join()
self._thread = None

Expand All @@ -52,10 +51,12 @@ def handle(self, record: logging.LogRecord) -> None:

def _monitor(self):
while True:
record = self._q.get()
if record is self._sentinel:
try:
data = self._duplex.recv_bytes()
except utils.aio.duplex_unix.DuplexClosed:
break

record = pickle.loads(data)
self.handle(record)


Expand Down Expand Up @@ -145,19 +146,19 @@ def _add_proc_ctx_log(record: logging.LogRecord) -> None:
setattr(record, key, value)

async with self._lock:
log_q = self._opts.mp_ctx.Queue()
log_q.cancel_join_thread()

mp_pch, mp_cch = socket.socketpair()
mp_log_pch, mp_log_cch = socket.socketpair()

self._pch = await duplex_unix._AsyncDuplex.open(mp_pch)
log_listener = LogQueueListener(log_q, _add_proc_ctx_log)

log_pch = duplex_unix._Duplex.open(mp_log_pch)
log_listener = LogQueueListener(log_pch, _add_proc_ctx_log)
log_listener.start()

self._proc_args = proto.ProcStartArgs(
initialize_process_fnc=self._opts.initialize_process_fnc,
job_entrypoint_fnc=self._opts.job_entrypoint_fnc,
log_q=log_q,
log_cch=mp_log_cch,
mp_cch=mp_cch,
asyncio_debug=self._loop.get_debug(),
user_arguments=self._user_args,
Expand All @@ -168,6 +169,7 @@ def _add_proc_ctx_log(record: logging.LogRecord) -> None:
)

self._proc.start()
mp_log_cch.close()
mp_cch.close()

self._pid = self._proc.pid
Expand All @@ -176,7 +178,6 @@ def _add_proc_ctx_log(record: logging.LogRecord) -> None:
def _sync_run():
self._proc.join()
log_listener.stop()
log_q.close()
try:
self._loop.call_soon_threadsafe(self._join_fut.set_result, None)
except RuntimeError:
Expand Down
25 changes: 15 additions & 10 deletions livekit-agents/livekit/agents/utils/aio/duplex_unix.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ async def recv_bytes(self) -> bytes:
len = struct.unpack("!I", len_bytes)[0]
return await self._reader.readexactly(len)
except (
BrokenPipeError,
ConnectionResetError,
OSError,
EOFError,
asyncio.IncompleteReadError,
):
Expand All @@ -49,15 +48,15 @@ async def send_bytes(self, data: bytes) -> None:
self._writer.write(len_bytes)
self._writer.write(data)
await self._writer.drain()
except (ConnectionResetError, BrokenPipeError):
except OSError:
raise DuplexClosed()

async def aclose(self) -> None:
try:
self._writer.close()
await self._writer.wait_closed()
self._sock.close()
except (BrokenPipeError, ConnectionResetError):
except OSError:
raise DuplexClosed()


Expand All @@ -80,25 +79,31 @@ def open(sock: socket.socket) -> _Duplex:
return _Duplex(sock)

def recv_bytes(self) -> bytes:
assert self._sock is not None
if self._sock is None:
raise DuplexClosed()

try:
len_bytes = _read_exactly(self._sock, 4)
len = struct.unpack("!I", len_bytes)[0]
return _read_exactly(self._sock, len)
except (BrokenPipeError, ConnectionResetError, EOFError):
except (OSError, EOFError):
raise DuplexClosed()

def send_bytes(self, data: bytes) -> None:
assert self._sock is not None
if self._sock is None:
raise DuplexClosed()

try:
len_bytes = struct.pack("!I", len(data))
self._sock.sendall(len_bytes)
self._sock.sendall(data)
except (BrokenPipeError, ConnectionResetError):
except OSError:
raise DuplexClosed()

def detach(self) -> socket.socket:
assert self._sock is not None
if self._sock is None:
raise DuplexClosed()

sock = self._sock
self._sock = None
return sock
Expand All @@ -108,5 +113,5 @@ def close(self) -> None:
if self._sock is not None:
self._sock.close()
self._sock = None
except (BrokenPipeError, ConnectionResetError):
except OSError:
raise DuplexClosed()
20 changes: 15 additions & 5 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
from enum import Enum
from typing import Annotated
Expand Down Expand Up @@ -62,13 +64,13 @@ async def toggle_light(
self._toggle_light_cancelled = True

# used to test arrays as arguments
@ai_callable(description="Currencies of a specific country")
@ai_callable(description="Currencies of a specific continent")
def select_currencies(
self,
currencies: Annotated[
list[str],
TypeInfo(
description="The currency to select",
description="The currencies to select",
choices=["usd", "eur", "gbp", "jpy", "sek"],
),
],
Expand Down Expand Up @@ -165,7 +167,10 @@ async def test_calls_arrays():
llm = openai.LLM(model="gpt-4o")

stream = await _request_fnc_call(
llm, "Can you select all currencies in Europe at once?", fnc_ctx
llm,
"Can you select all currencies in Europe at once?",
fnc_ctx,
temperature=0.5,
)
fns = stream.execute_functions()
await asyncio.gather(*[f.task for f in fns])
Expand Down Expand Up @@ -194,10 +199,15 @@ async def test_calls_choices():


async def _request_fnc_call(
model: llm.LLM, request: str, fnc_ctx: FncCtx
model: llm.LLM,
request: str,
fnc_ctx: FncCtx,
temperature: float | None = None,
) -> llm.LLMStream:
stream = model.chat(
chat_ctx=ChatContext().append(text=request, role="user"), fnc_ctx=fnc_ctx
chat_ctx=ChatContext().append(text=request, role="user"),
fnc_ctx=fnc_ctx,
temperature=temperature,
)

async for _ in stream:
Expand Down

0 comments on commit 482ed3b

Please sign in to comment.