Skip to content

Commit

Permalink
fix(tools.serve): resolve coroutine
Browse files Browse the repository at this point in the history
  • Loading branch information
jourdain committed Sep 9, 2024
1 parent 58a2c7f commit 11795b6
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
2 changes: 1 addition & 1 deletion trame/tools/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def on_msg_from_server(binary, content):
try:
await ws_network.prepare(request)
ws_app = app.server._server.ws
connection = ws_app.connect()
connection = await ws_app.connect()
connection.on_message(on_msg_from_server)
async for msg in ws_network:
await connection.send(msg.type == aiohttp.WSMsgType.BINARY, msg)
Expand Down
Empty file added trame/utils/__init__.py
Empty file.
62 changes: 62 additions & 0 deletions trame/utils/exec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import asyncio


class Throttle:
"""
Helper class that wrap a function with a given max execution rate.
By default the rate is set to execute no more than once a second.
:param fn: the function to call.
:type fn: function
:param ts: Number of seconds to wait before the next execution.
:type ts: float
"""

def __init__(self, fn, ts=1):
self._ts = ts
self._fn = fn
self._requests = 0
self._pending = False
self._last_args = []
self._last_kwargs = {}

@property
def rate(self):
"""Number of maximum executions per second"""
return 1.0 / self._ts

@rate.setter
def rate(self, rate):
"""Update the maximum number of executions per seconds"""
self._ts = 1.0 / rate

@property
def delta_t(self):
"""Number of seconds to wait between execution"""
return self._ts

@delta_t.setter
def delta_t(self, seconds):
"""Update the number of seconds to wait between execution"""
self._ts = seconds

async def _trottle(self):
self._pending = True
if self._requests:
self._fn(*self._last_args, **self._last_kwargs)
self._requests = 0

await asyncio.sleep(self._ts)
if self._requests > 0:
await self._trottle()
self._pending = False

def __call__(self, *args, **kwargs):
"""Function call wrapper that will throttle the actual function provided at construction"""
self._requests += 1
self._last_args = args
self._last_kwargs = kwargs

if not self._pending:
asyncio.create_task(self._trottle())

0 comments on commit 11795b6

Please sign in to comment.