Skip to content

Commit

Permalink
Re-implementing requests as dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
josiah-wolf-oberholtzer committed Feb 15, 2023
1 parent 6fe9748 commit 36e13e1
Show file tree
Hide file tree
Showing 11 changed files with 2,160 additions and 222 deletions.
245 changes: 131 additions & 114 deletions supriya/contexts/core.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions supriya/contexts/nonrealtime.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from typing import Dict, Iterator, List, Optional, SupportsInt, Type, Union

from ..commands import Request, RequestBundle, Requestable
from ..enums import CalculationRate
from ..osc import OscBundle
from ..scsynth import Options
from ..typing import SupportsOsc
from .core import Context, ContextError, ContextObject, Node
from .requests import RequestBundle, Requestable


class NonrealtimeContext(Context):

### INITIALIZER ###

def __init__(self, options: Optional[Options] = None, **kwargs):
super().__init__(options=options, **kwargs)
self._requests: Dict[float, List[Request]] = {}
self._requests: Dict[float, List[Requestable]] = {}

### PRIVATE METHODS ###

Expand Down Expand Up @@ -50,7 +50,7 @@ def iterate_request_bundles(self) -> Iterator[RequestBundle]:
continue
yield RequestBundle(timestamp=timestamp, contents=requests)

def send(self, requestable: Requestable) -> None:
def send(self, requestable: SupportsOsc) -> None:
if not isinstance(requestable, RequestBundle):
raise ContextError
elif requestable.timestamp is None:
Expand Down
190 changes: 108 additions & 82 deletions supriya/contexts/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import dataclasses
import enum
import logging
import threading
from typing import (
TYPE_CHECKING,
Awaitable,
Expand All @@ -20,19 +19,6 @@
from uqbar.objects import new

from ..assets.synthdefs import system_synthdefs
from ..commands import (
BufferInfo,
BufferQueryRequest,
DoneResponse,
FailResponse,
GroupQueryTreeRequest,
NodeInfo,
NodeQueryRequest,
NotifyRequest,
QuitRequest,
Requestable,
SyncRequest,
)
from ..enums import CalculationRate
from ..exceptions import (
OwnedServerShutdown,
Expand All @@ -45,16 +31,41 @@
from ..osc import (
AsyncOscProtocol,
HealthCheck,
OscBundle,
OscMessage,
OscProtocol,
OscProtocolOffline,
ThreadedOscProtocol,
)
from ..querytree import QueryTreeGroup
from ..querytree import QueryTreeGroup, QueryTreeSynth
from ..scsynth import AsyncProcessProtocol, Options, SyncProcessProtocol
from ..synthdefs import SynthDef
from .core import Buffer, Bus, Context, ContextObject, Group, Node
from ..typing import SupportsOsc
from .core import (
Buffer,
Bus,
Context,
ContextObject,
Group,
InvalidCalculationRate,
Node,
)
from .requests import (
GetControlBus,
QueryBuffer,
QueryNode,
QueryTree,
Quit,
Sync,
ToggleNotifications,
)
from .responses import (
BufferInfo,
DoneInfo,
FailInfo,
GetControlBusInfo,
NodeInfo,
QueryTreeInfo,
)

if TYPE_CHECKING:
from ..realtime.shm import ServerSHM
Expand All @@ -80,7 +91,6 @@ class BootStatus(enum.IntEnum):


class RealtimeContext(Context):

### CLASS VARIABLES ###

_contexts: Set["RealtimeContext"] = set()
Expand All @@ -94,7 +104,6 @@ def __init__(
self._is_owner = False
self._boot_status = BootStatus.OFFLINE
self._buffers: Set[int] = set()
self._lock = threading.RLock()
self._maximum_logins = 1
self._node_active: Dict[int, bool] = {}
self._node_children: Dict[int, List[int]] = {}
Expand Down Expand Up @@ -137,14 +146,6 @@ def _free_id(
) -> None:
self._get_allocator(type_, calculation_rate).free(id_)

def _get_next_sync_id(self) -> int:
with self._lock:
sync_id = self._sync_id
self._sync_id += 1
if self._sync_id > self._sync_id_maximum:
self._sync_id = self._sync_id_minimum
return sync_id

def _handle_osc_callbacks(self, message: OscMessage) -> None:
def _handle_done(message: OscMessage) -> None:
if message.contents[0] in (
Expand Down Expand Up @@ -296,17 +297,22 @@ def query_buffer(self, buffer: Buffer) -> Union[Awaitable[BufferInfo], BufferInf
def query_node(self, node: Node) -> Union[Awaitable[NodeInfo], NodeInfo]:
raise NotImplementedError

def send(self, message: Union[OscBundle, OscMessage, Requestable]) -> None:
@abc.abstractmethod
def query_tree(
self,
) -> Union[
Awaitable[Union[QueryTreeGroup, QueryTreeSynth]],
Union[QueryTreeGroup, QueryTreeSynth],
]:
raise NotImplementedError

def send(self, message: SupportsOsc) -> None:
if self._boot_status not in (BootStatus.BOOTING, BootStatus.ONLINE):
raise ServerOffline
self._osc_protocol.send(
message.to_osc() if isinstance(message, Requestable) else message
message.to_osc() if hasattr(message, "to_osc") else message
)

@abc.abstractmethod
def query_tree(self) -> Union[Awaitable[QueryTreeGroup], QueryTreeGroup]:
raise NotImplementedError

### PUBLIC PROPERTIES ###

@property
Expand All @@ -327,7 +333,6 @@ def osc_protocol(self) -> OscProtocol:


class Server(RealtimeContext):

### INITIALIZER ###

def __init__(self, options: Optional[Options] = None, **kwargs):
Expand All @@ -337,7 +342,7 @@ def __init__(self, options: Optional[Options] = None, **kwargs):
### PRIVATE METHODS ###

def _connect(self) -> None:
logger.info("connecting")
logger.info("Connecting")
cast(ThreadedOscProtocol, self._osc_protocol).connect(
ip_address=self._options.ip_address,
port=self._options.port,
Expand All @@ -353,10 +358,10 @@ def _connect(self) -> None:
self._setup_system()
self.sync()
self._boot_status = BootStatus.ONLINE
logger.info("connected")
logger.info("Connected")

def _disconnect(self) -> None:
logger.info("disconnecting")
logger.info("Disconnecting")
self._boot_status = BootStatus.QUITTING
self._teardown_shm()
cast(ThreadedOscProtocol, self._osc_protocol).disconnect()
Expand All @@ -366,22 +371,22 @@ def _disconnect(self) -> None:
self._contexts.remove(self)
self._is_owner = False
self._boot_status = BootStatus.OFFLINE
logger.info("disconnected")
logger.info("Disconnected")

def _setup_notifications(self) -> None:
response: Union[DoneResponse, FailResponse] = NotifyRequest(True).communicate(
server=self
)
if isinstance(response, FailResponse):
logger.info("Setting up notifications")
response = ToggleNotifications(True).communicate(server=self)
if response is None or not isinstance(response, (DoneInfo, FailInfo)):
raise RuntimeError
if isinstance(response, FailInfo):
self._shutdown()
raise TooManyClients
if len(response.action) == 2: # supernova doesn't provide a max logins value
self._client_id, self._maximum_logins = (
response.action[1],
self._options.maximum_logins,
)
if len(response.other) == 1: # supernova doesn't provide a max logins value
self._client_id = int(response.other[0])
self._maximum_logins = self._options.maximum_logins
else:
self._client_id, self._maximum_logins = response.action[1:3]
self._client_id = int(response.other[0])
self._maximum_logins = int(response.other[1])

def _shutdown(self):
if self.is_owner:
Expand All @@ -396,6 +401,7 @@ def boot(self, *, options: Optional[Options] = None, **kwargs) -> "Server":
raise ServerOnline
self._boot_status = BootStatus.BOOTING
self._options = new(options or self._options, **kwargs)
logger.debug(f"Options: {self._options}")
try:
self._process_protocol.boot(self._options)
except ServerCannotBoot:
Expand Down Expand Up @@ -423,21 +429,30 @@ def disconnect(self) -> "Server":
self._disconnect()
return self

async def get_bus(self, bus: Bus) -> float:
raise NotImplementedError
def get_bus(self, bus: Bus) -> float:
if bus.calculation_rate != CalculationRate.CONTROL:
raise InvalidCalculationRate
return cast(
GetControlBusInfo, GetControlBus(bus_ids=[bus.id_]).communicate(server=self)
).items[0][-1]

def query_buffer(self, buffer: Buffer) -> BufferInfo:
request = BufferQueryRequest(buffer_ids=[buffer.id_])
response = request.communicate(server=self)
return response
return cast(
BufferInfo, QueryBuffer(buffer_ids=[buffer.id_]).communicate(server=self)
)

def query_node(self, node: Node) -> NodeInfo:
return NodeQueryRequest(node_id=node).communicate(server=self)
return cast(NodeInfo, QueryNode(node_ids=[node.id_]).communicate(server=self))

def query_tree(self) -> Union[Awaitable[QueryTreeGroup], QueryTreeGroup]:
request = GroupQueryTreeRequest(node_id=0, include_controls=True)
response = request.communicate(server=self)
return response.query_tree_group
def query_tree(
self,
) -> Union[
Awaitable[Union[QueryTreeGroup, QueryTreeSynth]],
Union[QueryTreeGroup, QueryTreeSynth],
]:
return QueryTreeGroup.from_query_tree_info(
cast(QueryTreeInfo, QueryTree(items=[(0, True)]).communicate(server=self))
)

def quit(self, force: bool = False) -> "Server":
if self._boot_status != BootStatus.ONLINE:
Expand All @@ -447,7 +462,7 @@ def quit(self, force: bool = False) -> "Server":
"Cannot quit unowned server without force flag."
)
try:
QuitRequest().communicate(server=self)
Quit().communicate(server=self)
except OscProtocolOffline:
pass
self._teardown_shm()
Expand All @@ -458,14 +473,13 @@ def quit(self, force: bool = False) -> "Server":
def sync(self, sync_id: Optional[int] = None) -> "Server":
if self._boot_status not in (BootStatus.BOOTING, BootStatus.ONLINE):
raise ServerOffline
SyncRequest(
Sync(
sync_id=sync_id if sync_id is not None else self._get_next_sync_id()
).communicate(server=self)
return self


class AsyncServer(RealtimeContext):

### INITIALIZER ###

def __init__(self, options: Optional[Options] = None, **kwargs):
Expand Down Expand Up @@ -507,19 +521,18 @@ async def _disconnect(self) -> None:

async def _setup_notifications(self) -> None:
logger.info("Setting up notifications")
response: Union[DoneResponse, FailResponse] = await NotifyRequest(
True
).communicate_async(server=self)
if isinstance(response, FailResponse):
response = await ToggleNotifications(True).communicate_async(server=self)
if response is None or not isinstance(response, (DoneInfo, FailInfo)):
raise RuntimeError
if isinstance(response, FailInfo):
await self._shutdown()
raise TooManyClients
if len(response.action) == 2: # supernova doesn't provide a max logins value
self._client_id, self._maximum_logins = (
response.action[1],
self._options.maximum_logins,
)
if len(response.other) == 1: # supernova doesn't provide a max logins value
self._client_id = int(response.other[0])
self._maximum_logins = self._options.maximum_logins
else:
self._client_id, self._maximum_logins = response.action[1:3]
self._client_id = int(response.other[0])
self._maximum_logins = int(response.other[1])

async def _shutdown(self):
if self.is_owner:
Expand All @@ -536,6 +549,7 @@ async def boot(
raise ServerOnline
self._boot_status = BootStatus.BOOTING
self._options = new(options or self._options, **kwargs)
logger.debug(f"Options: {self._options}")
await self._process_protocol.boot(self._options)
if not await self._process_protocol.boot_future:
self._boot_status = BootStatus.OFFLINE
Expand Down Expand Up @@ -564,20 +578,32 @@ async def disconnect(self) -> "AsyncServer":
return self

async def get_bus(self, bus: Bus) -> float:
raise NotImplementedError
if bus.calculation_rate != CalculationRate.CONTROL:
raise InvalidCalculationRate
return cast(
GetControlBusInfo,
await GetControlBus(bus_ids=[bus.id_]).communicate_async(server=self),
).items[0][-1]

async def query_buffer(self, buffer: Buffer) -> BufferInfo:
request = BufferQueryRequest(buffer_ids=[buffer.id_])
response = await request.communicate_async(server=self)
return response
return cast(
BufferInfo,
await QueryBuffer(buffer_ids=[buffer.id_]).communicate_async(server=self),
)

async def query_node(self, node: Node) -> NodeInfo:
return await NodeQueryRequest(node_id=node).communicate_async(server=self)
return cast(
NodeInfo,
await QueryNode(node_ids=[node.id_]).communicate_async(server=self),
)

async def query_tree(self) -> QueryTreeGroup:
request = GroupQueryTreeRequest(node_id=0, include_controls=True)
response = await request.communicate_async(server=self)
return response.query_tree_group
async def query_tree(self) -> Union[QueryTreeGroup, QueryTreeSynth]:
return QueryTreeGroup.from_query_tree_info(
cast(
QueryTreeInfo,
await QueryTree(items=[(0, True)]).communicate_async(server=self),
)
)

async def quit(self, force: bool = False) -> "AsyncServer":
if self._boot_status != BootStatus.ONLINE:
Expand All @@ -587,7 +613,7 @@ async def quit(self, force: bool = False) -> "AsyncServer":
"Cannot quit unowned server without force flag."
)
try:
await QuitRequest().communicate_async(server=self, sync=True, timeout=1)
await Quit().communicate_async(server=self, sync=True, timeout=1)
except (OscProtocolOffline, asyncio.TimeoutError):
pass
self._process_protocol.quit()
Expand All @@ -597,7 +623,7 @@ async def quit(self, force: bool = False) -> "AsyncServer":
async def sync(self, sync_id: Optional[int] = None) -> "AsyncServer":
if self._boot_status not in (BootStatus.BOOTING, BootStatus.ONLINE):
raise ServerOffline
await SyncRequest(
await Sync(
sync_id=sync_id if sync_id is not None else self._get_next_sync_id()
).communicate_async(server=self)
return self
Loading

0 comments on commit 36e13e1

Please sign in to comment.