Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-697 | Adds token auth to the APIClient #533

Merged
merged 8 commits into from
Mar 1, 2024
Merged
64 changes: 39 additions & 25 deletions fedn/fedn/network/api/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os

import requests

Expand All @@ -18,26 +19,39 @@ class APIClient:
:type verify: bool
"""

def __init__(self, host, port, secure=False, verify=False):
def __init__(self, host, port=None, secure=False, verify=False, token=None, auth_scheme=None):
self.host = host
self.port = port
self.secure = secure
self.verify = verify
self.header = {}
# Auth scheme passed as argument overrides environment variable.
# "Token" is the default auth scheme.
if not auth_scheme:
auth_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Token")
# Override potential env variable if token is passed as argument.
if not token:
token = os.environ.get("FEDN_AUTH_TOKEN", False)

if token:
self.header = {"Authorization": f"{auth_scheme} {token}"}

def _get_url(self, endpoint):
if self.secure:
protocol = 'https'
else:
protocol = 'http'
return f'{protocol}://{self.host}:{self.port}/{endpoint}'
if self.port:
return f'{protocol}://{self.host}:{self.port}/{endpoint}'
return f'{protocol}://{self.host}/{endpoint}'

def get_model_trail(self):
""" Get the model trail.

:return: The model trail as dict including commit timestamp.
:rtype: dict
"""
response = requests.get(self._get_url('get_model_trail'), verify=self.verify)
response = requests.get(self._get_url('get_model_trail'), verify=self.verify, headers=self.header)
return response.json()

def list_models(self, session_id=None):
Expand All @@ -46,7 +60,7 @@ def list_models(self, session_id=None):
:return: All models.
:rtype: dict
"""
response = requests.get(self._get_url('list_models'), params={'session_id': session_id}, verify=self.verify)
response = requests.get(self._get_url('list_models'), params={'session_id': session_id}, verify=self.verify, headers=self.header)
return response.json()

def list_clients(self):
Expand All @@ -55,7 +69,7 @@ def list_clients(self):
return: All clients.
rtype: dict
"""
response = requests.get(self._get_url('list_clients'))
response = requests.get(self._get_url('list_clients'), verify=self.verify, headers=self.header)
return response.json()

def get_active_clients(self, combiner_id):
Expand All @@ -66,7 +80,7 @@ def get_active_clients(self, combiner_id):
:return: All active clients.
:rtype: dict
"""
response = requests.get(self._get_url('get_active_clients'), params={'combiner': combiner_id}, verify=self.verify)
response = requests.get(self._get_url('get_active_clients'), params={'combiner': combiner_id}, verify=self.verify, headers=self.header)
return response.json()

def get_client_config(self, checksum=True):
Expand All @@ -78,7 +92,7 @@ def get_client_config(self, checksum=True):
:return: The client configuration.
:rtype: dict
"""
response = requests.get(self._get_url('get_client_config'), params={'checksum': checksum}, verify=self.verify)
response = requests.get(self._get_url('get_client_config'), params={'checksum': checksum}, verify=self.verify, headers=self.header)
return response.json()

def list_combiners(self):
Expand All @@ -87,7 +101,7 @@ def list_combiners(self):
:return: All combiners with info.
:rtype: dict
"""
response = requests.get(self._get_url('list_combiners'))
response = requests.get(self._get_url('list_combiners'), verify=self.verify, headers=self.header)
return response.json()

def get_combiner(self, combiner_id):
Expand All @@ -98,7 +112,7 @@ def get_combiner(self, combiner_id):
:return: The combiner info.
:rtype: dict
"""
response = requests.get(self._get_url(f'get_combiner?combiner={combiner_id}'), verify=self.verify)
response = requests.get(self._get_url(f'get_combiner?combiner={combiner_id}'), verify=self.verify, headers=self.header)
return response.json()

def list_rounds(self):
Expand All @@ -107,7 +121,7 @@ def list_rounds(self):
:return: All rounds with config and metrics.
:rtype: dict
"""
response = requests.get(self._get_url('list_rounds'))
response = requests.get(self._get_url('list_rounds'), verify=self.verify, headers=self.header)
return response.json()

def get_round(self, round_id):
Expand All @@ -118,7 +132,7 @@ def get_round(self, round_id):
:return: The round config and metrics.
:rtype: dict
"""
response = requests.get(self._get_url(f'get_round?round_id={round_id}'), verify=self.verify)
response = requests.get(self._get_url(f'get_round?round_id={round_id}'), verify=self.verify, headers=self.header)
return response.json()

def start_session(self, session_id=None, aggregator='fedavg', model_id=None, round_timeout=180, rounds=5, round_buffer_size=-1, delete_models=True,
Expand Down Expand Up @@ -162,7 +176,7 @@ def start_session(self, session_id=None, aggregator='fedavg', model_id=None, rou
'helper': helper,
'min_clients': min_clients,
'requested_clients': requested_clients
}, verify=self.verify
}, verify=self.verify, headers=self.header
)
return response.json()

Expand All @@ -172,7 +186,7 @@ def list_sessions(self):
:return: All sessions in dict.
:rtype: dict
"""
response = requests.get(self._get_url('list_sessions'), verify=self.verify)
response = requests.get(self._get_url('list_sessions'), verify=self.verify, headers=self.header)
return response.json()

def get_session(self, session_id):
Expand All @@ -183,7 +197,7 @@ def get_session(self, session_id):
:return: The session as a json object.
:rtype: dict
"""
response = requests.get(self._get_url(f'get_session?session_id={session_id}'), self.verify)
response = requests.get(self._get_url(f'get_session?session_id={session_id}'), self.verify, headers=self.header)
return response.json()

def session_is_finished(self, session_id):
Expand Down Expand Up @@ -218,7 +232,7 @@ def set_package(self, path: str, helper: str, name: str = None, description: str
"""
with open(path, 'rb') as file:
response = requests.post(self._get_url('set_package'), files={'file': file}, data={
'helper': helper, 'name': name, 'description': description}, verify=self.verify)
'helper': helper, 'name': name, 'description': description}, verify=self.verify, headers=self.header)
return response.json()

def get_package(self):
Expand All @@ -227,7 +241,7 @@ def get_package(self):
:return: The compute package with info.
:rtype: dict
"""
response = requests.get(self._get_url('get_package'), verify=self.verify)
response = requests.get(self._get_url('get_package'), verify=self.verify, headers=self.header)
return response.json()

def list_compute_packages(self):
Expand All @@ -236,7 +250,7 @@ def list_compute_packages(self):
:return: All compute packages with info.
:rtype: dict
"""
response = requests.get(self._get_url('list_compute_packages'), verify=self.verify)
response = requests.get(self._get_url('list_compute_packages'), verify=self.verify, headers=self.header)
return response.json()

def download_package(self, path):
Expand All @@ -247,7 +261,7 @@ def download_package(self, path):
:return: Message with success or failure.
:rtype: dict
"""
response = requests.get(self._get_url('download_package'), verify=self.verify)
response = requests.get(self._get_url('download_package'), verify=self.verify, headers=self.header)
if response.status_code == 200:
with open(path, 'wb') as file:
file.write(response.content)
Expand All @@ -261,7 +275,7 @@ def get_package_checksum(self):
:return: The checksum.
:rtype: dict
"""
response = requests.get(self._get_url('get_package_checksum'), verify=self.verify)
response = requests.get(self._get_url('get_package_checksum'), verify=self.verify, headers=self.header)
return response.json()

def get_latest_model(self):
Expand All @@ -270,7 +284,7 @@ def get_latest_model(self):
:return: The latest model id.
:rtype: dict
"""
response = requests.get(self._get_url('get_latest_model'), verify=self.verify)
response = requests.get(self._get_url('get_latest_model'), verify=self.verify, headers=self.header)
return response.json()

def get_initial_model(self):
Expand All @@ -279,7 +293,7 @@ def get_initial_model(self):
:return: The initial model id.
:rtype: dict
"""
response = requests.get(self._get_url('get_initial_model'), verify=self.verify)
response = requests.get(self._get_url('get_initial_model'), verify=self.verify, headers=self.header)
return response.json()

def set_initial_model(self, path):
Expand All @@ -291,7 +305,7 @@ def set_initial_model(self, path):
:rtype: dict
"""
with open(path, 'rb') as file:
response = requests.post(self._get_url('set_initial_model'), files={'file': file}, verify=self.verify)
response = requests.post(self._get_url('set_initial_model'), files={'file': file}, verify=self.verify, headers=self.header)
return response.json()

def get_controller_status(self):
Expand All @@ -300,7 +314,7 @@ def get_controller_status(self):
:return: The status of the controller.
:rtype: dict
"""
response = requests.get(self._get_url('get_controller_status'), verify=self.verify)
response = requests.get(self._get_url('get_controller_status'), verify=self.verify, headers=self.header)
return response.json()

def get_events(self, **kwargs):
Expand All @@ -309,7 +323,7 @@ def get_events(self, **kwargs):
:return: The events in dict
:rtype: dict
"""
response = requests.get(self._get_url('get_events'), params=kwargs, verify=self.verify)
response = requests.get(self._get_url('get_events'), params=kwargs, verify=self.verify, headers=self.header)
return response.json()

def list_validations(self, **kwargs):
Expand All @@ -318,5 +332,5 @@ def list_validations(self, **kwargs):
:return: All validations in dict.
:rtype: dict
"""
response = requests.get(self._get_url('list_validations'), params=kwargs, verify=self.verify)
response = requests.get(self._get_url('list_validations'), params=kwargs, verify=self.verify, headers=self.header)
return response.json()
Loading