From 250971c853356ba779ff6731365610f1ad346f1c Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Mon, 11 Mar 2024 16:28:26 +0100 Subject: [PATCH 01/16] add jwt decorators and roles --- fedn/fedn/network/api/auth.py | 63 +++++++++++++++++++ fedn/fedn/network/api/client.py | 2 +- fedn/fedn/network/api/server.py | 34 ++++++++++ fedn/fedn/network/api/v1/client_routes.py | 6 ++ fedn/fedn/network/api/v1/combiner_routes.py | 6 ++ fedn/fedn/network/api/v1/model_routes.py | 8 +++ fedn/fedn/network/api/v1/package_routes.py | 7 +++ fedn/fedn/network/api/v1/round_routes.py | 6 ++ fedn/fedn/network/api/v1/session_routes.py | 6 ++ fedn/fedn/network/api/v1/status_routes.py | 6 ++ fedn/fedn/network/api/v1/validation_routes.py | 6 ++ fedn/fedn/network/combiner/connect.py | 7 ++- fedn/fedn/tests/__init__.py | 0 13 files changed, 155 insertions(+), 2 deletions(-) create mode 100644 fedn/fedn/network/api/auth.py delete mode 100644 fedn/fedn/tests/__init__.py diff --git a/fedn/fedn/network/api/auth.py b/fedn/fedn/network/api/auth.py new file mode 100644 index 000000000..dfbb109b2 --- /dev/null +++ b/fedn/fedn/network/api/auth.py @@ -0,0 +1,63 @@ +import os +from functools import wraps + +import jwt +from flask import jsonify, request + +# Define your secret key for JWT +SECRET_KEY = os.environ.get('FEDN_JWT_SECRET_KEY', False) +FEDN_JWT_CUSTOM_CLAIM_KEY = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_KEY', False) +FEDN_JWT_CUSTOM_CLAIM_VALUE = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_VALUE', False) +FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') + +# Fuction to check additional claims in the token +def check_role_claims(payload, role): + if FEDN_JWT_CUSTOM_CLAIM_KEY and FEDN_JWT_CUSTOM_CLAIM_VALUE: + if payload[FEDN_JWT_CUSTOM_CLAIM_KEY] != FEDN_JWT_CUSTOM_CLAIM_VALUE: + return False + if 'role' not in payload: + return False + if payload['role'] != role: + return False + + return True + +# Fuction to check additional cliams in the token +def check_custom_claims(payload): + if FEDN_JWT_CUSTOM_CLAIM_KEY and FEDN_JWT_CUSTOM_CLAIM_VALUE: + if payload[FEDN_JWT_CUSTOM_CLAIM_KEY] != FEDN_JWT_CUSTOM_CLAIM_VALUE: + return False + return True + +# Define the authentication decorator, with role as an argument +def jwt_auth_required(role=None): + def actual_decorator(func): + if not SECRET_KEY: + return func + @wraps(func) + def decorated(*args, **kwargs): + token = request.headers.get('Authorization') + # Get token from the header Bearer + if token and token.startswith(FEDN_AUTH_SCHEME): + token = token.split(' ')[1] + + if not token: + return jsonify({'message': 'Missing token'}), 401 + + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=['HS256']) + if not check_role_claims(payload, role): + return jsonify({'message': 'Invalid token'}), 401 + if not check_custom_claims(payload): + return jsonify({'message': 'Invalid token'}), 401 + + except jwt.ExpiredSignatureError: + return jsonify({'message': 'Token expired'}), 401 + + except jwt.InvalidTokenError: + return jsonify({'message': 'Invalid token'}), 401 + + return func(*args, **kwargs) + + return decorated + return actual_decorator \ No newline at end of file diff --git a/fedn/fedn/network/api/client.py b/fedn/fedn/network/api/client.py index 00532d2da..cf797b53a 100644 --- a/fedn/fedn/network/api/client.py +++ b/fedn/fedn/network/api/client.py @@ -28,7 +28,7 @@ def __init__(self, host, port=None, secure=False, verify=False, token=None, auth # Auth scheme passed as argument overrides environment variable. # "Token" is the default auth scheme. if not auth_scheme: - auth_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Token") + auth_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Bearer") # Override potential env variable if token is passed as argument. if not token: token = os.environ.get("FEDN_AUTH_TOKEN", False) diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index f988befca..f99f28aff 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -3,6 +3,7 @@ from fedn.common.config import (get_controller_config, get_modelstorage_config, get_network_config, get_statestore_config) +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.interface import API from fedn.network.api.v1.client_routes import bp as client_bp from fedn.network.api.v1.combiner_routes import bp as combiner_bp @@ -45,6 +46,7 @@ @app.route("/get_model_trail", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model_trail(): """Get the model trail for a given session. param: session: The session id to get the model trail for. @@ -56,6 +58,7 @@ def get_model_trail(): @app.route("/get_model_ancestors", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model_ancestors(): """Get the ancestors of a model. param: model: The model id to get the ancestors for. @@ -72,6 +75,7 @@ def get_model_ancestors(): @app.route("/get_model_descendants", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model_descendants(): """Get the ancestors of a model. param: model: The model id to get the child for. @@ -88,6 +92,7 @@ def get_model_descendants(): @app.route("/list_models", methods=["GET"]) +@jwt_auth_required(role="admin") def list_models(): """Get models from the statestore. param: @@ -109,6 +114,7 @@ def list_models(): @app.route("/get_model", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model(): """Get a model from the statestore. param: model: The model id to get. @@ -124,6 +130,7 @@ def get_model(): @app.route("/delete_model_trail", methods=["GET", "POST"]) +@jwt_auth_required(role="admin") def delete_model_trail(): """Delete the model trail for a given session. param: session: The session id to delete the model trail for. @@ -135,6 +142,7 @@ def delete_model_trail(): @app.route("/list_clients", methods=["GET"]) +@jwt_auth_required(role="admin") def list_clients(): """Get all clients from the statestore. return: All clients as a json object. @@ -149,6 +157,7 @@ def list_clients(): @app.route("/get_active_clients", methods=["GET"]) +@jwt_auth_required(role="admin") def get_active_clients(): """Get all active clients from the statestore. param: combiner_id: The combiner id to get active clients for. @@ -166,6 +175,7 @@ def get_active_clients(): @app.route("/list_combiners", methods=["GET"]) +@jwt_auth_required(role="admin") def list_combiners(): """Get all combiners in the network. return: All combiners as a json object. @@ -179,6 +189,7 @@ def list_combiners(): @app.route("/get_combiner", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiner(): """Get a combiner from the statestore. param: combiner_id: The combiner id to get. @@ -196,6 +207,7 @@ def get_combiner(): @app.route("/list_rounds", methods=["GET"]) +@jwt_auth_required(role="admin") def list_rounds(): """Get all rounds from the statestore. return: All rounds as a json object. @@ -205,6 +217,7 @@ def list_rounds(): @app.route("/get_round", methods=["GET"]) +@jwt_auth_required(role="admin") def get_round(): """Get a round from the statestore. param: round_id: The round id to get. @@ -219,6 +232,7 @@ def get_round(): @app.route("/start_session", methods=["GET", "POST"]) +@jwt_auth_required(role="admin") def start_session(): """Start a new session. return: The response from control. @@ -229,6 +243,7 @@ def start_session(): @app.route("/list_sessions", methods=["GET"]) +@jwt_auth_required(role="admin") def list_sessions(): """Get all sessions from the statestore. return: All sessions as a json object. @@ -241,6 +256,7 @@ def list_sessions(): @app.route("/get_session", methods=["GET"]) +@jwt_auth_required(role="admin") def get_session(): """Get a session from the statestore. param: session_id: The session id to get. @@ -258,12 +274,14 @@ def get_session(): @app.route("/set_active_package", methods=["PUT"]) +@jwt_auth_required(role="admin") def set_active_package(): id = request.args.get("id", None) return api.set_active_compute_package(id) @app.route("/set_package", methods=["POST"]) +@jwt_auth_required(role="admin") def set_package(): """ Set the compute package in the statestore. Usage with curl: @@ -296,6 +314,7 @@ def set_package(): @app.route("/get_package", methods=["GET"]) +@jwt_auth_required(role="admin") def get_package(): """Get the compute package from the statestore. return: The compute package as a json object. @@ -305,6 +324,7 @@ def get_package(): @app.route("/list_compute_packages", methods=["GET"]) +@jwt_auth_required(role="admin") def list_compute_packages(): """Get the compute package from the statestore. return: The compute package as a json object. @@ -321,6 +341,7 @@ def list_compute_packages(): @app.route("/download_package", methods=["GET"]) +@jwt_auth_required(role="client") def download_package(): """Download the compute package. return: The compute package as a json object. @@ -331,12 +352,14 @@ def download_package(): @app.route("/get_package_checksum", methods=["GET"]) +@jwt_auth_required(role="admin") def get_package_checksum(): name = request.args.get("name", None) return api.get_checksum(name) @app.route("/get_latest_model", methods=["GET"]) +@jwt_auth_required(role="admin") def get_latest_model(): """Get the latest model from the statestore. return: The initial model as a json object. @@ -346,6 +369,7 @@ def get_latest_model(): @app.route("/set_current_model", methods=["PUT"]) +@jwt_auth_required(role="admin") def set_current_model(): """Set the initial model in the statestore and upload to model repository. Usage with curl: @@ -368,6 +392,7 @@ def set_current_model(): @app.route("/get_initial_model", methods=["GET"]) +@jwt_auth_required(role="admin") def get_initial_model(): """Get the initial model from the statestore. return: The initial model as a json object. @@ -377,6 +402,7 @@ def get_initial_model(): @app.route("/set_initial_model", methods=["POST"]) +@jwt_auth_required(role="admin") def set_initial_model(): """Set the initial model in the statestore and upload to model repository. Usage with curl: @@ -397,6 +423,7 @@ def set_initial_model(): @app.route("/get_controller_status", methods=["GET"]) +@jwt_auth_required(role="admin") def get_controller_status(): """Get the status of the controller. return: The status as a json object. @@ -406,6 +433,7 @@ def get_controller_status(): @app.route("/get_client_config", methods=["GET"]) +@jwt_auth_required(role="admin") def get_client_config(): """Get the client configuration. return: The client configuration as a json object. @@ -416,6 +444,7 @@ def get_client_config(): @app.route("/get_events", methods=["GET"]) +@jwt_auth_required(role="admin") def get_events(): """Get the events from the statestore. return: The events as a json object. @@ -428,6 +457,7 @@ def get_events(): @app.route("/list_validations", methods=["GET"]) +@jwt_auth_required(role="admin") def list_validations(): """Get all validations from the statestore. return: All validations as a json object. @@ -439,6 +469,7 @@ def list_validations(): @app.route("/add_combiner", methods=["POST"]) +@jwt_auth_required(role="combiner") def add_combiner(): """Add a combiner to the network. return: The response from the statestore. @@ -454,6 +485,7 @@ def add_combiner(): @app.route("/add_client", methods=["POST"]) +@jwt_auth_required(role="client") def add_client(): """Add a client to the network. return: The response from control. @@ -470,6 +502,7 @@ def add_client(): @app.route("/list_combiners_data", methods=["POST"]) +@jwt_auth_required(role="admin") def list_combiners_data(): """List data from combiners. return: The response from control. @@ -489,6 +522,7 @@ def list_combiners_data(): @app.route("/get_plot_data", methods=["GET"]) +@jwt_auth_required(role="admin") def get_plot_data(): """Get plot data from the statestore. rtype: json diff --git a/fedn/fedn/network/api/v1/client_routes.py b/fedn/fedn/network/api/v1/client_routes.py index c4215bd54..30322a9b7 100644 --- a/fedn/fedn/network/api/v1/client_routes.py +++ b/fedn/fedn/network/api/v1/client_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.client_store import ClientStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_clients(): """Get clients Retrieves a list of clients based on the provided parameters. @@ -127,6 +129,7 @@ def get_clients(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_clients(): """List clients Retrieves a list of clients based on the provided parameters. @@ -213,6 +216,7 @@ def list_clients(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_clients_count(): """Clients count Retrieves the total number of clients based on the provided parameters. @@ -273,6 +277,7 @@ def get_clients_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def clients_count(): """Clients count Retrieves the total number of clients based on the provided parameters. @@ -325,6 +330,7 @@ def clients_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_client(id: str): """Get client Retrieves a client based on the provided id. diff --git a/fedn/fedn/network/api/v1/combiner_routes.py b/fedn/fedn/network/api/v1/combiner_routes.py index ba6bf5dbd..7d1761bee 100644 --- a/fedn/fedn/network/api/v1/combiner_routes.py +++ b/fedn/fedn/network/api/v1/combiner_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.combiner_store import CombinerStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiners(): """Get combiners Retrieves a list of combiners based on the provided parameters. @@ -119,6 +121,7 @@ def get_combiners(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_combiners(): """List combiners Retrieves a list of combiners based on the provided parameters. @@ -203,6 +206,7 @@ def list_combiners(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiners_count(): """Combiners count Retrieves the count of combiners based on the provided parameters. @@ -249,6 +253,7 @@ def get_combiners_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def combiners_count(): """Combiners count Retrieves the count of combiners based on the provided parameters. @@ -297,6 +302,7 @@ def combiners_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiner(id: str): """Get combiner Retrieves a combiner based on the provided id. diff --git a/fedn/fedn/network/api/v1/model_routes.py b/fedn/fedn/network/api/v1/model_routes.py index e531bcb3a..7572cb01c 100644 --- a/fedn/fedn/network/api/v1/model_routes.py +++ b/fedn/fedn/network/api/v1/model_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_limit, get_post_data_to_kwargs, get_typed_list_headers, mdb) @@ -12,6 +13,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_models(): """Get models Retrieves a list of models based on the provided parameters. @@ -114,6 +116,7 @@ def get_models(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_models(): """List models Retrieves a list of models based on the provided parameters. @@ -200,6 +203,7 @@ def list_models(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_models_count(): """Models count Retrieves the count of models based on the provided parameters. @@ -247,6 +251,7 @@ def get_models_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def models_count(): """Models count Retrieves the count of models based on the provided parameters. @@ -298,6 +303,7 @@ def models_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model(id: str): """Get model Retrieves a model based on the provided id. @@ -343,6 +349,7 @@ def get_model(id: str): @bp.route("//descendants", methods=["GET"]) +@jwt_auth_required(role="admin") def get_descendants(id: str): """Get model descendants Retrieves a list of model descendants of the provided model id/model property. @@ -396,6 +403,7 @@ def get_descendants(id: str): @bp.route("//ancestors", methods=["GET"]) +@jwt_auth_required(role="admin") def get_ancestors(id: str): """Get model ancestors Retrieves a list of model ancestors of the provided model id/model property. diff --git a/fedn/fedn/network/api/v1/package_routes.py b/fedn/fedn/network/api/v1/package_routes.py index b62aa96a9..30ac4d51e 100644 --- a/fedn/fedn/network/api/v1/package_routes.py +++ b/fedn/fedn/network/api/v1/package_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb) @@ -12,6 +13,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_packages(): """Get packages Retrieves a list of packages based on the provided parameters. @@ -132,6 +134,7 @@ def get_packages(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_packages(): """List packages Retrieves a list of packages based on the provided parameters. @@ -221,6 +224,7 @@ def list_packages(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_packages_count(): """Package count Retrieves the count of packages based on the provided parameters. @@ -281,6 +285,7 @@ def get_packages_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def packages_count(): """Package count Retrieves the count of packages based on the provided parameters. @@ -342,6 +347,7 @@ def packages_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_package(id: str): """Get package Retrieves a package based on the provided id. @@ -388,6 +394,7 @@ def get_package(id: str): @bp.route("/active", methods=["GET"]) +@jwt_auth_required(role="admin") def get_active_package(): """Get active package Retrieves the active package diff --git a/fedn/fedn/network/api/v1/round_routes.py b/fedn/fedn/network/api/v1/round_routes.py index 317c767ee..8890c510a 100644 --- a/fedn/fedn/network/api/v1/round_routes.py +++ b/fedn/fedn/network/api/v1/round_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.round_store import RoundStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_rounds(): """Get rounds Retrieves a list of rounds based on the provided parameters. @@ -107,6 +109,7 @@ def get_rounds(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_rounds(): """List rounds Retrieves a list of rounds based on the provided parameters. @@ -187,6 +190,7 @@ def list_rounds(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_rounds_count(): """Rounds count Retrieves the count of rounds based on the provided parameters. @@ -227,6 +231,7 @@ def get_rounds_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def rounds_count(): """Rounds count Retrieves the count of rounds based on the provided parameters. @@ -271,6 +276,7 @@ def rounds_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_round(id: str): """Get round Retrieves a round based on the provided id. diff --git a/fedn/fedn/network/api/v1/session_routes.py b/fedn/fedn/network/api/v1/session_routes.py index 4d3fe493f..99c52d8db 100644 --- a/fedn/fedn/network/api/v1/session_routes.py +++ b/fedn/fedn/network/api/v1/session_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.session_store import SessionStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_sessions(): """Get sessions Retrieves a list of sessions based on the provided parameters. @@ -99,6 +101,7 @@ def get_sessions(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_sessions(): """List sessions Retrieves a list of sessions based on the provided parameters. @@ -178,6 +181,7 @@ def list_sessions(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_sessions_count(): """Sessions count Retrieves the count of sessions based on the provided parameters. @@ -218,6 +222,7 @@ def get_sessions_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def sessions_count(): """Sessions count Retrieves the count of sessions based on the provided parameters. @@ -262,6 +267,7 @@ def sessions_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_session(id: str): """Get session Retrieves a session based on the provided id. diff --git a/fedn/fedn/network/api/v1/status_routes.py b/fedn/fedn/network/api/v1/status_routes.py index 562b971db..e78c18533 100644 --- a/fedn/fedn/network/api/v1/status_routes.py +++ b/fedn/fedn/network/api/v1/status_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb) @@ -12,6 +13,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_statuses(): """Get statuses Retrieves a list of statuses based on the provided parameters. @@ -144,6 +146,7 @@ def get_statuses(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_statuses(): """Get statuses Retrieves a list of statuses based on the provided parameters. @@ -246,6 +249,7 @@ def list_statuses(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_statuses_count(): """Statuses count Retrieves the count of statuses based on the provided parameters. @@ -307,6 +311,7 @@ def get_statuses_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def statuses_count(): """Statuses count Retrieves the count of statuses based on the provided parameters. @@ -368,6 +373,7 @@ def statuses_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_status(id: str): """Get status Retrieves a status based on the provided id. diff --git a/fedn/fedn/network/api/v1/validation_routes.py b/fedn/fedn/network/api/v1/validation_routes.py index 874154dbc..96fbac55c 100644 --- a/fedn/fedn/network/api/v1/validation_routes.py +++ b/fedn/fedn/network/api/v1/validation_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb) @@ -13,6 +14,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_validations(): """Get validations Retrieves a list of validations based on the provided parameters. @@ -152,6 +154,7 @@ def get_validations(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_validations(): """Get validations Retrieves a list of validations based on the provided parameters. @@ -257,6 +260,7 @@ def list_validations(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_validations_count(): """Validations count Retrieves the count of validations based on the provided parameters. @@ -322,6 +326,7 @@ def get_validations_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def validations_count(): """Validations count Retrieves the count of validations based on the provided parameters. @@ -386,6 +391,7 @@ def validations_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_validation(id: str): """Get validation Retrieves a validation based on the provided id. diff --git a/fedn/fedn/network/combiner/connect.py b/fedn/fedn/network/combiner/connect.py index 4c1c94266..a0b2d1803 100644 --- a/fedn/fedn/network/combiner/connect.py +++ b/fedn/fedn/network/combiner/connect.py @@ -5,6 +5,7 @@ # # import enum +import os import requests @@ -72,10 +73,14 @@ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, self.myhost = myhost self.myport = myport self.token = token + self.token_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') self.name = name self.secure = secure self.verify = verify + if not self.token: + self.token = os.environ.get('FEDN_AUTH_TOKEN', None) + # for https we assume a an ingress handles permanent redirect (308) self.prefix = "http://" if port: @@ -104,7 +109,7 @@ def announce(self): try: retval = requests.post(self.connect_string + '/add_combiner', json=payload, verify=self.verify, - headers={'Authorization': 'Token {}'.format(self.token)}) + headers={'Authorization': f'{self.token_scheme} {self.token}'}) except Exception: return Status.Unassigned, {} diff --git a/fedn/fedn/tests/__init__.py b/fedn/fedn/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 From 296ecb4b252a75d30ecb8b5af7064968da244dbe Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Mon, 11 Mar 2024 15:41:05 +0000 Subject: [PATCH 02/16] fix --- fedn/fedn/network/api/auth.py | 11 ++++++----- fedn/fedn/network/api/server.py | 12 ++++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/fedn/fedn/network/api/auth.py b/fedn/fedn/network/api/auth.py index dfbb109b2..a5910c64a 100644 --- a/fedn/fedn/network/api/auth.py +++ b/fedn/fedn/network/api/auth.py @@ -10,7 +10,7 @@ FEDN_JWT_CUSTOM_CLAIM_VALUE = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_VALUE', False) FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') -# Fuction to check additional claims in the token + def check_role_claims(payload, role): if FEDN_JWT_CUSTOM_CLAIM_KEY and FEDN_JWT_CUSTOM_CLAIM_VALUE: if payload[FEDN_JWT_CUSTOM_CLAIM_KEY] != FEDN_JWT_CUSTOM_CLAIM_VALUE: @@ -19,21 +19,22 @@ def check_role_claims(payload, role): return False if payload['role'] != role: return False - + return True -# Fuction to check additional cliams in the token + def check_custom_claims(payload): if FEDN_JWT_CUSTOM_CLAIM_KEY and FEDN_JWT_CUSTOM_CLAIM_VALUE: if payload[FEDN_JWT_CUSTOM_CLAIM_KEY] != FEDN_JWT_CUSTOM_CLAIM_VALUE: return False return True -# Define the authentication decorator, with role as an argument + def jwt_auth_required(role=None): def actual_decorator(func): if not SECRET_KEY: return func + @wraps(func) def decorated(*args, **kwargs): token = request.headers.get('Authorization') @@ -60,4 +61,4 @@ def decorated(*args, **kwargs): return func(*args, **kwargs) return decorated - return actual_decorator \ No newline at end of file + return actual_decorator diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index f99f28aff..fb5010911 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -34,12 +34,12 @@ app.register_blueprint(round_bp) template = { - "swagger": "2.0", - "info": { - "title": "FEDn API", - "description": "API for the FEDn network.", - "version": "0.0.1" - } + "swagger": "2.0", + "info": { + "title": "FEDn API", + "description": "API for the FEDn network.", + "version": "0.0.1" + } } swagger = Swagger(app, template=template) From f4157d427cfc3d9b20988cfb0c3690e0644c9380 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Tue, 12 Mar 2024 14:01:21 +0000 Subject: [PATCH 03/16] add whitelist prefix url --- fedn/fedn/network/api/auth.py | 8 ++++++++ fedn/fedn/network/api/server.py | 26 ++++++++++---------------- fedn/fedn/network/api/v1/__init__.py | 10 ++++++++++ 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/fedn/fedn/network/api/auth.py b/fedn/fedn/network/api/auth.py index a5910c64a..d65794417 100644 --- a/fedn/fedn/network/api/auth.py +++ b/fedn/fedn/network/api/auth.py @@ -9,6 +9,7 @@ FEDN_JWT_CUSTOM_CLAIM_KEY = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_KEY', False) FEDN_JWT_CUSTOM_CLAIM_VALUE = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_VALUE', False) FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') +FEDN_AUTH_WHITELIST_URL_PREFIX = os.environ.get('FEDN_AUTH_WHITELIST_URL_PREFIX', False) def check_role_claims(payload, role): @@ -29,6 +30,11 @@ def check_custom_claims(payload): return False return True +def if_whitelisted_url_prefix(path): + if FEDN_AUTH_WHITELIST_URL_PREFIX and path.startswith(FEDN_AUTH_WHITELIST_URL_PREFIX): + return True + else: + return False def jwt_auth_required(role=None): def actual_decorator(func): @@ -37,6 +43,8 @@ def actual_decorator(func): @wraps(func) def decorated(*args, **kwargs): + if if_whitelisted_url_prefix(request.path): + return func(*args, **kwargs) token = request.headers.get('Authorization') # Get token from the header Bearer if token and token.startswith(FEDN_AUTH_SCHEME): diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index fb5010911..cb70717ba 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -1,18 +1,12 @@ from flasgger import Swagger from flask import Flask, jsonify, request +import os from fedn.common.config import (get_controller_config, get_modelstorage_config, get_network_config, get_statestore_config) from fedn.network.api.auth import jwt_auth_required from fedn.network.api.interface import API -from fedn.network.api.v1.client_routes import bp as client_bp -from fedn.network.api.v1.combiner_routes import bp as combiner_bp -from fedn.network.api.v1.model_routes import bp as model_bp -from fedn.network.api.v1.package_routes import bp as package_bp -from fedn.network.api.v1.round_routes import bp as round_bp -from fedn.network.api.v1.session_routes import bp as session_bp -from fedn.network.api.v1.status_routes import bp as status_bp -from fedn.network.api.v1.validation_routes import bp as validation_bp +from fedn.network.api.v1 import _routes from fedn.network.controller.control import Control from fedn.network.storage.statestore.mongostatestore import MongoStateStore @@ -22,16 +16,16 @@ statestore = MongoStateStore(network_id, statestore_config["mongo_config"]) statestore.set_storage_backend(modelstorage_config) control = Control(statestore=statestore) + +custom_url_prefix = os.environ.get("FEDN_CUSTOM_URL_PREFIX", False) api = API(statestore, control) app = Flask(__name__) -app.register_blueprint(client_bp) -app.register_blueprint(status_bp) -app.register_blueprint(model_bp) -app.register_blueprint(validation_bp) -app.register_blueprint(package_bp) -app.register_blueprint(session_bp) -app.register_blueprint(combiner_bp) -app.register_blueprint(round_bp) +for bp in _routes: + app.register_blueprint(bp) + if custom_url_prefix: + app.register_blueprint(bp, + name=f"{bp.name}_custom", + url_prefix=f"{custom_url_prefix}{bp.url_prefix}") template = { "swagger": "2.0", diff --git a/fedn/fedn/network/api/v1/__init__.py b/fedn/fedn/network/api/v1/__init__.py index e69de29bb..11698cb14 100644 --- a/fedn/fedn/network/api/v1/__init__.py +++ b/fedn/fedn/network/api/v1/__init__.py @@ -0,0 +1,10 @@ +from fedn.network.api.v1.client_routes import bp as client_bp +from fedn.network.api.v1.combiner_routes import bp as combiner_bp +from fedn.network.api.v1.model_routes import bp as model_bp +from fedn.network.api.v1.package_routes import bp as package_bp +from fedn.network.api.v1.round_routes import bp as round_bp +from fedn.network.api.v1.session_routes import bp as session_bp +from fedn.network.api.v1.status_routes import bp as status_bp +from fedn.network.api.v1.validation_routes import bp as validation_bp + +_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp] \ No newline at end of file From 80f239dcc2cc7a696d673be756055c5296b53060 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Tue, 12 Mar 2024 14:25:40 +0000 Subject: [PATCH 04/16] add prefix to old api --- fedn/fedn/network/api/server.py | 69 ++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index cb70717ba..a5c90804d 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -50,6 +50,9 @@ def get_model_trail(): """ return api.get_model_trail() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model_trail", view_func=get_model_trail, methods=["GET"]) + @app.route("/get_model_ancestors", methods=["GET"]) @jwt_auth_required(role="admin") @@ -67,6 +70,8 @@ def get_model_ancestors(): return api.get_model_ancestors(model, limit) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model_ancestors", view_func=get_model_ancestors, methods=["GET"]) @app.route("/get_model_descendants", methods=["GET"]) @jwt_auth_required(role="admin") @@ -84,6 +89,8 @@ def get_model_descendants(): return api.get_model_descendants(model, limit) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model_descendants", view_func=get_model_descendants, methods=["GET"]) @app.route("/list_models", methods=["GET"]) @jwt_auth_required(role="admin") @@ -106,6 +113,8 @@ def list_models(): return api.get_models(session_id, limit, skip, include_active) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_models", view_func=list_models, methods=["GET"]) @app.route("/get_model", methods=["GET"]) @jwt_auth_required(role="admin") @@ -122,6 +131,8 @@ def get_model(): return api.get_model(model) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model", view_func=get_model, methods=["GET"]) @app.route("/delete_model_trail", methods=["GET", "POST"]) @jwt_auth_required(role="admin") @@ -134,6 +145,8 @@ def delete_model_trail(): """ return jsonify({"message": "Not implemented"}), 501 +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/delete_model_trail", view_func=delete_model_trail, methods=["GET", "POST"]) @app.route("/list_clients", methods=["GET"]) @jwt_auth_required(role="admin") @@ -149,6 +162,8 @@ def list_clients(): return api.get_clients(limit, skip, status) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_clients", view_func=list_clients, methods=["GET"]) @app.route("/get_active_clients", methods=["GET"]) @jwt_auth_required(role="admin") @@ -167,6 +182,8 @@ def get_active_clients(): ) return api.get_active_clients(combiner_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_active_clients", view_func=get_active_clients, methods=["GET"]) @app.route("/list_combiners", methods=["GET"]) @jwt_auth_required(role="admin") @@ -181,6 +198,8 @@ def list_combiners(): return api.get_all_combiners(limit, skip) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_combiners", view_func=list_combiners, methods=["GET"]) @app.route("/get_combiner", methods=["GET"]) @jwt_auth_required(role="admin") @@ -199,6 +218,8 @@ def get_combiner(): ) return api.get_combiner(combiner_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_combiner", view_func=get_combiner, methods=["GET"]) @app.route("/list_rounds", methods=["GET"]) @jwt_auth_required(role="admin") @@ -209,6 +230,8 @@ def list_rounds(): """ return api.get_all_rounds() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_rounds", view_func=list_rounds, methods=["GET"]) @app.route("/get_round", methods=["GET"]) @jwt_auth_required(role="admin") @@ -224,6 +247,8 @@ def get_round(): return jsonify({"success": False, "message": "Missing round id."}), 400 return api.get_round(round_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_round", view_func=get_round, methods=["GET"]) @app.route("/start_session", methods=["GET", "POST"]) @jwt_auth_required(role="admin") @@ -235,6 +260,8 @@ def start_session(): json_data = request.get_json() return api.start_session(**json_data) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/start_session", view_func=start_session, methods=["GET", "POST"]) @app.route("/list_sessions", methods=["GET"]) @jwt_auth_required(role="admin") @@ -248,6 +275,8 @@ def list_sessions(): return api.get_all_sessions(limit, skip) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_sessions", view_func=list_sessions, methods=["GET"]) @app.route("/get_session", methods=["GET"]) @jwt_auth_required(role="admin") @@ -266,6 +295,8 @@ def get_session(): ) return api.get_session(session_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_session", view_func=get_session, methods=["GET"]) @app.route("/set_active_package", methods=["PUT"]) @jwt_auth_required(role="admin") @@ -273,6 +304,8 @@ def set_active_package(): id = request.args.get("id", None) return api.set_active_compute_package(id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_active_package", view_func=set_active_package, methods=["PUT"]) @app.route("/set_package", methods=["POST"]) @jwt_auth_required(role="admin") @@ -306,6 +339,8 @@ def set_package(): file=file, helper_type=helper_type, name=name, description=description ) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_package", view_func=set_package, methods=["POST"]) @app.route("/get_package", methods=["GET"]) @jwt_auth_required(role="admin") @@ -316,6 +351,8 @@ def get_package(): """ return api.get_compute_package() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_package", view_func=get_package, methods=["GET"]) @app.route("/list_compute_packages", methods=["GET"]) @jwt_auth_required(role="admin") @@ -333,6 +370,8 @@ def list_compute_packages(): limit=limit, skip=skip, include_active=include_active ) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_compute_packages", view_func=list_compute_packages, methods=["GET"]) @app.route("/download_package", methods=["GET"]) @jwt_auth_required(role="client") @@ -344,6 +383,8 @@ def download_package(): name = request.args.get("name", None) return api.download_compute_package(name) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/download_package", view_func=download_package, methods=["GET"]) @app.route("/get_package_checksum", methods=["GET"]) @jwt_auth_required(role="admin") @@ -351,6 +392,8 @@ def get_package_checksum(): name = request.args.get("name", None) return api.get_checksum(name) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_package_checksum", view_func=get_package_checksum, methods=["GET"]) @app.route("/get_latest_model", methods=["GET"]) @jwt_auth_required(role="admin") @@ -361,6 +404,8 @@ def get_latest_model(): """ return api.get_latest_model() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_latest_model", view_func=get_latest_model, methods=["GET"]) @app.route("/set_current_model", methods=["PUT"]) @jwt_auth_required(role="admin") @@ -381,6 +426,8 @@ def set_current_model(): return jsonify({"success": False, "message": "Missing model id."}), 400 return api.set_current_model(id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_current_model", view_func=set_current_model, methods=["PUT"]) # Get initial model endpoint @@ -394,6 +441,8 @@ def get_initial_model(): """ return api.get_initial_model() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_initial_model", view_func=get_initial_model, methods=["GET"]) @app.route("/set_initial_model", methods=["POST"]) @jwt_auth_required(role="admin") @@ -415,6 +464,8 @@ def set_initial_model(): return jsonify({"success": False, "message": "Missing file."}), 400 return api.set_initial_model(file) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_initial_model", view_func=set_initial_model, methods=["POST"]) @app.route("/get_controller_status", methods=["GET"]) @jwt_auth_required(role="admin") @@ -425,6 +476,8 @@ def get_controller_status(): """ return api.get_controller_status() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_controller_status", view_func=get_controller_status, methods=["GET"]) @app.route("/get_client_config", methods=["GET"]) @jwt_auth_required(role="admin") @@ -436,6 +489,8 @@ def get_client_config(): checksum = request.args.get("checksum", True) return api.get_client_config(checksum) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_client_config", view_func=get_client_config, methods=["GET"]) @app.route("/get_events", methods=["GET"]) @jwt_auth_required(role="admin") @@ -449,6 +504,8 @@ def get_events(): return api.get_events(**kwargs) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_events", view_func=get_client_config, methods=["GET"]) @app.route("/list_validations", methods=["GET"]) @jwt_auth_required(role="admin") @@ -461,6 +518,8 @@ def list_validations(): kwargs = request.args.to_dict() return api.get_all_validations(**kwargs) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_validations", view_func=list_validations, methods=["GET"]) @app.route("/add_combiner", methods=["POST"]) @jwt_auth_required(role="combiner") @@ -477,6 +536,8 @@ def add_combiner(): return jsonify({"success": False, "message": str(e)}), 400 return response +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/add_combiner", view_func=add_combiner, methods=["POST"]) @app.route("/add_client", methods=["POST"]) @jwt_auth_required(role="client") @@ -494,6 +555,8 @@ def add_client(): return jsonify({"success": False, "message": str(e)}), 400 return response +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/add_client", view_func=add_client, methods=["POST"]) @app.route("/list_combiners_data", methods=["POST"]) @jwt_auth_required(role="admin") @@ -514,6 +577,8 @@ def list_combiners_data(): return jsonify({"success": False, "message": str(e)}), 400 return response +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_combiners_data", view_func=list_combiners_data, methods=["POST"]) @app.route("/get_plot_data", methods=["GET"]) @jwt_auth_required(role="admin") @@ -529,7 +594,9 @@ def get_plot_data(): return jsonify({"success": False, "message": str(e)}), 400 return response - +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_plot_data", view_func=get_plot_data, methods=["GET"]) + if __name__ == "__main__": config = get_controller_config() port = config["port"] From e8490c9f2b7571a4c73ab70bb2f342b397adb7e9 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Tue, 12 Mar 2024 14:27:59 +0000 Subject: [PATCH 05/16] fix --- fedn/fedn/network/api/server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index a5c90804d..6d193a351 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -1,6 +1,7 @@ +import os + from flasgger import Swagger from flask import Flask, jsonify, request -import os from fedn.common.config import (get_controller_config, get_modelstorage_config, get_network_config, get_statestore_config) @@ -596,7 +597,7 @@ def get_plot_data(): if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_plot_data", view_func=get_plot_data, methods=["GET"]) - + if __name__ == "__main__": config = get_controller_config() port = config["port"] From d524da39a26146422702ca8325a884337934cfd1 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Tue, 12 Mar 2024 16:44:55 +0000 Subject: [PATCH 06/16] jwt algo --- fedn/fedn/network/api/auth.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fedn/fedn/network/api/auth.py b/fedn/fedn/network/api/auth.py index d65794417..8e6da18e8 100644 --- a/fedn/fedn/network/api/auth.py +++ b/fedn/fedn/network/api/auth.py @@ -10,6 +10,7 @@ FEDN_JWT_CUSTOM_CLAIM_VALUE = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_VALUE', False) FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') FEDN_AUTH_WHITELIST_URL_PREFIX = os.environ.get('FEDN_AUTH_WHITELIST_URL_PREFIX', False) +FEDN_JWT_ALGORITHM = os.environ.get('FEDN_JWT_ALGORITHM', 'HS256') def check_role_claims(payload, role): @@ -54,7 +55,7 @@ def decorated(*args, **kwargs): return jsonify({'message': 'Missing token'}), 401 try: - payload = jwt.decode(token, SECRET_KEY, algorithms=['HS256']) + payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) if not check_role_claims(payload, role): return jsonify({'message': 'Invalid token'}), 401 if not check_custom_claims(payload): From ca1d15cd9152fae5e4f13d95b19494aa47c741d7 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Wed, 13 Mar 2024 13:37:42 +0100 Subject: [PATCH 07/16] add custom url prefix to rest clients --- fedn/fedn/network/clients/connect.py | 5 +++-- fedn/fedn/network/combiner/connect.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fedn/fedn/network/clients/connect.py b/fedn/fedn/network/clients/connect.py index 478844d26..1862e2ee5 100644 --- a/fedn/fedn/network/clients/connect.py +++ b/fedn/fedn/network/clients/connect.py @@ -5,6 +5,7 @@ # # import enum +import os import requests @@ -77,8 +78,8 @@ def assign(self): try: retval = None payload = {'client_id': self.name, 'preferred_combiner': self.preferred_combiner} - - retval = requests.post(self.connect_string + '/add_client', + url_prefix = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') + retval = requests.post(self.connect_string + url_prefix + '/add_client', json=payload, verify=self.verify, allow_redirects=True, diff --git a/fedn/fedn/network/combiner/connect.py b/fedn/fedn/network/combiner/connect.py index a0b2d1803..7dc388261 100644 --- a/fedn/fedn/network/combiner/connect.py +++ b/fedn/fedn/network/combiner/connect.py @@ -106,8 +106,9 @@ def announce(self): "port": self.myport, "secure_grpc": self.secure } + url_prefix = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') try: - retval = requests.post(self.connect_string + '/add_combiner', json=payload, + retval = requests.post(self.connect_string + url_prefix + '/add_combiner', json=payload, verify=self.verify, headers={'Authorization': f'{self.token_scheme} {self.token}'}) except Exception: From 8a29fc10318ed1273bb470a5dcc5ad065ff20de3 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Wed, 13 Mar 2024 14:40:41 +0100 Subject: [PATCH 08/16] fix --- fedn/fedn/network/api/auth.py | 15 +++++++-------- fedn/fedn/network/clients/connect.py | 3 ++- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/fedn/fedn/network/api/auth.py b/fedn/fedn/network/api/auth.py index 8e6da18e8..925b7cef2 100644 --- a/fedn/fedn/network/api/auth.py +++ b/fedn/fedn/network/api/auth.py @@ -14,9 +14,6 @@ def check_role_claims(payload, role): - if FEDN_JWT_CUSTOM_CLAIM_KEY and FEDN_JWT_CUSTOM_CLAIM_VALUE: - if payload[FEDN_JWT_CUSTOM_CLAIM_KEY] != FEDN_JWT_CUSTOM_CLAIM_VALUE: - return False if 'role' not in payload: return False if payload['role'] != role: @@ -47,13 +44,15 @@ def decorated(*args, **kwargs): if if_whitelisted_url_prefix(request.path): return func(*args, **kwargs) token = request.headers.get('Authorization') - # Get token from the header Bearer - if token and token.startswith(FEDN_AUTH_SCHEME): - token = token.split(' ')[1] - if not token: return jsonify({'message': 'Missing token'}), 401 - + # Get token from the header Bearer + if token.startswith(FEDN_AUTH_SCHEME): + token = token.split(' ')[1] + else: + return jsonify({'message': + f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}' + }), 401 try: payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) if not check_role_claims(payload, role): diff --git a/fedn/fedn/network/clients/connect.py b/fedn/fedn/network/clients/connect.py index 1862e2ee5..b559a27d3 100644 --- a/fedn/fedn/network/clients/connect.py +++ b/fedn/fedn/network/clients/connect.py @@ -79,11 +79,12 @@ def assign(self): retval = None payload = {'client_id': self.name, 'preferred_combiner': self.preferred_combiner} url_prefix = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') + auth_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') retval = requests.post(self.connect_string + url_prefix + '/add_client', json=payload, verify=self.verify, allow_redirects=True, - headers={'Authorization': 'Token {}'.format(self.token)}) + headers={'Authorization': f"{auth_scheme} {self.token}"}) except Exception as e: print('***** {}'.format(e), flush=True) return Status.Unassigned, {} From 2bdc49383940d467cdfe871b4b3c84672343c166 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Wed, 13 Mar 2024 15:02:57 +0100 Subject: [PATCH 09/16] undo default auth scheme --- fedn/fedn/network/clients/connect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fedn/fedn/network/clients/connect.py b/fedn/fedn/network/clients/connect.py index b559a27d3..21c84f586 100644 --- a/fedn/fedn/network/clients/connect.py +++ b/fedn/fedn/network/clients/connect.py @@ -79,7 +79,7 @@ def assign(self): retval = None payload = {'client_id': self.name, 'preferred_combiner': self.preferred_combiner} url_prefix = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') - auth_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') + auth_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Token') retval = requests.post(self.connect_string + url_prefix + '/add_client', json=payload, verify=self.verify, From c6329bfb59b737e43f9d9190b10b4e36136acd35 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Wed, 13 Mar 2024 15:16:16 +0100 Subject: [PATCH 10/16] fix --- fedn/fedn/network/api/auth.py | 4 +- fedn/fedn/network/api/server.py | 65 +++++++++++++++++++++++++++- fedn/fedn/network/api/v1/__init__.py | 2 +- 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/fedn/fedn/network/api/auth.py b/fedn/fedn/network/api/auth.py index 925b7cef2..b6921c330 100644 --- a/fedn/fedn/network/api/auth.py +++ b/fedn/fedn/network/api/auth.py @@ -28,12 +28,14 @@ def check_custom_claims(payload): return False return True + def if_whitelisted_url_prefix(path): if FEDN_AUTH_WHITELIST_URL_PREFIX and path.startswith(FEDN_AUTH_WHITELIST_URL_PREFIX): return True else: return False + def jwt_auth_required(role=None): def actual_decorator(func): if not SECRET_KEY: @@ -50,7 +52,7 @@ def decorated(*args, **kwargs): if token.startswith(FEDN_AUTH_SCHEME): token = token.split(' ')[1] else: - return jsonify({'message': + return jsonify({'message': f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}' }), 401 try: diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index 6d193a351..be3e8ee02 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -24,7 +24,7 @@ for bp in _routes: app.register_blueprint(bp) if custom_url_prefix: - app.register_blueprint(bp, + app.register_blueprint(bp, name=f"{bp.name}_custom", url_prefix=f"{custom_url_prefix}{bp.url_prefix}") @@ -51,6 +51,7 @@ def get_model_trail(): """ return api.get_model_trail() + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_model_trail", view_func=get_model_trail, methods=["GET"]) @@ -71,9 +72,11 @@ def get_model_ancestors(): return api.get_model_ancestors(model, limit) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_model_ancestors", view_func=get_model_ancestors, methods=["GET"]) + @app.route("/get_model_descendants", methods=["GET"]) @jwt_auth_required(role="admin") def get_model_descendants(): @@ -90,9 +93,11 @@ def get_model_descendants(): return api.get_model_descendants(model, limit) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_model_descendants", view_func=get_model_descendants, methods=["GET"]) + @app.route("/list_models", methods=["GET"]) @jwt_auth_required(role="admin") def list_models(): @@ -114,9 +119,11 @@ def list_models(): return api.get_models(session_id, limit, skip, include_active) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/list_models", view_func=list_models, methods=["GET"]) + @app.route("/get_model", methods=["GET"]) @jwt_auth_required(role="admin") def get_model(): @@ -132,9 +139,11 @@ def get_model(): return api.get_model(model) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_model", view_func=get_model, methods=["GET"]) + @app.route("/delete_model_trail", methods=["GET", "POST"]) @jwt_auth_required(role="admin") def delete_model_trail(): @@ -146,9 +155,11 @@ def delete_model_trail(): """ return jsonify({"message": "Not implemented"}), 501 + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/delete_model_trail", view_func=delete_model_trail, methods=["GET", "POST"]) + @app.route("/list_clients", methods=["GET"]) @jwt_auth_required(role="admin") def list_clients(): @@ -163,9 +174,11 @@ def list_clients(): return api.get_clients(limit, skip, status) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/list_clients", view_func=list_clients, methods=["GET"]) + @app.route("/get_active_clients", methods=["GET"]) @jwt_auth_required(role="admin") def get_active_clients(): @@ -183,9 +196,11 @@ def get_active_clients(): ) return api.get_active_clients(combiner_id) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_active_clients", view_func=get_active_clients, methods=["GET"]) + @app.route("/list_combiners", methods=["GET"]) @jwt_auth_required(role="admin") def list_combiners(): @@ -199,9 +214,11 @@ def list_combiners(): return api.get_all_combiners(limit, skip) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/list_combiners", view_func=list_combiners, methods=["GET"]) + @app.route("/get_combiner", methods=["GET"]) @jwt_auth_required(role="admin") def get_combiner(): @@ -219,9 +236,11 @@ def get_combiner(): ) return api.get_combiner(combiner_id) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_combiner", view_func=get_combiner, methods=["GET"]) + @app.route("/list_rounds", methods=["GET"]) @jwt_auth_required(role="admin") def list_rounds(): @@ -231,9 +250,11 @@ def list_rounds(): """ return api.get_all_rounds() + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/list_rounds", view_func=list_rounds, methods=["GET"]) + @app.route("/get_round", methods=["GET"]) @jwt_auth_required(role="admin") def get_round(): @@ -248,9 +269,11 @@ def get_round(): return jsonify({"success": False, "message": "Missing round id."}), 400 return api.get_round(round_id) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_round", view_func=get_round, methods=["GET"]) + @app.route("/start_session", methods=["GET", "POST"]) @jwt_auth_required(role="admin") def start_session(): @@ -261,9 +284,11 @@ def start_session(): json_data = request.get_json() return api.start_session(**json_data) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/start_session", view_func=start_session, methods=["GET", "POST"]) + @app.route("/list_sessions", methods=["GET"]) @jwt_auth_required(role="admin") def list_sessions(): @@ -276,9 +301,11 @@ def list_sessions(): return api.get_all_sessions(limit, skip) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/list_sessions", view_func=list_sessions, methods=["GET"]) + @app.route("/get_session", methods=["GET"]) @jwt_auth_required(role="admin") def get_session(): @@ -296,18 +323,22 @@ def get_session(): ) return api.get_session(session_id) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_session", view_func=get_session, methods=["GET"]) + @app.route("/set_active_package", methods=["PUT"]) @jwt_auth_required(role="admin") def set_active_package(): id = request.args.get("id", None) return api.set_active_compute_package(id) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/set_active_package", view_func=set_active_package, methods=["PUT"]) + @app.route("/set_package", methods=["POST"]) @jwt_auth_required(role="admin") def set_package(): @@ -340,9 +371,11 @@ def set_package(): file=file, helper_type=helper_type, name=name, description=description ) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/set_package", view_func=set_package, methods=["POST"]) + @app.route("/get_package", methods=["GET"]) @jwt_auth_required(role="admin") def get_package(): @@ -352,9 +385,11 @@ def get_package(): """ return api.get_compute_package() + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_package", view_func=get_package, methods=["GET"]) + @app.route("/list_compute_packages", methods=["GET"]) @jwt_auth_required(role="admin") def list_compute_packages(): @@ -371,9 +406,11 @@ def list_compute_packages(): limit=limit, skip=skip, include_active=include_active ) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/list_compute_packages", view_func=list_compute_packages, methods=["GET"]) + @app.route("/download_package", methods=["GET"]) @jwt_auth_required(role="client") def download_package(): @@ -384,18 +421,22 @@ def download_package(): name = request.args.get("name", None) return api.download_compute_package(name) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/download_package", view_func=download_package, methods=["GET"]) + @app.route("/get_package_checksum", methods=["GET"]) @jwt_auth_required(role="admin") def get_package_checksum(): name = request.args.get("name", None) return api.get_checksum(name) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_package_checksum", view_func=get_package_checksum, methods=["GET"]) + @app.route("/get_latest_model", methods=["GET"]) @jwt_auth_required(role="admin") def get_latest_model(): @@ -405,9 +446,11 @@ def get_latest_model(): """ return api.get_latest_model() + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_latest_model", view_func=get_latest_model, methods=["GET"]) + @app.route("/set_current_model", methods=["PUT"]) @jwt_auth_required(role="admin") def set_current_model(): @@ -427,6 +470,7 @@ def set_current_model(): return jsonify({"success": False, "message": "Missing model id."}), 400 return api.set_current_model(id) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/set_current_model", view_func=set_current_model, methods=["PUT"]) @@ -442,9 +486,11 @@ def get_initial_model(): """ return api.get_initial_model() + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_initial_model", view_func=get_initial_model, methods=["GET"]) + @app.route("/set_initial_model", methods=["POST"]) @jwt_auth_required(role="admin") def set_initial_model(): @@ -465,9 +511,11 @@ def set_initial_model(): return jsonify({"success": False, "message": "Missing file."}), 400 return api.set_initial_model(file) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/set_initial_model", view_func=set_initial_model, methods=["POST"]) + @app.route("/get_controller_status", methods=["GET"]) @jwt_auth_required(role="admin") def get_controller_status(): @@ -477,9 +525,11 @@ def get_controller_status(): """ return api.get_controller_status() + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_controller_status", view_func=get_controller_status, methods=["GET"]) + @app.route("/get_client_config", methods=["GET"]) @jwt_auth_required(role="admin") def get_client_config(): @@ -490,9 +540,11 @@ def get_client_config(): checksum = request.args.get("checksum", True) return api.get_client_config(checksum) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_client_config", view_func=get_client_config, methods=["GET"]) + @app.route("/get_events", methods=["GET"]) @jwt_auth_required(role="admin") def get_events(): @@ -505,9 +557,11 @@ def get_events(): return api.get_events(**kwargs) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_events", view_func=get_client_config, methods=["GET"]) + @app.route("/list_validations", methods=["GET"]) @jwt_auth_required(role="admin") def list_validations(): @@ -519,9 +573,11 @@ def list_validations(): kwargs = request.args.to_dict() return api.get_all_validations(**kwargs) + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/list_validations", view_func=list_validations, methods=["GET"]) + @app.route("/add_combiner", methods=["POST"]) @jwt_auth_required(role="combiner") def add_combiner(): @@ -537,9 +593,11 @@ def add_combiner(): return jsonify({"success": False, "message": str(e)}), 400 return response + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/add_combiner", view_func=add_combiner, methods=["POST"]) + @app.route("/add_client", methods=["POST"]) @jwt_auth_required(role="client") def add_client(): @@ -556,9 +614,11 @@ def add_client(): return jsonify({"success": False, "message": str(e)}), 400 return response + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/add_client", view_func=add_client, methods=["POST"]) + @app.route("/list_combiners_data", methods=["POST"]) @jwt_auth_required(role="admin") def list_combiners_data(): @@ -578,9 +638,11 @@ def list_combiners_data(): return jsonify({"success": False, "message": str(e)}), 400 return response + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/list_combiners_data", view_func=list_combiners_data, methods=["POST"]) + @app.route("/get_plot_data", methods=["GET"]) @jwt_auth_required(role="admin") def get_plot_data(): @@ -595,6 +657,7 @@ def get_plot_data(): return jsonify({"success": False, "message": str(e)}), 400 return response + if custom_url_prefix: app.add_url_rule(f"{custom_url_prefix}/get_plot_data", view_func=get_plot_data, methods=["GET"]) diff --git a/fedn/fedn/network/api/v1/__init__.py b/fedn/fedn/network/api/v1/__init__.py index 11698cb14..bb8d8d33c 100644 --- a/fedn/fedn/network/api/v1/__init__.py +++ b/fedn/fedn/network/api/v1/__init__.py @@ -7,4 +7,4 @@ from fedn.network.api.v1.status_routes import bp as status_bp from fedn.network.api.v1.validation_routes import bp as validation_bp -_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp] \ No newline at end of file +_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp] From fbc00d1653b30fbc67fd266f2742e1eeedc5bf4e Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Thu, 14 Mar 2024 21:50:17 +0100 Subject: [PATCH 11/16] init interceptor --- fedn/fedn/common/config.py | 3 + fedn/fedn/network/clients/client.py | 3 + fedn/fedn/network/clients/package.py | 13 ++-- fedn/fedn/network/grpc/auth.py | 97 ++++++++++++++++++++++++++++ fedn/fedn/network/grpc/server.py | 3 +- 5 files changed, 112 insertions(+), 7 deletions(-) create mode 100644 fedn/fedn/network/grpc/auth.py diff --git a/fedn/fedn/common/config.py b/fedn/fedn/common/config.py index f6c827d0d..1cd68fba4 100644 --- a/fedn/fedn/common/config.py +++ b/fedn/fedn/common/config.py @@ -5,6 +5,9 @@ global STATESTORE_CONFIG global MODELSTORAGE_CONFIG +FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Token') +FEDN_CUSTOM_URL_PREFIX = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') + def get_environment_config(): """ Get the configuration from environment variables. diff --git a/fedn/fedn/network/clients/client.py b/fedn/fedn/network/clients/client.py index a6408d66c..8e1cdfa7e 100644 --- a/fedn/fedn/network/clients/client.py +++ b/fedn/fedn/network/clients/client.py @@ -178,6 +178,9 @@ def _connect(self, client_config): host = client_config['host'] # Add host to gRPC metadata self._add_grpc_metadata('grpc-server', host) + auth_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Token') + if self.config['token']: + self._add_grpc_metadata('authorization', f"{auth_scheme} {self.config['token']}") logger.info("Client using metadata: {}.".format(self.metadata)) port = client_config['port'] secure = False diff --git a/fedn/fedn/network/clients/package.py b/fedn/fedn/network/clients/package.py index d56296de8..cc98eae94 100644 --- a/fedn/fedn/network/clients/package.py +++ b/fedn/fedn/network/clients/package.py @@ -9,6 +9,7 @@ import requests import yaml +from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_CUSTOM_URL_PREFIX from fedn.common.log_config import logger from fedn.utils.checksum import sha from fedn.utils.dispatcher import Dispatcher @@ -52,13 +53,13 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): else: scheme = "http" if port: - path = f"{scheme}://{host}:{port}/download_package" + path = f"{scheme}://{host}:{port}{FEDN_CUSTOM_URL_PREFIX}/download_package" else: - path = f"{scheme}://{host}/download_package" + path = f"{scheme}://{host}{FEDN_CUSTOM_URL_PREFIX}/download_package" if name: path = path + "?name={}".format(name) - with requests.get(path, stream=True, verify=False, headers={'Authorization': 'Token {}'.format(token)}) as r: + with requests.get(path, stream=True, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: if 200 <= r.status_code < 204: params = cgi.parse_header( @@ -73,13 +74,13 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): for chunk in r.iter_content(chunk_size=8192): f.write(chunk) if port: - path = f"{scheme}://{host}:{port}/get_package_checksum" + path = f"{scheme}://{host}:{port}{FEDN_CUSTOM_URL_PREFIX}/get_package_checksum" else: - path = f"{scheme}://{host}/get_package_checksum" + path = f"{scheme}://{host}{FEDN_CUSTOM_URL_PREFIX}/get_package_checksum" if name: path = path + "?name={}".format(name) - with requests.get(path, verify=False, headers={'Authorization': 'Token {}'.format(token)}) as r: + with requests.get(path, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: if 200 <= r.status_code < 204: data = r.json() diff --git a/fedn/fedn/network/grpc/auth.py b/fedn/fedn/network/grpc/auth.py new file mode 100644 index 000000000..208dcff46 --- /dev/null +++ b/fedn/fedn/network/grpc/auth.py @@ -0,0 +1,97 @@ +import os + +import grpc +import jwt + +from fedn.common.log_config import logger +from fedn.network.api.auth import check_custom_claims + +SECRET_KEY = os.environ.get('FEDN_JWT_SECRET_KEY', False) +FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') +FEDN_JWT_ALGORITHM = os.environ.get('FEDN_JWT_ALGORITHM', 'HS256') + +ENDPOINT_ROLES_MAPPING = { + '/fedn.Combiner/TaskStream': ['client'], + '/fedn.Combiner/SendModelUpdate': ['client'], + '/fedn.Combiner/SendModelValidation': ['client'], + '/fedn.Connector/SendHeartbeat': ['client'], + '/fedn.Connector/SendStatus': ['client'], + '/fedn.ModelService/Download': ['client'], + '/fedn.ModelService/Upload': ['client'], + '/fedn.Control/Start': ['controller'], + '/fedn.Control/Stop': ['controller'], + '/fedn.Control/FlushAggregationQueue': ['controller'], + '/fedn.Control/SetAggregator': ['controller'], +} + +ENDPOINT_WHITELIST = [ + '/fedn.Connector/AcceptingClients', + '/fedn.Connector/ListActiveClients', + '/fedn.Control/Start', + '/fedn.Control/Stop', + '/fedn.Control/FlushAggregationQueue', + '/fedn.Control/SetAggregator', +] + +USER_AGENT_WHITELIST = [ + 'grpc_health_probe' +] + +def check_role_claims(payload, endpoint): + user_role = payload.get('role', '') + + # Perform endpoint-specific RBAC check + allowed_roles = ENDPOINT_ROLES_MAPPING.get(endpoint) + if allowed_roles and not user_role in allowed_roles: + return False + return True + +def _unary_unary_rpc_terminator(code, details): + def terminate(ignored_request, context): + context.abort(code, details) + + return grpc.unary_unary_rpc_method_handler(terminate) + +# Define the gRPC interceptor class +class JWTInterceptor(grpc.ServerInterceptor): + def __init__(self): + pass + + def intercept_service(self, continuation, handler_call_details): + # Pass if no secret key is set + if not SECRET_KEY: + return continuation(handler_call_details) + metadata = dict(handler_call_details.invocation_metadata) + # Pass whitelisted methods + if handler_call_details.method in ENDPOINT_WHITELIST: + return continuation(handler_call_details) + # Pass if the request comes from whitelisted user agents + user_agent = metadata.get('user-agent').split(' ')[0] + if user_agent in USER_AGENT_WHITELIST: + return continuation(handler_call_details) + + token = metadata.get('authorization') + if token is None: + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token is missing') + + if not token.startswith(FEDN_AUTH_SCHEME): + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}') + + token = token.split(' ')[1] + print(f"HANDLER: {handler_call_details}", flush=True) + + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) + + if not check_role_claims(payload, handler_call_details.method): + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + + if not check_custom_claims(payload): + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + + return continuation(handler_call_details) + except jwt.InvalidTokenError: + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Invalid token') + except Exception as e: + logger.error(str(e)) + return _unary_unary_rpc_terminator(grpc.StatusCode.UNKNOWN, str(e)) \ No newline at end of file diff --git a/fedn/fedn/network/grpc/server.py b/fedn/fedn/network/grpc/server.py index 59ed6b1ba..916b4756e 100644 --- a/fedn/fedn/network/grpc/server.py +++ b/fedn/fedn/network/grpc/server.py @@ -6,6 +6,7 @@ import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.log_config import (logger, set_log_level_from_string, set_log_stream) +from fedn.network.grpc.auth import JWTInterceptor class Server: @@ -16,7 +17,7 @@ def __init__(self, servicer, modelservicer, config): set_log_level_from_string(config.get('verbosity', "INFO")) set_log_stream(config.get('logfile', None)) - self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=350)) + self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=350), interceptors=[JWTInterceptor()]) self.certificate = None self.health_servicer = health.HealthServicer() From 6a40493487b807df08e3d0c93ac660c71a0f4d10 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Fri, 15 Mar 2024 11:33:00 +0100 Subject: [PATCH 12/16] move Envs to config --- fedn/fedn/common/config.py | 6 ++++++ fedn/fedn/network/api/auth.py | 11 ++++------- fedn/fedn/network/clients/connect.py | 7 +++---- fedn/fedn/network/grpc/auth.py | 7 +------ 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/fedn/fedn/common/config.py b/fedn/fedn/common/config.py index 1cd68fba4..0ee195046 100644 --- a/fedn/fedn/common/config.py +++ b/fedn/fedn/common/config.py @@ -5,6 +5,12 @@ global STATESTORE_CONFIG global MODELSTORAGE_CONFIG +SECRET_KEY = os.environ.get('FEDN_JWT_SECRET_KEY', False) +FEDN_JWT_CUSTOM_CLAIM_KEY = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_KEY', False) +FEDN_JWT_CUSTOM_CLAIM_VALUE = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_VALUE', False) + +FEDN_AUTH_WHITELIST_URL_PREFIX = os.environ.get('FEDN_AUTH_WHITELIST_URL_PREFIX', False) +FEDN_JWT_ALGORITHM = os.environ.get('FEDN_JWT_ALGORITHM', 'HS256') FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Token') FEDN_CUSTOM_URL_PREFIX = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') diff --git a/fedn/fedn/network/api/auth.py b/fedn/fedn/network/api/auth.py index b6921c330..59bda3ec5 100644 --- a/fedn/fedn/network/api/auth.py +++ b/fedn/fedn/network/api/auth.py @@ -4,13 +4,10 @@ import jwt from flask import jsonify, request -# Define your secret key for JWT -SECRET_KEY = os.environ.get('FEDN_JWT_SECRET_KEY', False) -FEDN_JWT_CUSTOM_CLAIM_KEY = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_KEY', False) -FEDN_JWT_CUSTOM_CLAIM_VALUE = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_VALUE', False) -FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') -FEDN_AUTH_WHITELIST_URL_PREFIX = os.environ.get('FEDN_AUTH_WHITELIST_URL_PREFIX', False) -FEDN_JWT_ALGORITHM = os.environ.get('FEDN_JWT_ALGORITHM', 'HS256') +from fedn.common.config import (FEDN_AUTH_SCHEME, + FEDN_AUTH_WHITELIST_URL_PREFIX, + FEDN_JWT_ALGORITHM, FEDN_JWT_CUSTOM_CLAIM_KEY, + FEDN_JWT_CUSTOM_CLAIM_VALUE, SECRET_KEY) def check_role_claims(payload, role): diff --git a/fedn/fedn/network/clients/connect.py b/fedn/fedn/network/clients/connect.py index 21c84f586..9c0928215 100644 --- a/fedn/fedn/network/clients/connect.py +++ b/fedn/fedn/network/clients/connect.py @@ -9,6 +9,7 @@ import requests +from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_CUSTOM_URL_PREFIX from fedn.common.log_config import logger @@ -78,13 +79,11 @@ def assign(self): try: retval = None payload = {'client_id': self.name, 'preferred_combiner': self.preferred_combiner} - url_prefix = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') - auth_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Token') - retval = requests.post(self.connect_string + url_prefix + '/add_client', + retval = requests.post(self.connect_string + FEDN_CUSTOM_URL_PREFIX + '/add_client', json=payload, verify=self.verify, allow_redirects=True, - headers={'Authorization': f"{auth_scheme} {self.token}"}) + headers={'Authorization': f"{FEDN_AUTH_SCHEME} {self.token}"}) except Exception as e: print('***** {}'.format(e), flush=True) return Status.Unassigned, {} diff --git a/fedn/fedn/network/grpc/auth.py b/fedn/fedn/network/grpc/auth.py index 208dcff46..cf5fc0fab 100644 --- a/fedn/fedn/network/grpc/auth.py +++ b/fedn/fedn/network/grpc/auth.py @@ -3,13 +3,10 @@ import grpc import jwt +from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_JWT_ALGORITHM, SECRET_KEY from fedn.common.log_config import logger from fedn.network.api.auth import check_custom_claims -SECRET_KEY = os.environ.get('FEDN_JWT_SECRET_KEY', False) -FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') -FEDN_JWT_ALGORITHM = os.environ.get('FEDN_JWT_ALGORITHM', 'HS256') - ENDPOINT_ROLES_MAPPING = { '/fedn.Combiner/TaskStream': ['client'], '/fedn.Combiner/SendModelUpdate': ['client'], @@ -51,8 +48,6 @@ def terminate(ignored_request, context): context.abort(code, details) return grpc.unary_unary_rpc_method_handler(terminate) - -# Define the gRPC interceptor class class JWTInterceptor(grpc.ServerInterceptor): def __init__(self): pass From fe9029dab517c101fac302297eb2c3425da6b313 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Tue, 19 Mar 2024 11:47:30 +0100 Subject: [PATCH 13/16] handle refresh token --- fedn/fedn/common/config.py | 2 + fedn/fedn/network/clients/client.py | 84 ++++++++++++++++++++-------- fedn/fedn/network/clients/connect.py | 32 ++++++++++- fedn/fedn/network/grpc/auth.py | 3 +- 4 files changed, 95 insertions(+), 26 deletions(-) diff --git a/fedn/fedn/common/config.py b/fedn/fedn/common/config.py index 0ee195046..0a261c4af 100644 --- a/fedn/fedn/common/config.py +++ b/fedn/fedn/common/config.py @@ -12,6 +12,8 @@ FEDN_AUTH_WHITELIST_URL_PREFIX = os.environ.get('FEDN_AUTH_WHITELIST_URL_PREFIX', False) FEDN_JWT_ALGORITHM = os.environ.get('FEDN_JWT_ALGORITHM', 'HS256') FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Token') +FEDN_AUTH_REFRESH_TOKEN_URI = os.environ.get('FEDN_AUTH_REFRESH_TOKEN_URI', False) +FEDN_AUTH_REFRESH_TOKEN = os.environ.get('FEDN_AUTH_REFRESH_TOKEN', False) FEDN_CUSTOM_URL_PREFIX = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') diff --git a/fedn/fedn/network/clients/client.py b/fedn/fedn/network/clients/client.py index 8e1cdfa7e..f75d907b8 100644 --- a/fedn/fedn/network/clients/client.py +++ b/fedn/fedn/network/clients/client.py @@ -21,6 +21,7 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc +from fedn.common.config import FEDN_AUTH_SCHEME from fedn.common.log_config import (logger, set_log_level_from_string, set_log_stream) from fedn.network.clients.connect import ConnectorClient, Status @@ -112,7 +113,7 @@ def _assign(self): while True: status, response = self.connector.assign() if status == Status.TryAgain: - logger.info(response) + logger.info("Assignment request failed. Retrying in 5 seconds.") time.sleep(5) continue if status == Status.Assigned: @@ -125,7 +126,10 @@ def _assign(self): logger.critical(response) sys.exit("Exiting: UnMatchedConfig") time.sleep(5) - + # If token was refreshed, update the config + if self.config['token'] != self.connector.token: + self.config['token'] = self.connector.token + self._add_grpc_metadata('authorization', f"{FEDN_AUTH_SCHEME} {self.config['token']}") logger.info("Assignment successfully received.") logger.info("Received combiner configuration: {}".format(client_config)) return client_config @@ -178,10 +182,9 @@ def _connect(self, client_config): host = client_config['host'] # Add host to gRPC metadata self._add_grpc_metadata('grpc-server', host) - auth_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Token') if self.config['token']: - self._add_grpc_metadata('authorization', f"{auth_scheme} {self.config['token']}") - logger.info("Client using metadata: {}.".format(self.metadata)) + self._add_grpc_metadata('authorization', f"{FEDN_AUTH_SCHEME} {self.config['token']}") + logger.debug("Client using metadata: {}.".format(self.metadata)) port = client_config['port'] secure = False if client_config['fqdn'] is not None: @@ -373,21 +376,24 @@ def get_model_from_combiner(self, id, timeout=20): request.sender.name = self.name request.sender.role = fedn.WORKER - for part in self.modelStub.Download(request, metadata=self.metadata): - - if part.status == fedn.ModelStatus.IN_PROGRESS: - data.write(part.data) + try: + for part in self.modelStub.Download(request, metadata=self.metadata): - if part.status == fedn.ModelStatus.OK: - return data + if part.status == fedn.ModelStatus.IN_PROGRESS: + data.write(part.data) - if part.status == fedn.ModelStatus.FAILED: - return None + if part.status == fedn.ModelStatus.OK: + return data - if part.status == fedn.ModelStatus.UNKNOWN: - if time.time() - time_start >= timeout: + if part.status == fedn.ModelStatus.FAILED: return None - continue + + if part.status == fedn.ModelStatus.UNKNOWN: + if time.time() - time_start >= timeout: + return None + continue + except grpc.RpcError as e: + logger.critical(f"GRPC: An error occurred during model download: {e}") return data @@ -412,7 +418,10 @@ def send_model_to_combiner(self, model, id): bt.seek(0, 0) - result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata) + try: + result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata) + except grpc.RpcError as e: + logger.critical(f"GRPC: An error occurred during model upload: {e}") return result @@ -451,16 +460,26 @@ def _listen_to_task_stream(self): # Handle gRPC errors status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: - logger.warning("GRPC server unavailable during model update request stream. Retrying.") + logger.warning("GRPC TaskStream: server unavailable during model update request stream. Retrying.") # Retry after a delay time.sleep(5) + if status_code == grpc.StatusCode.UNAUTHENTICATED: + details = e.details() + if details == 'Token expired': + logger.warning("GRPC TaskStream: Token expired. Reconnecting.") + self.detach() + + if status_code == grpc.StatusCode.CANCELLED: + # Expected if the client is detached + logger.critical("GRPC TaskStream: Client detached from combiner. Atempting to reconnect.") + else: # Log the error and continue - logger.error(f"An error occurred during model update request stream: {e}") + logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {e}") except Exception as ex: # Handle other exceptions - logger.error(f"An error occurred during model update request stream: {ex}") + logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {ex}") # Detach if not attached if not self._attached: @@ -657,12 +676,15 @@ def process_request(self): self.inbox.task_done() except queue.Empty: pass + except grpc.RpcError as e: + status_code = e.code() + logger.critical(f"GRPC process_request: An error occurred during process request: {e}") def _handle_combiner_failure(self): """ Register failed combiner connection.""" self._missed_heartbeat += 1 if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']: - self.detach()() + self.detach() def _send_heartbeat(self, update_frequency=2.0): """Send a heartbeat to the combiner. @@ -680,8 +702,13 @@ def _send_heartbeat(self, update_frequency=2.0): self._missed_heartbeat = 0 except grpc.RpcError as e: status_code = e.code() - logger.warning("Client heartbeat: GRPC error, {}. Retrying.".format( - status_code.name)) + if status_code == grpc.StatusCode.UNAVAILABLE: + logger.warning("GRPC hearbeat: server unavailable during send heartbeat. Retrying.") + if status_code == grpc.StatusCode.UNAUTHENTICATED: + details = e.details() + if details == 'Token expired': + logger.warning("GRPC hearbeat: Token expired. Reconnecting.") + self.detach() logger.debug(e) self._handle_combiner_failure() @@ -717,7 +744,16 @@ def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, self.logs.append( "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) - _ = self.connectorStub.SendStatus(status, metadata=self.metadata) + try: + _ = self.connectorStub.SendStatus(status, metadata=self.metadata) + except grpc.RpcError as e: + status_code = e.code() + if status_code == grpc.StatusCode.UNAVAILABLE: + logger.warning("GRPC SendStatus: server unavailable during send status.") + if status_code == grpc.StatusCode.UNAUTHENTICATED: + details = e.details() + if details == 'Token expired': + logger.warning("GRPC SendStatus: Token expired.") def run(self): """ Run the client. """ diff --git a/fedn/fedn/network/clients/connect.py b/fedn/fedn/network/clients/connect.py index 9c0928215..7c5d5524e 100644 --- a/fedn/fedn/network/clients/connect.py +++ b/fedn/fedn/network/clients/connect.py @@ -9,7 +9,9 @@ import requests -from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_CUSTOM_URL_PREFIX +from fedn.common.config import (FEDN_AUTH_REFRESH_TOKEN, + FEDN_AUTH_REFRESH_TOKEN_URI, FEDN_AUTH_SCHEME, + FEDN_CUSTOM_URL_PREFIX) from fedn.common.log_config import logger @@ -94,6 +96,16 @@ def assign(self): return Status.UnMatchedConfig, reason if retval.status_code == 401: + if 'message' in retval.json(): + reason = retval.json()['message'] + logger.warning(reason) + if reason == 'Token expired': + status_code = self.refresh_token() + if status_code >= 200 and status_code < 204: + logger.info("Token refreshed.") + return Status.TryAgain, reason + else: + return Status.UnAuthorized, "Could not refresh token" reason = "Unauthorized connection to reducer, make sure the correct token is set" return Status.UnAuthorized, reason @@ -116,3 +128,21 @@ def assign(self): return Status.Assigned, retval.json() return Status.Unassigned, None + + def refresh_token(self): + """ + Refresh client token. + + :return: Tuple with assingment status, combiner connection information if sucessful, else None. + :rtype: tuple(:class:`fedn.network.clients.connect.Status`, str) + """ + if not FEDN_AUTH_REFRESH_TOKEN_URI or not FEDN_AUTH_REFRESH_TOKEN: + logger.error("No refresh token URI/Token set, cannot refresh token.") + return 401 + + payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI, + verify=self.verify, + allow_redirects=True, + json={'refresh': FEDN_AUTH_REFRESH_TOKEN}) + self.token = payload.json()['access'] + return payload.status_code \ No newline at end of file diff --git a/fedn/fedn/network/grpc/auth.py b/fedn/fedn/network/grpc/auth.py index cf5fc0fab..18c3a5875 100644 --- a/fedn/fedn/network/grpc/auth.py +++ b/fedn/fedn/network/grpc/auth.py @@ -73,7 +73,6 @@ def intercept_service(self, continuation, handler_call_details): return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}') token = token.split(' ')[1] - print(f"HANDLER: {handler_call_details}", flush=True) try: payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) @@ -87,6 +86,8 @@ def intercept_service(self, continuation, handler_call_details): return continuation(handler_call_details) except jwt.InvalidTokenError: return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Invalid token') + except jwt.ExpiredSignatureError: + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token expired') except Exception as e: logger.error(str(e)) return _unary_unary_rpc_terminator(grpc.StatusCode.UNKNOWN, str(e)) \ No newline at end of file From f6e92ab53c7b3fbe224f8430e9e682fceb4c5856 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Tue, 19 Mar 2024 11:47:27 +0000 Subject: [PATCH 14/16] add docs --- docs/auth.rst | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++ docs/index.rst | 1 + 2 files changed, 91 insertions(+) create mode 100644 docs/auth.rst diff --git a/docs/auth.rst b/docs/auth.rst new file mode 100644 index 000000000..1866bd4f1 --- /dev/null +++ b/docs/auth.rst @@ -0,0 +1,90 @@ +.. _auth-label: + +Authentication and Authorization (RBAC) +============================================= +.. warning:: The FEDn RBAC system is an experimental feature and may change in the future. + +FEDn supports Role-Based Access Control (RBAC) for controlling access to the FEDn API and gRPC endpoints. The RBAC system is based on JSON Web Tokens (JWT) and is implemented using the `jwt` package. The JWT tokens are used to authenticate users and to control access to the FEDn API. +There are two types of JWT tokens used in the FEDn RBAC system: +- Access tokens: Used to authenticate users and to control access to the FEDn API. +- Refresh tokens: Used to obtain new access tokens when the old ones expire. + +.. note:: Please note that the FEDn RBAC system is not enabled by default and does not issue JWT tokens. It is used to integrate with external authentication and authorization systems such as FEDn Studio. + +FEDn RBAC system is by default configured with four types of roles: +- `admin`: Has full access to the FEDn API. This role is used to manage the FEDn network using the API client or the FEDn CLI. +- `combiner`: Has access to the /add_combiner endpoint in the API. +- `client`: Has access to the /add_client endpoint in the API and various gRPC endpoint to participate in federated learning sessions. + +A full list of the "roles to endpoint" mappings for gRPC can be found in the `fedn/network/grpc/auth.py`. For the API, the mappings are defined using custom decorators defined in `fedn/network/api/auth.py`. + +.. note:: The roles are handled by a custom claim in the JWT token called `role`. The claim is used to control access to the FEDn API and gRPC endpoints. + +To enable the FEDn RBAC system, you need to set the following environment variables in the controller and combiner: + +.. envvar:: FEDN_JWT_SECRET_KEY + :type: str + :required: yes + :default: None + :description: The secret key used for JWT token encryption. + +.. envvar:: FEDN_JWT_ALGORITHM + :type: str + :required: no + :default: "HS256" + :description: The algorithm used for JWT token encryption. + +.. envvar:: FEDN_AUTH_SCHEME + :type: str + :required: no + :default: "Token" + :description: The authentication scheme used in the FEDn API and gRPC interceptors. + +For further fexibility, you can also set the following environment variables: + +.. envvar:: FEDN_CUSTOM_URL_PREFIX + :type: str + :required: no + :default: None + :description: Add a custom URL prefix used in the FEDn API, such as /internal or /v1. + +.. envvar:: FEDN_AUTH_WHITELIST_URL + :type: str + :required: no + :default: None + :description: A URL patterns to the API that should be excluded from the FEDn RBAC system. For example /internal (to enable internal API calls). + +.. envvar:: FEDN_JWT_CUSTOM_CLAIM_KEY + :type: str + :required: no + :default: None + :description: The custom claim key used in the JWT token. + +.. envvar:: FEDN_JWT_CUSTOM_CLAIM_VALUE + :type: str + :required: no + :default: None + :description: The custom claim value used in the JWT token. + + +For the client you need to set the following environment variables: + +.. envvar:: FEDN_JWT_ACCESS_TOKEN + :type: str + :required: yes + :default: None + :description: The access token used to authenticate the client to the FEDn API. + +.. envvar:: FEDN_JWT_REFRESH_TOKEN + :type: str + :required: no + :default: None + :description: The refresh token used to obtain new access tokens when the old ones expire. + +.. envvar:: FEDN_AUTH_SCHEME + :type: str + :required: no + :default: "Token" + :description: The authentication scheme used in the FEDn API and gRPC interceptors. + +You can also use `--token` flags in the FEDn CLI to set the access token. \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 832be1334..77a265340 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,6 +9,7 @@ architecture aggregators helpers + auth faq modules From d3a242cddc2fe3c9b0d65e448f7027735619f99b Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Tue, 19 Mar 2024 12:30:55 +0000 Subject: [PATCH 15/16] fix --- fedn/fedn/network/clients/client.py | 5 ++--- fedn/fedn/network/clients/connect.py | 15 +++++++-------- fedn/fedn/network/grpc/auth.py | 28 +++++++++++++++------------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/fedn/fedn/network/clients/client.py b/fedn/fedn/network/clients/client.py index f75d907b8..7fc0af1b2 100644 --- a/fedn/fedn/network/clients/client.py +++ b/fedn/fedn/network/clients/client.py @@ -468,7 +468,7 @@ def _listen_to_task_stream(self): if details == 'Token expired': logger.warning("GRPC TaskStream: Token expired. Reconnecting.") self.detach() - + if status_code == grpc.StatusCode.CANCELLED: # Expected if the client is detached logger.critical("GRPC TaskStream: Client detached from combiner. Atempting to reconnect.") @@ -677,7 +677,6 @@ def process_request(self): except queue.Empty: pass except grpc.RpcError as e: - status_code = e.code() logger.critical(f"GRPC process_request: An error occurred during process request: {e}") def _handle_combiner_failure(self): @@ -753,7 +752,7 @@ def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, if status_code == grpc.StatusCode.UNAUTHENTICATED: details = e.details() if details == 'Token expired': - logger.warning("GRPC SendStatus: Token expired.") + logger.warning("GRPC SendStatus: Token expired.") def run(self): """ Run the client. """ diff --git a/fedn/fedn/network/clients/connect.py b/fedn/fedn/network/clients/connect.py index 7c5d5524e..2e9345ebb 100644 --- a/fedn/fedn/network/clients/connect.py +++ b/fedn/fedn/network/clients/connect.py @@ -5,7 +5,6 @@ # # import enum -import os import requests @@ -128,7 +127,7 @@ def assign(self): return Status.Assigned, retval.json() return Status.Unassigned, None - + def refresh_token(self): """ Refresh client token. @@ -139,10 +138,10 @@ def refresh_token(self): if not FEDN_AUTH_REFRESH_TOKEN_URI or not FEDN_AUTH_REFRESH_TOKEN: logger.error("No refresh token URI/Token set, cannot refresh token.") return 401 - - payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI, - verify=self.verify, - allow_redirects=True, - json={'refresh': FEDN_AUTH_REFRESH_TOKEN}) + + payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI, + verify=self.verify, + allow_redirects=True, + json={'refresh': FEDN_AUTH_REFRESH_TOKEN}) self.token = payload.json()['access'] - return payload.status_code \ No newline at end of file + return payload.status_code diff --git a/fedn/fedn/network/grpc/auth.py b/fedn/fedn/network/grpc/auth.py index 18c3a5875..d879cd812 100644 --- a/fedn/fedn/network/grpc/auth.py +++ b/fedn/fedn/network/grpc/auth.py @@ -1,5 +1,3 @@ -import os - import grpc import jwt @@ -34,20 +32,24 @@ 'grpc_health_probe' ] + def check_role_claims(payload, endpoint): - user_role = payload.get('role', '') - + user_role = payload.get('role', '') + # Perform endpoint-specific RBAC check allowed_roles = ENDPOINT_ROLES_MAPPING.get(endpoint) - if allowed_roles and not user_role in allowed_roles: + if allowed_roles and user_role not in allowed_roles: return False return True + def _unary_unary_rpc_terminator(code, details): def terminate(ignored_request, context): context.abort(code, details) return grpc.unary_unary_rpc_method_handler(terminate) + + class JWTInterceptor(grpc.ServerInterceptor): def __init__(self): pass @@ -57,32 +59,32 @@ def intercept_service(self, continuation, handler_call_details): if not SECRET_KEY: return continuation(handler_call_details) metadata = dict(handler_call_details.invocation_metadata) - # Pass whitelisted methods + # Pass whitelisted methods if handler_call_details.method in ENDPOINT_WHITELIST: return continuation(handler_call_details) # Pass if the request comes from whitelisted user agents user_agent = metadata.get('user-agent').split(' ')[0] if user_agent in USER_AGENT_WHITELIST: return continuation(handler_call_details) - + token = metadata.get('authorization') if token is None: return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token is missing') - + if not token.startswith(FEDN_AUTH_SCHEME): return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}') - + token = token.split(' ')[1] try: payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) - + if not check_role_claims(payload, handler_call_details.method): return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') - + if not check_custom_claims(payload): return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') - + return continuation(handler_call_details) except jwt.InvalidTokenError: return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Invalid token') @@ -90,4 +92,4 @@ def intercept_service(self, continuation, handler_call_details): return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token expired') except Exception as e: logger.error(str(e)) - return _unary_unary_rpc_terminator(grpc.StatusCode.UNKNOWN, str(e)) \ No newline at end of file + return _unary_unary_rpc_terminator(grpc.StatusCode.UNKNOWN, str(e)) From 399c0459740b87fd5b196133413c1db11e3e6f28 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Tue, 19 Mar 2024 12:32:39 +0000 Subject: [PATCH 16/16] fix --- fedn/fedn/network/api/auth.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fedn/fedn/network/api/auth.py b/fedn/fedn/network/api/auth.py index 59bda3ec5..bf43c2f69 100644 --- a/fedn/fedn/network/api/auth.py +++ b/fedn/fedn/network/api/auth.py @@ -1,4 +1,3 @@ -import os from functools import wraps import jwt