Skip to content

Commit

Permalink
Milestone: Client can send hardware profile to the main hub
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Apr 2, 2023
1 parent dc63215 commit 5f3eca3
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 22 deletions.
31 changes: 13 additions & 18 deletions discord_tron_client/__main__.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,38 @@
import asyncio
from .ws_client import websocket_client
import logging
logging.basicConfig(level=logging.INFO)
from discord_tron_client.classes.app_config import AppConfig
config = AppConfig()


async def main():
logging.info("Starting WebSocket client...")
await websocket_client(config)

if __name__ == '__main__':
try:
# Detect an expired token.
logging.info("Inspecting auth ticket...")
from discord_tron_client.classes.auth import Auth
current_ticket = config.get_auth_ticket()
auth = Auth(config, current_ticket["access_token"], current_ticket["refresh_token"], current_ticket["expires_in"], current_ticket["issued_at"])
try:
is_expired = auth.is_token_expired()
except Exception as e:
logging.error(f"Error checking token expiration: {e}")
is_expired = True
if is_expired:
logging.warning("Access token is expired. Attempting to refresh...")
new_ticket = auth.refresh_client_token(current_ticket["refresh_token"])
import json
print(f"New ticket: {json.dumps(new_ticket, indent=4)}")
auth.get()

# Start the WebSocket client in the background
asyncio.get_event_loop().run_until_complete(websocket_client(config))
startup_sequence = []
from discord_tron_client.classes.message import WebsocketMessage
# Add any startup sequence here
from discord_tron_client.classes.hardware import HardwareInfo
hardware_info = HardwareInfo()
machine_info = hardware_info.get_machine_info()
hardware_info_message = WebsocketMessage(message_type="hardware_info", module_name="system", module_command="update", data=machine_info)
startup_sequence.append(hardware_info_message.to_json())
asyncio.get_event_loop().run_until_complete(websocket_client(config, startup_sequence))

# Start the Flask server
from discord_tron_client.app_factory import create_app
app = create_app()
app.run()
asyncio.run(main())
except KeyboardInterrupt:
logging.info("Shutting down...")
exit(0)
except Exception as e:
logging.error(f"An error occurred: {e}")
import traceback
logging.error(f"Stack trace: {traceback.format_exc()}")
exit(1)
6 changes: 5 additions & 1 deletion discord_tron_client/classes/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def __init__(self):
# Retrieve the OAuth ticket information.
def get_auth_ticket(self):
with open(self.auth_ticket_path, "r") as auth_ticket:
return json.load(auth_ticket)
auth_data = json.load(auth_ticket)
return auth_data["access_token"]

def get_master_api_key(self):
return self.config.get("master_api_key", None)

def get_concurrent_slots(self):
return self.config.get("concurrent_slots", 1)
Expand Down
47 changes: 46 additions & 1 deletion discord_tron_client/classes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,65 @@ def __init__(self, config: AppConfig, access_token: str, refresh_token: str, exp
self.token_received_at = token_received_at
self.base_url = config.get_master_url()

# When it's expired, we have to refresh the token.
def refresh_client_token(self, refresh_token):
url = self.base_url + "/refresh_token"
payload = {"refresh_token": refresh_token}
import requests
response = requests.post(url, json=payload)

if response.status_code == 200:
self.write_auth_ticket(response)
return response.json()
else:
raise Exception("Error refreshing token: {}".format(response.text))

# Before the token expires, we can get a new one normally.
def get_access_token(self):
url = self.base_url + "/authorize"
from discord_tron_client.classes.app_config import AppConfig
config = AppConfig()
api_key = config.get_master_api_key()
auth_ticket = config.get_auth_ticket()
payload = { "api_key": api_key, "client_id": auth_ticket["client_id"] }

import requests
response = requests.post(url, json=payload)

if response.status_code == 200:
self.write_auth_ticket(response)
new_ticket = response.json()['access_token']
self.access_token = new_ticket["access_token"]
self.expires_in = new_ticket["expires_in"]
self.token_received_at = new_ticket["issued_at"]
return response.json()
else:
raise Exception("Error refreshing token: {}".format(response.text))

def write_auth_ticket(self, response):
import json
from discord_tron_client.classes.app_config import AppConfig
config = AppConfig()
with open(config.auth_ticket_path, "w") as f:
f.write(json.dumps(response.json()))

def is_token_expired(self):
token_received_at = datetime.fromisoformat(self.token_received_at).timestamp()
expires_in = int(self.expires_in)
test = time.time() >= (token_received_at + expires_in)
logging.info(f"Token expired? {test}")
return test
return test

# Request an access token from the auth server, refreshing it if necessary.
def get(self):
try:
is_expired = self.is_token_expired()
except Exception as e:
logging.error(f"Error checking token expiration: {e}")
is_expired = True
if is_expired:
logging.warning("Access token is expired. Attempting to refresh...")
current_ticket = self.refresh_client_token(current_ticket["refresh_token"])
import json
print(f"New ticket: {json.dumps(current_ticket, indent=4)}")
self.get_access_token()
78 changes: 78 additions & 0 deletions discord_tron_client/classes/hardware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import subprocess
import json, logging

class HardwareInfo:
def __init__(self):
self.gpu_type = "Unknown type"
self.cpu_type = "Unknown type"
self.memory_amount = None
self.video_memory_amount = None
self.disk_space_total = None
self.disk_space_used = None

def get_gpu_info(self):
try:
output = subprocess.check_output(["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"])
self.gpu_type = output.decode().strip()
except:
try:
output = subprocess.check_output(["rocm-smi", "--showproductname"])
self.gpu_type = output.decode().strip()
except:
self.gpu_type = "Unknown"

def get_cpu_info(self):
try:
with open('/proc/cpuinfo') as f:
for line in f:
if line.strip() and line.startswith('model name'):
self.cpu_type = line.strip().split(': ')[1]
break
except:
self.cpu_type = "Unknown"

def get_memory_info(self):
try:
with open('/proc/meminfo') as f:
for line in f:
if line.startswith('MemTotal:'):
self.memory_amount = int(int(line.split()[1]) / 1024 / 1024)
break
except:
self.memory_amount = "Unknown"

def get_video_memory_info(self):
try:
output = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"])
self.video_memory_amount = int(output.decode().strip()) / 1024
except:
self.video_memory_amount = "Unknown"

def get_disk_space_info(self):
try:
output = subprocess.check_output(["df", "-h"])
for line in output.decode().split('\n'):
if '/dev/' in line and not line.startswith('tmpfs'):
line = line.split()
self.disk_space_total = line[1]
self.disk_space_used = line[2]
break
except:
self.disk_space_total = "Unknown"
self.disk_space_used = "Unknown"

def get_machine_info(self):
self.get_gpu_info()
self.get_cpu_info()
self.get_memory_info()
self.get_video_memory_info()
self.get_disk_space_info()

return {
'gpu_type': self.gpu_type,
'cpu_type': self.cpu_type,
'memory_amount': self.memory_amount,
'video_memory_amount': self.video_memory_amount,
'disk_space_total': self.disk_space_total,
'disk_space_used': self.disk_space_used
}
67 changes: 67 additions & 0 deletions discord_tron_client/classes/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import time

class WebsocketMessage:
def __init__(self, message_type: str, module_name: str, module_command, data=None, arguments=None):
self._message_type = message_type
self._module_name = module_name
self._module_command = module_command
self._timestamp = time.time()
self._data = data or {}
self._arguments = arguments or {}

@property
def message_type(self):
return self._message_type

@message_type.setter
def message_type(self, value: str):
self._message_type = value

@property
def module_name(self):
return self._module_name

@module_name.setter
def module_name(self, value):
self._module_name = value

@property
def module_command(self):
return self._module_command

@module_command.setter
def module_command(self, value):
self._module_command = value

@property
def timestamp(self):
return self._timestamp

@property
def data(self):
return self._data

@data.setter
def data(self, value):
self._data = value

@property
def arguments(self):
return self._arguments

@arguments.setter
def arguments(self, value):
self._arguments = value

def to_dict(self):
return {
"message_type": self.message_type,
"module_name": self.module_name,
"module_command": self.module_command,
"timestamp": self.timestamp,
"data": self.data,
"arguments": self.arguments
}

def to_json(self):
return self.to_dict()
12 changes: 10 additions & 2 deletions discord_tron_client/ws_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import ssl, websockets
from discord_tron_client.classes.app_config import AppConfig

async def websocket_client(config: AppConfig):
async def websocket_client(config: AppConfig, startup_sequence:str = None):
websocket_config = config.get_websocket_config()
logging.info(f"Retrieved websocket config: {websocket_config}")
hub_url = str(websocket_config["protocol"]) + "://" + str(websocket_config["host"]) + ":" + str(websocket_config["port"])
Expand All @@ -19,5 +19,13 @@ async def websocket_client(config: AppConfig):
}
logging.info(f"Connecting to {hub_url}...")
async with websockets.connect(hub_url, ssl=ssl_context, extra_headers=headers) as websocket:
# Send the startup sequence
if startup_sequence:
logging.info(f"Sending startup sequence {startup_sequence}")
for message in startup_sequence:
logging.debug(f"Sending startup sequence message: {message}")
await websocket.send(json.dumps(message))
else:
logging.info("No startup sequence found.")
async for message in websocket:
print(f"Received message: {message}")
logging.debug(f"Received message: {message}")

0 comments on commit 5f3eca3

Please sign in to comment.