diff --git a/nightwatch/__init__.py b/nightwatch/__init__.py index bc8c296..777f190 100644 --- a/nightwatch/__init__.py +++ b/nightwatch/__init__.py @@ -1 +1 @@ -__version__ = "0.7.2" +__version__ = "0.8.0" diff --git a/nightwatch/bot/client.py b/nightwatch/bot/client.py index 6ea48aa..47a6de0 100644 --- a/nightwatch/bot/client.py +++ b/nightwatch/bot/client.py @@ -46,13 +46,3 @@ def connected(self, func: Awaitable) -> None: def on_message(self, func: Awaitable) -> None: self._on_message = func - -""" -def hello(): - with connect("ws://localhost:8765") as websocket: - websocket.send("Hello world!") - message = websocket.recv() - print(f"Received: {message}") - -hello() -""" diff --git a/nightwatch/config.py b/nightwatch/config.py index bda0db3..a52c6e1 100644 --- a/nightwatch/config.py +++ b/nightwatch/config.py @@ -8,7 +8,7 @@ from getpass import getuser # Initialization -config_path = Path(os.path.expanduser("~")) / ".config/nightwatch/config.json" +config_path = Path.home() / ".config/nightwatch/config.json" if os.name == "nt": config_path = Path(f"C:\\Users\\{getuser()}\\AppData\\Local\\Nightwatch\\config.json") diff --git a/nightwatch/server/__init__.py b/nightwatch/server/__init__.py index 0c0373d..c5ad645 100644 --- a/nightwatch/server/__init__.py +++ b/nightwatch/server/__init__.py @@ -2,6 +2,7 @@ # Modules import orjson +from pydantic import ValidationError from websockets import WebSocketCommonProtocol from websockets.exceptions import ConnectionClosedError @@ -14,7 +15,6 @@ class NightwatchStateManager(): def __init__(self) -> None: self.clients = {} - self.message_buffer = [] def add_client(self, client: WebSocketCommonProtocol) -> None: self.clients[client] = None @@ -35,8 +35,20 @@ async def connection(websocket: WebSocketCommonProtocol) -> None: await client.send("error", text = "Specified command type does not exist or is missing.") continue + callback = message.get("callback") + if callback is not None: + client.set_callback(callback) + command, payload_type = registry.commands[message["type"]] - await command(state, client, payload_type(**(message.get("data") or {}))) + if payload_type is None: + await command(state, client) + + else: + try: + await command(state, client, payload_type(**(message.get("data") or {}))) + + except ValidationError as error: + await client.send("error", text = error) except orjson.JSONDecodeError: log.warn("ws", "Failed to decode JSON from client.") diff --git a/nightwatch/server/__main__.py b/nightwatch/server/__main__.py index 4d93bd1..7a67181 100644 --- a/nightwatch/server/__main__.py +++ b/nightwatch/server/__main__.py @@ -1,15 +1,21 @@ # Copyright (c) 2024 iiPython # Modules +import os import asyncio from websockets.server import serve from . import connection +from nightwatch import __version__ +from nightwatch.logging import log + # Entrypoint async def main() -> None: - async with serve(connection, "localhost", 8000): + host, port = os.getenv("HOST", "localhost"), int(os.getenv("PORT", 8000)) + log.info("ws", f"Nightwatch v{__version__} running on ws://{host}:{port}/") + async with serve(connection, host, port): await asyncio.Future() if __name__ == "__main__": diff --git a/nightwatch/server/utils/commands.py b/nightwatch/server/utils/commands.py index 4004073..6feb026 100644 --- a/nightwatch/server/utils/commands.py +++ b/nightwatch/server/utils/commands.py @@ -23,7 +23,10 @@ def __init__(self) -> None: def command(self, name: str) -> Callable: def callback(function: Callable) -> None: - self.commands[name] = (function, function.__annotations__["data"]) + self.commands[name] = ( + function, + function.__annotations__["data"] if "data" in function.__annotations__ else None + ) return callback @@ -45,17 +48,25 @@ async def command_identify(state, client: NightwatchClient, data: models.Identif client.identified = True await client.send("server", name = Constant.SERVER_NAME, online = len(state.clients)) - websockets.broadcast(state.clients.keys(), orjson.dumps({ + websockets.broadcast(state.clients, orjson.dumps({ "type": "message", "data": {"text": f"{data.name} joined the chatroom.", "user": Constant.SERVER_USER} - })) + }).decode()) @registry.command("message") async def command_message(state, client: NightwatchClient, data: models.MessageModel) -> None: if not client.identified: return await client.send("error", text = "You must identify before sending a message.") - websockets.broadcast(state.clients.keys(), orjson.dumps({ + websockets.broadcast(state.clients, orjson.dumps({ "type": "message", "data": {"text": data.text, "user": client.user_data} - })) + }).decode()) + +@registry.command("members") +async def command_members(state, client: NightwatchClient) -> None: + return await client.send("members", list = list(state.clients.values())) + +@registry.command("ping") +async def command_ping(state, client: NightwatchClient) -> None: + return await client.send("pong") diff --git a/nightwatch/server/utils/websocket.py b/nightwatch/server/utils/websocket.py index f349493..aa7f845 100644 --- a/nightwatch/server/utils/websocket.py +++ b/nightwatch/server/utils/websocket.py @@ -11,13 +11,21 @@ class NightwatchClient(): data serialization through orjson.""" def __init__(self, state, client: WebSocketCommonProtocol) -> None: self.client = client - self.identified = False + self.identified, self.callback = False, None self.state = state self.state.add_client(client) async def send(self, message_type: str, **message_data) -> None: - await self.client.send(orjson.dumps({"type": message_type, "data": message_data})) + payload = {"type": message_type, "data": message_data} + if self.callback is not None: + payload["callback"] = self.callback + self.callback = None + + await self.client.send(orjson.dumps(payload).decode()) + + def set_callback(self, callback: str) -> None: + self.callback = callback # Handle user data (ie. name and color) def set_user_data(self, data: dict[str, Any]) -> None: