Skip to content

Commit

Permalink
refactor/improve_readwritestream
Browse files Browse the repository at this point in the history
improve the ReadWriteStream class used by audio transformers

make it threadsafe and more performant by using a deque

add a max size to ensure it doesnt grow forever

maybe fixes undiagnosed issue OpenVoiceOS/ovos-dinkum-listener#98
  • Loading branch information
JarbasAl committed Jun 8, 2024
1 parent efac638 commit 51b833f
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 25 deletions.
58 changes: 35 additions & 23 deletions ovos_plugin_manager/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
#
"""Common functions for loading plugins."""
import time
from collections import deque
from enum import Enum
from threading import Event
from threading import Event, Lock
from typing import Optional

import pkg_resources

from ovos_utils.log import LOG


Expand Down Expand Up @@ -185,41 +185,53 @@ def normalize_lang(lang):
class ReadWriteStream:
"""
Class used to support writing binary audio data at any pace,
optionally chopping when the buffer gets too large
with an optional maximum buffer size
"""

def __init__(self, s=b'', chop_samples=-1):
self.buffer = s
def __init__(self, s=b'', max_size=None):
self.buffer = deque(s)
self.write_event = Event()
self.chop_samples = chop_samples
self.lock = Lock()
self.max_size = max_size # Introduce max size

def __len__(self):
return len(self.buffer)
with self.lock:
return len(self.buffer)

def read(self, n=-1, timeout=None):
if n == -1:
n = len(self.buffer)
if 0 < self.chop_samples < len(self.buffer):
samples_left = len(self.buffer) % self.chop_samples
self.buffer = self.buffer[-samples_left:]
return_time = 1e10 if timeout is None else (
timeout + time.time()
)
while len(self.buffer) < n:
with self.lock:
if n == -1 or n > len(self.buffer):
n = len(self.buffer)

end_time = time.time() + timeout if timeout is not None else float('inf')

while True:
with self.lock:
if len(self.buffer) >= n:
chunk = bytes([self.buffer.popleft() for _ in range(n)])
return chunk

remaining_time = None
if timeout is not None:
remaining_time = end_time - time.time()
if remaining_time <= 0:
return b''

self.write_event.clear()
if not self.write_event.wait(return_time - time.time()):
return b''
chunk = self.buffer[:n]
self.buffer = self.buffer[n:]
return chunk
self.write_event.wait(remaining_time)

def write(self, s):
self.buffer += s
with self.lock:
self.buffer.extend(s)
if self.max_size is not None:
while len(self.buffer) > self.max_size:
self.buffer.popleft() # Discard oldest data to maintain max size
self.write_event.set()

def flush(self):
"""Makes compatible with sys.stdout"""
pass

def clear(self):
self.buffer = b''
with self.lock:
self.buffer.clear()
52 changes: 50 additions & 2 deletions test/unittests/test_audio_transformers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,60 @@
import unittest
import time
from unittest.mock import patch

from ovos_plugin_manager.utils import PluginTypes, PluginConfigTypes
from ovos_plugin_manager.utils import PluginTypes, PluginConfigTypes, ReadWriteStream


class TestReadWriteStream(unittest.TestCase):
def test_write_and_read(self):
# Initialize the stream
stream = ReadWriteStream()

# Write some data to the stream
stream.write(b'1234567890abcdefghijklmnopqrstuvwxyz')

# Read some data from the stream
self.assertEqual(stream.read(10), b'1234567890')

# Read more data with a timeout
self.assertEqual(stream.read(5, timeout=1), b'abcde')

def test_clear_buffer(self):
# Initialize the stream
stream = ReadWriteStream()

# Write some data to the stream
stream.write(b'1234567890abcdefghijklmnopqrstuvwxyz')

# Clear the buffer
stream.clear()
self.assertEqual(len(stream), 0)

def test_write_with_max_size(self):
# Initialize the stream with a max size of 20 bytes
stream = ReadWriteStream(max_size=20)

# Write some data to the stream
stream.write(b'1234567890abcdefghijklmnopqrstuvwxyz')

# The buffer should have been trimmed to the last 20 bytes
self.assertEqual(stream.read(20), b'ghijklmnopqrstuvwxyz')

def test_clear_buffer_with_max_size(self):
# Initialize the stream with a max size of 20 bytes
stream = ReadWriteStream(max_size=20)

# Write some data to the stream
stream.write(b'1234567890abcdefghijklmnopqrstuvwxyz')

# Clear the buffer
stream.clear()
self.assertEqual(len(stream), 0)


class TestAudioTransformersTemplate(unittest.TestCase):
def test_audio_transformer(self):
from ovos_plugin_manager.templates.transformers import AudioTransformer
pass
# TODO


Expand Down

0 comments on commit 51b833f

Please sign in to comment.