From 5f3eca33e8ee2e59835629ca9805238db545c556 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 1 Apr 2023 21:12:24 -0700 Subject: [PATCH] Milestone: Client can send hardware profile to the main hub --- discord_tron_client/__main__.py | 31 ++++----- discord_tron_client/classes/app_config.py | 6 +- discord_tron_client/classes/auth.py | 47 +++++++++++++- discord_tron_client/classes/hardware.py | 78 +++++++++++++++++++++++ discord_tron_client/classes/message.py | 67 +++++++++++++++++++ discord_tron_client/ws_client/client.py | 12 +++- 6 files changed, 219 insertions(+), 22 deletions(-) create mode 100644 discord_tron_client/classes/hardware.py create mode 100644 discord_tron_client/classes/message.py diff --git a/discord_tron_client/__main__.py b/discord_tron_client/__main__.py index 541cef8b..2f8afe7b 100644 --- a/discord_tron_client/__main__.py +++ b/discord_tron_client/__main__.py @@ -1,14 +1,10 @@ import asyncio from .ws_client import websocket_client import logging +logging.basicConfig(level=logging.INFO) from discord_tron_client.classes.app_config import AppConfig config = AppConfig() - -async def main(): - logging.info("Starting WebSocket client...") - await websocket_client(config) - if __name__ == '__main__': try: # Detect an expired token. @@ -16,28 +12,27 @@ async def main(): from discord_tron_client.classes.auth import Auth current_ticket = config.get_auth_ticket() auth = Auth(config, current_ticket["access_token"], current_ticket["refresh_token"], current_ticket["expires_in"], current_ticket["issued_at"]) - try: - is_expired = auth.is_token_expired() - except Exception as e: - logging.error(f"Error checking token expiration: {e}") - is_expired = True - if is_expired: - logging.warning("Access token is expired. Attempting to refresh...") - new_ticket = auth.refresh_client_token(current_ticket["refresh_token"]) - import json - print(f"New ticket: {json.dumps(new_ticket, indent=4)}") + auth.get() # Start the WebSocket client in the background - asyncio.get_event_loop().run_until_complete(websocket_client(config)) + startup_sequence = [] + from discord_tron_client.classes.message import WebsocketMessage + # Add any startup sequence here + from discord_tron_client.classes.hardware import HardwareInfo + hardware_info = HardwareInfo() + machine_info = hardware_info.get_machine_info() + hardware_info_message = WebsocketMessage(message_type="hardware_info", module_name="system", module_command="update", data=machine_info) + startup_sequence.append(hardware_info_message.to_json()) + asyncio.get_event_loop().run_until_complete(websocket_client(config, startup_sequence)) # Start the Flask server from discord_tron_client.app_factory import create_app app = create_app() app.run() - asyncio.run(main()) except KeyboardInterrupt: logging.info("Shutting down...") exit(0) except Exception as e: - logging.error(f"An error occurred: {e}") + import traceback + logging.error(f"Stack trace: {traceback.format_exc()}") exit(1) \ No newline at end of file diff --git a/discord_tron_client/classes/app_config.py b/discord_tron_client/classes/app_config.py index 913a6b72..6d112053 100644 --- a/discord_tron_client/classes/app_config.py +++ b/discord_tron_client/classes/app_config.py @@ -26,7 +26,11 @@ def __init__(self): # Retrieve the OAuth ticket information. def get_auth_ticket(self): with open(self.auth_ticket_path, "r") as auth_ticket: - return json.load(auth_ticket) + auth_data = json.load(auth_ticket) + return auth_data["access_token"] + + def get_master_api_key(self): + return self.config.get("master_api_key", None) def get_concurrent_slots(self): return self.config.get("concurrent_slots", 1) diff --git a/discord_tron_client/classes/auth.py b/discord_tron_client/classes/auth.py index 41b9d1be..8a6fb089 100644 --- a/discord_tron_client/classes/auth.py +++ b/discord_tron_client/classes/auth.py @@ -12,6 +12,7 @@ def __init__(self, config: AppConfig, access_token: str, refresh_token: str, exp self.token_received_at = token_received_at self.base_url = config.get_master_url() + # When it's expired, we have to refresh the token. def refresh_client_token(self, refresh_token): url = self.base_url + "/refresh_token" payload = {"refresh_token": refresh_token} @@ -19,13 +20,57 @@ def refresh_client_token(self, refresh_token): response = requests.post(url, json=payload) if response.status_code == 200: + self.write_auth_ticket(response) return response.json() else: raise Exception("Error refreshing token: {}".format(response.text)) + # Before the token expires, we can get a new one normally. + def get_access_token(self): + url = self.base_url + "/authorize" + from discord_tron_client.classes.app_config import AppConfig + config = AppConfig() + api_key = config.get_master_api_key() + auth_ticket = config.get_auth_ticket() + payload = { "api_key": api_key, "client_id": auth_ticket["client_id"] } + + import requests + response = requests.post(url, json=payload) + + if response.status_code == 200: + self.write_auth_ticket(response) + new_ticket = response.json()['access_token'] + self.access_token = new_ticket["access_token"] + self.expires_in = new_ticket["expires_in"] + self.token_received_at = new_ticket["issued_at"] + return response.json() + else: + raise Exception("Error refreshing token: {}".format(response.text)) + + def write_auth_ticket(self, response): + import json + from discord_tron_client.classes.app_config import AppConfig + config = AppConfig() + with open(config.auth_ticket_path, "w") as f: + f.write(json.dumps(response.json())) + def is_token_expired(self): token_received_at = datetime.fromisoformat(self.token_received_at).timestamp() expires_in = int(self.expires_in) test = time.time() >= (token_received_at + expires_in) logging.info(f"Token expired? {test}") - return test \ No newline at end of file + return test + + # Request an access token from the auth server, refreshing it if necessary. + def get(self): + try: + is_expired = self.is_token_expired() + except Exception as e: + logging.error(f"Error checking token expiration: {e}") + is_expired = True + if is_expired: + logging.warning("Access token is expired. Attempting to refresh...") + current_ticket = self.refresh_client_token(current_ticket["refresh_token"]) + import json + print(f"New ticket: {json.dumps(current_ticket, indent=4)}") + self.get_access_token() \ No newline at end of file diff --git a/discord_tron_client/classes/hardware.py b/discord_tron_client/classes/hardware.py new file mode 100644 index 00000000..dc353021 --- /dev/null +++ b/discord_tron_client/classes/hardware.py @@ -0,0 +1,78 @@ +import subprocess +import json, logging + +class HardwareInfo: + def __init__(self): + self.gpu_type = "Unknown type" + self.cpu_type = "Unknown type" + self.memory_amount = None + self.video_memory_amount = None + self.disk_space_total = None + self.disk_space_used = None + + def get_gpu_info(self): + try: + output = subprocess.check_output(["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"]) + self.gpu_type = output.decode().strip() + except: + try: + output = subprocess.check_output(["rocm-smi", "--showproductname"]) + self.gpu_type = output.decode().strip() + except: + self.gpu_type = "Unknown" + + def get_cpu_info(self): + try: + with open('/proc/cpuinfo') as f: + for line in f: + if line.strip() and line.startswith('model name'): + self.cpu_type = line.strip().split(': ')[1] + break + except: + self.cpu_type = "Unknown" + + def get_memory_info(self): + try: + with open('/proc/meminfo') as f: + for line in f: + if line.startswith('MemTotal:'): + self.memory_amount = int(int(line.split()[1]) / 1024 / 1024) + break + except: + self.memory_amount = "Unknown" + + def get_video_memory_info(self): + try: + output = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"]) + self.video_memory_amount = int(output.decode().strip()) / 1024 + except: + self.video_memory_amount = "Unknown" + + def get_disk_space_info(self): + try: + output = subprocess.check_output(["df", "-h"]) + for line in output.decode().split('\n'): + if '/dev/' in line and not line.startswith('tmpfs'): + line = line.split() + self.disk_space_total = line[1] + self.disk_space_used = line[2] + break + except: + self.disk_space_total = "Unknown" + self.disk_space_used = "Unknown" + + def get_machine_info(self): + self.get_gpu_info() + self.get_cpu_info() + self.get_memory_info() + self.get_video_memory_info() + self.get_disk_space_info() + + return { + 'gpu_type': self.gpu_type, + 'cpu_type': self.cpu_type, + 'memory_amount': self.memory_amount, + 'video_memory_amount': self.video_memory_amount, + 'disk_space_total': self.disk_space_total, + 'disk_space_used': self.disk_space_used + } \ No newline at end of file diff --git a/discord_tron_client/classes/message.py b/discord_tron_client/classes/message.py new file mode 100644 index 00000000..f2fb9677 --- /dev/null +++ b/discord_tron_client/classes/message.py @@ -0,0 +1,67 @@ +import time + +class WebsocketMessage: + def __init__(self, message_type: str, module_name: str, module_command, data=None, arguments=None): + self._message_type = message_type + self._module_name = module_name + self._module_command = module_command + self._timestamp = time.time() + self._data = data or {} + self._arguments = arguments or {} + + @property + def message_type(self): + return self._message_type + + @message_type.setter + def message_type(self, value: str): + self._message_type = value + + @property + def module_name(self): + return self._module_name + + @module_name.setter + def module_name(self, value): + self._module_name = value + + @property + def module_command(self): + return self._module_command + + @module_command.setter + def module_command(self, value): + self._module_command = value + + @property + def timestamp(self): + return self._timestamp + + @property + def data(self): + return self._data + + @data.setter + def data(self, value): + self._data = value + + @property + def arguments(self): + return self._arguments + + @arguments.setter + def arguments(self, value): + self._arguments = value + + def to_dict(self): + return { + "message_type": self.message_type, + "module_name": self.module_name, + "module_command": self.module_command, + "timestamp": self.timestamp, + "data": self.data, + "arguments": self.arguments + } + + def to_json(self): + return self.to_dict() \ No newline at end of file diff --git a/discord_tron_client/ws_client/client.py b/discord_tron_client/ws_client/client.py index 4ddfa0b4..aa3acb0f 100644 --- a/discord_tron_client/ws_client/client.py +++ b/discord_tron_client/ws_client/client.py @@ -2,7 +2,7 @@ import ssl, websockets from discord_tron_client.classes.app_config import AppConfig -async def websocket_client(config: AppConfig): +async def websocket_client(config: AppConfig, startup_sequence:str = None): websocket_config = config.get_websocket_config() logging.info(f"Retrieved websocket config: {websocket_config}") hub_url = str(websocket_config["protocol"]) + "://" + str(websocket_config["host"]) + ":" + str(websocket_config["port"]) @@ -19,5 +19,13 @@ async def websocket_client(config: AppConfig): } logging.info(f"Connecting to {hub_url}...") async with websockets.connect(hub_url, ssl=ssl_context, extra_headers=headers) as websocket: + # Send the startup sequence + if startup_sequence: + logging.info(f"Sending startup sequence {startup_sequence}") + for message in startup_sequence: + logging.debug(f"Sending startup sequence message: {message}") + await websocket.send(json.dumps(message)) + else: + logging.info("No startup sequence found.") async for message in websocket: - print(f"Received message: {message}") + logging.debug(f"Received message: {message}")