diff --git a/docs/auth.rst b/docs/auth.rst new file mode 100644 index 000000000..1866bd4f1 --- /dev/null +++ b/docs/auth.rst @@ -0,0 +1,90 @@ +.. _auth-label: + +Authentication and Authorization (RBAC) +============================================= +.. warning:: The FEDn RBAC system is an experimental feature and may change in the future. + +FEDn supports Role-Based Access Control (RBAC) for controlling access to the FEDn API and gRPC endpoints. The RBAC system is based on JSON Web Tokens (JWT) and is implemented using the `jwt` package. The JWT tokens are used to authenticate users and to control access to the FEDn API. +There are two types of JWT tokens used in the FEDn RBAC system: +- Access tokens: Used to authenticate users and to control access to the FEDn API. +- Refresh tokens: Used to obtain new access tokens when the old ones expire. + +.. note:: Please note that the FEDn RBAC system is not enabled by default and does not issue JWT tokens. It is used to integrate with external authentication and authorization systems such as FEDn Studio. + +FEDn RBAC system is by default configured with four types of roles: +- `admin`: Has full access to the FEDn API. This role is used to manage the FEDn network using the API client or the FEDn CLI. +- `combiner`: Has access to the /add_combiner endpoint in the API. +- `client`: Has access to the /add_client endpoint in the API and various gRPC endpoint to participate in federated learning sessions. + +A full list of the "roles to endpoint" mappings for gRPC can be found in the `fedn/network/grpc/auth.py`. For the API, the mappings are defined using custom decorators defined in `fedn/network/api/auth.py`. + +.. note:: The roles are handled by a custom claim in the JWT token called `role`. The claim is used to control access to the FEDn API and gRPC endpoints. + +To enable the FEDn RBAC system, you need to set the following environment variables in the controller and combiner: + +.. envvar:: FEDN_JWT_SECRET_KEY + :type: str + :required: yes + :default: None + :description: The secret key used for JWT token encryption. + +.. envvar:: FEDN_JWT_ALGORITHM + :type: str + :required: no + :default: "HS256" + :description: The algorithm used for JWT token encryption. + +.. envvar:: FEDN_AUTH_SCHEME + :type: str + :required: no + :default: "Token" + :description: The authentication scheme used in the FEDn API and gRPC interceptors. + +For further fexibility, you can also set the following environment variables: + +.. envvar:: FEDN_CUSTOM_URL_PREFIX + :type: str + :required: no + :default: None + :description: Add a custom URL prefix used in the FEDn API, such as /internal or /v1. + +.. envvar:: FEDN_AUTH_WHITELIST_URL + :type: str + :required: no + :default: None + :description: A URL patterns to the API that should be excluded from the FEDn RBAC system. For example /internal (to enable internal API calls). + +.. envvar:: FEDN_JWT_CUSTOM_CLAIM_KEY + :type: str + :required: no + :default: None + :description: The custom claim key used in the JWT token. + +.. envvar:: FEDN_JWT_CUSTOM_CLAIM_VALUE + :type: str + :required: no + :default: None + :description: The custom claim value used in the JWT token. + + +For the client you need to set the following environment variables: + +.. envvar:: FEDN_JWT_ACCESS_TOKEN + :type: str + :required: yes + :default: None + :description: The access token used to authenticate the client to the FEDn API. + +.. envvar:: FEDN_JWT_REFRESH_TOKEN + :type: str + :required: no + :default: None + :description: The refresh token used to obtain new access tokens when the old ones expire. + +.. envvar:: FEDN_AUTH_SCHEME + :type: str + :required: no + :default: "Token" + :description: The authentication scheme used in the FEDn API and gRPC interceptors. + +You can also use `--token` flags in the FEDn CLI to set the access token. \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 832be1334..77a265340 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,6 +9,7 @@ architecture aggregators helpers + auth faq modules diff --git a/fedn/fedn/common/config.py b/fedn/fedn/common/config.py index f6c827d0d..0a261c4af 100644 --- a/fedn/fedn/common/config.py +++ b/fedn/fedn/common/config.py @@ -5,6 +5,17 @@ global STATESTORE_CONFIG global MODELSTORAGE_CONFIG +SECRET_KEY = os.environ.get('FEDN_JWT_SECRET_KEY', False) +FEDN_JWT_CUSTOM_CLAIM_KEY = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_KEY', False) +FEDN_JWT_CUSTOM_CLAIM_VALUE = os.environ.get('FEDN_JWT_CUSTOM_CLAIM_VALUE', False) + +FEDN_AUTH_WHITELIST_URL_PREFIX = os.environ.get('FEDN_AUTH_WHITELIST_URL_PREFIX', False) +FEDN_JWT_ALGORITHM = os.environ.get('FEDN_JWT_ALGORITHM', 'HS256') +FEDN_AUTH_SCHEME = os.environ.get('FEDN_AUTH_SCHEME', 'Token') +FEDN_AUTH_REFRESH_TOKEN_URI = os.environ.get('FEDN_AUTH_REFRESH_TOKEN_URI', False) +FEDN_AUTH_REFRESH_TOKEN = os.environ.get('FEDN_AUTH_REFRESH_TOKEN', False) +FEDN_CUSTOM_URL_PREFIX = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') + def get_environment_config(): """ Get the configuration from environment variables. diff --git a/fedn/fedn/network/api/auth.py b/fedn/fedn/network/api/auth.py new file mode 100644 index 000000000..bf43c2f69 --- /dev/null +++ b/fedn/fedn/network/api/auth.py @@ -0,0 +1,70 @@ +from functools import wraps + +import jwt +from flask import jsonify, request + +from fedn.common.config import (FEDN_AUTH_SCHEME, + FEDN_AUTH_WHITELIST_URL_PREFIX, + FEDN_JWT_ALGORITHM, FEDN_JWT_CUSTOM_CLAIM_KEY, + FEDN_JWT_CUSTOM_CLAIM_VALUE, SECRET_KEY) + + +def check_role_claims(payload, role): + if 'role' not in payload: + return False + if payload['role'] != role: + return False + + return True + + +def check_custom_claims(payload): + if FEDN_JWT_CUSTOM_CLAIM_KEY and FEDN_JWT_CUSTOM_CLAIM_VALUE: + if payload[FEDN_JWT_CUSTOM_CLAIM_KEY] != FEDN_JWT_CUSTOM_CLAIM_VALUE: + return False + return True + + +def if_whitelisted_url_prefix(path): + if FEDN_AUTH_WHITELIST_URL_PREFIX and path.startswith(FEDN_AUTH_WHITELIST_URL_PREFIX): + return True + else: + return False + + +def jwt_auth_required(role=None): + def actual_decorator(func): + if not SECRET_KEY: + return func + + @wraps(func) + def decorated(*args, **kwargs): + if if_whitelisted_url_prefix(request.path): + return func(*args, **kwargs) + token = request.headers.get('Authorization') + if not token: + return jsonify({'message': 'Missing token'}), 401 + # Get token from the header Bearer + if token.startswith(FEDN_AUTH_SCHEME): + token = token.split(' ')[1] + else: + return jsonify({'message': + f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}' + }), 401 + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) + if not check_role_claims(payload, role): + return jsonify({'message': 'Invalid token'}), 401 + if not check_custom_claims(payload): + return jsonify({'message': 'Invalid token'}), 401 + + except jwt.ExpiredSignatureError: + return jsonify({'message': 'Token expired'}), 401 + + except jwt.InvalidTokenError: + return jsonify({'message': 'Invalid token'}), 401 + + return func(*args, **kwargs) + + return decorated + return actual_decorator diff --git a/fedn/fedn/network/api/client.py b/fedn/fedn/network/api/client.py index b33c38626..58337137f 100644 --- a/fedn/fedn/network/api/client.py +++ b/fedn/fedn/network/api/client.py @@ -27,7 +27,7 @@ def __init__(self, host, port=None, secure=False, verify=False, token=None, auth # Auth scheme passed as argument overrides environment variable. # "Token" is the default auth scheme. if not auth_scheme: - auth_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Token") + auth_scheme = os.environ.get("FEDN_AUTH_SCHEME", "Bearer") # Override potential env variable if token is passed as argument. if not token: token = os.environ.get("FEDN_AUTH_TOKEN", False) diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index 597feca20..10030899a 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -1,17 +1,13 @@ +import os + from flasgger import Swagger from flask import Flask, jsonify, request from fedn.common.config import (get_controller_config, get_modelstorage_config, get_network_config, get_statestore_config) +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.interface import API -from fedn.network.api.v1.client_routes import bp as client_bp -from fedn.network.api.v1.combiner_routes import bp as combiner_bp -from fedn.network.api.v1.model_routes import bp as model_bp -from fedn.network.api.v1.package_routes import bp as package_bp -from fedn.network.api.v1.round_routes import bp as round_bp -from fedn.network.api.v1.session_routes import bp as session_bp -from fedn.network.api.v1.status_routes import bp as status_bp -from fedn.network.api.v1.validation_routes import bp as validation_bp +from fedn.network.api.v1 import _routes from fedn.network.controller.control import Control from fedn.network.storage.statestore.mongostatestore import MongoStateStore @@ -21,30 +17,31 @@ statestore = MongoStateStore(network_id, statestore_config["mongo_config"]) statestore.set_storage_backend(modelstorage_config) control = Control(statestore=statestore) + +custom_url_prefix = os.environ.get("FEDN_CUSTOM_URL_PREFIX", False) api = API(statestore, control) app = Flask(__name__) -app.register_blueprint(client_bp) -app.register_blueprint(status_bp) -app.register_blueprint(model_bp) -app.register_blueprint(validation_bp) -app.register_blueprint(package_bp) -app.register_blueprint(session_bp) -app.register_blueprint(combiner_bp) -app.register_blueprint(round_bp) +for bp in _routes: + app.register_blueprint(bp) + if custom_url_prefix: + app.register_blueprint(bp, + name=f"{bp.name}_custom", + url_prefix=f"{custom_url_prefix}{bp.url_prefix}") template = { - "swagger": "2.0", - "info": { - "title": "FEDn API", - "description": "API for the FEDn network.", - "version": "0.0.1" - } + "swagger": "2.0", + "info": { + "title": "FEDn API", + "description": "API for the FEDn network.", + "version": "0.0.1" + } } swagger = Swagger(app, template=template) @app.route("/get_model_trail", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model_trail(): """Get the model trail for a given session. param: session: The session id to get the model trail for. @@ -55,7 +52,12 @@ def get_model_trail(): return api.get_model_trail() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model_trail", view_func=get_model_trail, methods=["GET"]) + + @app.route("/get_model_ancestors", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model_ancestors(): """Get the ancestors of a model. param: model: The model id to get the ancestors for. @@ -71,7 +73,12 @@ def get_model_ancestors(): return api.get_model_ancestors(model, limit) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model_ancestors", view_func=get_model_ancestors, methods=["GET"]) + + @app.route("/get_model_descendants", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model_descendants(): """Get the ancestors of a model. param: model: The model id to get the child for. @@ -87,7 +94,12 @@ def get_model_descendants(): return api.get_model_descendants(model, limit) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model_descendants", view_func=get_model_descendants, methods=["GET"]) + + @app.route("/list_models", methods=["GET"]) +@jwt_auth_required(role="admin") def list_models(): """Get models from the statestore. param: @@ -108,7 +120,12 @@ def list_models(): return api.get_models(session_id, limit, skip, include_active) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_models", view_func=list_models, methods=["GET"]) + + @app.route("/get_model", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model(): """Get a model from the statestore. param: model: The model id to get. @@ -123,7 +140,12 @@ def get_model(): return api.get_model(model) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_model", view_func=get_model, methods=["GET"]) + + @app.route("/delete_model_trail", methods=["GET", "POST"]) +@jwt_auth_required(role="admin") def delete_model_trail(): """Delete the model trail for a given session. param: session: The session id to delete the model trail for. @@ -134,7 +156,12 @@ def delete_model_trail(): return jsonify({"message": "Not implemented"}), 501 +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/delete_model_trail", view_func=delete_model_trail, methods=["GET", "POST"]) + + @app.route("/list_clients", methods=["GET"]) +@jwt_auth_required(role="admin") def list_clients(): """Get all clients from the statestore. return: All clients as a json object. @@ -148,7 +175,12 @@ def list_clients(): return api.get_clients(limit, skip, status) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_clients", view_func=list_clients, methods=["GET"]) + + @app.route("/get_active_clients", methods=["GET"]) +@jwt_auth_required(role="admin") def get_active_clients(): """Get all active clients from the statestore. param: combiner_id: The combiner id to get active clients for. @@ -165,7 +197,12 @@ def get_active_clients(): return api.get_active_clients(combiner_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_active_clients", view_func=get_active_clients, methods=["GET"]) + + @app.route("/list_combiners", methods=["GET"]) +@jwt_auth_required(role="admin") def list_combiners(): """Get all combiners in the network. return: All combiners as a json object. @@ -178,7 +215,12 @@ def list_combiners(): return api.get_all_combiners(limit, skip) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_combiners", view_func=list_combiners, methods=["GET"]) + + @app.route("/get_combiner", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiner(): """Get a combiner from the statestore. param: combiner_id: The combiner id to get. @@ -195,7 +237,12 @@ def get_combiner(): return api.get_combiner(combiner_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_combiner", view_func=get_combiner, methods=["GET"]) + + @app.route("/list_rounds", methods=["GET"]) +@jwt_auth_required(role="admin") def list_rounds(): """Get all rounds from the statestore. return: All rounds as a json object. @@ -204,7 +251,12 @@ def list_rounds(): return api.get_all_rounds() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_rounds", view_func=list_rounds, methods=["GET"]) + + @app.route("/get_round", methods=["GET"]) +@jwt_auth_required(role="admin") def get_round(): """Get a round from the statestore. param: round_id: The round id to get. @@ -218,7 +270,12 @@ def get_round(): return api.get_round(round_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_round", view_func=get_round, methods=["GET"]) + + @app.route("/start_session", methods=["GET", "POST"]) +@jwt_auth_required(role="admin") def start_session(): """Start a new session. return: The response from control. @@ -228,7 +285,12 @@ def start_session(): return api.start_session(**json_data) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/start_session", view_func=start_session, methods=["GET", "POST"]) + + @app.route("/list_sessions", methods=["GET"]) +@jwt_auth_required(role="admin") def list_sessions(): """Get all sessions from the statestore. return: All sessions as a json object. @@ -240,7 +302,12 @@ def list_sessions(): return api.get_all_sessions(limit, skip) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_sessions", view_func=list_sessions, methods=["GET"]) + + @app.route("/get_session", methods=["GET"]) +@jwt_auth_required(role="admin") def get_session(): """Get a session from the statestore. param: session_id: The session id to get. @@ -257,13 +324,23 @@ def get_session(): return api.get_session(session_id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_session", view_func=get_session, methods=["GET"]) + + @app.route("/set_active_package", methods=["PUT"]) +@jwt_auth_required(role="admin") def set_active_package(): id = request.args.get("id", None) return api.set_active_compute_package(id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_active_package", view_func=set_active_package, methods=["PUT"]) + + @app.route("/set_package", methods=["POST"]) +@jwt_auth_required(role="admin") def set_package(): """ Set the compute package in the statestore. Usage with curl: @@ -295,7 +372,12 @@ def set_package(): ) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_package", view_func=set_package, methods=["POST"]) + + @app.route("/get_package", methods=["GET"]) +@jwt_auth_required(role="admin") def get_package(): """Get the compute package from the statestore. return: The compute package as a json object. @@ -304,7 +386,12 @@ def get_package(): return api.get_compute_package() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_package", view_func=get_package, methods=["GET"]) + + @app.route("/list_compute_packages", methods=["GET"]) +@jwt_auth_required(role="admin") def list_compute_packages(): """Get the compute package from the statestore. return: The compute package as a json object. @@ -320,7 +407,12 @@ def list_compute_packages(): ) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_compute_packages", view_func=list_compute_packages, methods=["GET"]) + + @app.route("/download_package", methods=["GET"]) +@jwt_auth_required(role="client") def download_package(): """Download the compute package. return: The compute package as a json object. @@ -330,13 +422,23 @@ def download_package(): return api.download_compute_package(name) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/download_package", view_func=download_package, methods=["GET"]) + + @app.route("/get_package_checksum", methods=["GET"]) +@jwt_auth_required(role="admin") def get_package_checksum(): name = request.args.get("name", None) return api.get_checksum(name) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_package_checksum", view_func=get_package_checksum, methods=["GET"]) + + @app.route("/get_latest_model", methods=["GET"]) +@jwt_auth_required(role="admin") def get_latest_model(): """Get the latest model from the statestore. return: The initial model as a json object. @@ -345,7 +447,12 @@ def get_latest_model(): return api.get_latest_model() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_latest_model", view_func=get_latest_model, methods=["GET"]) + + @app.route("/set_current_model", methods=["PUT"]) +@jwt_auth_required(role="admin") def set_current_model(): """Set the initial model in the statestore and upload to model repository. Usage with curl: @@ -364,10 +471,14 @@ def set_current_model(): return api.set_current_model(id) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_current_model", view_func=set_current_model, methods=["PUT"]) + # Get initial model endpoint @app.route("/get_initial_model", methods=["GET"]) +@jwt_auth_required(role="admin") def get_initial_model(): """Get the initial model from the statestore. return: The initial model as a json object. @@ -376,7 +487,12 @@ def get_initial_model(): return api.get_initial_model() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_initial_model", view_func=get_initial_model, methods=["GET"]) + + @app.route("/set_initial_model", methods=["POST"]) +@jwt_auth_required(role="admin") def set_initial_model(): """Set the initial model in the statestore and upload to model repository. Usage with curl: @@ -396,7 +512,12 @@ def set_initial_model(): return api.set_initial_model(file) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/set_initial_model", view_func=set_initial_model, methods=["POST"]) + + @app.route("/get_controller_status", methods=["GET"]) +@jwt_auth_required(role="admin") def get_controller_status(): """Get the status of the controller. return: The status as a json object. @@ -405,7 +526,12 @@ def get_controller_status(): return api.get_controller_status() +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_controller_status", view_func=get_controller_status, methods=["GET"]) + + @app.route("/get_client_config", methods=["GET"]) +@jwt_auth_required(role="admin") def get_client_config(): """Get the client configuration. return: The client configuration as a json object. @@ -416,7 +542,12 @@ def get_client_config(): return api.get_client_config(checksum) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_client_config", view_func=get_client_config, methods=["GET"]) + + @app.route("/get_events", methods=["GET"]) +@jwt_auth_required(role="admin") def get_events(): """Get the events from the statestore. return: The events as a json object. @@ -428,7 +559,12 @@ def get_events(): return api.get_events(**kwargs) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_events", view_func=get_client_config, methods=["GET"]) + + @app.route("/list_validations", methods=["GET"]) +@jwt_auth_required(role="admin") def list_validations(): """Get all validations from the statestore. return: All validations as a json object. @@ -439,7 +575,12 @@ def list_validations(): return api.get_all_validations(**kwargs) +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_validations", view_func=list_validations, methods=["GET"]) + + @app.route("/add_combiner", methods=["POST"]) +@jwt_auth_required(role="combiner") def add_combiner(): """Add a combiner to the network. return: The response from the statestore. @@ -454,7 +595,12 @@ def add_combiner(): return response +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/add_combiner", view_func=add_combiner, methods=["POST"]) + + @app.route("/add_client", methods=["POST"]) +@jwt_auth_required(role="client") def add_client(): """Add a client to the network. return: The response from control. @@ -470,7 +616,12 @@ def add_client(): return response +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/add_client", view_func=add_client, methods=["POST"]) + + @app.route("/list_combiners_data", methods=["POST"]) +@jwt_auth_required(role="admin") def list_combiners_data(): """List data from combiners. return: The response from control. @@ -489,7 +640,12 @@ def list_combiners_data(): return response +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/list_combiners_data", view_func=list_combiners_data, methods=["POST"]) + + @app.route("/get_plot_data", methods=["GET"]) +@jwt_auth_required(role="admin") def get_plot_data(): """Get plot data from the statestore. rtype: json @@ -503,6 +659,9 @@ def get_plot_data(): return response +if custom_url_prefix: + app.add_url_rule(f"{custom_url_prefix}/get_plot_data", view_func=get_plot_data, methods=["GET"]) + if __name__ == "__main__": config = get_controller_config() port = config["port"] diff --git a/fedn/fedn/network/api/v1/__init__.py b/fedn/fedn/network/api/v1/__init__.py index e69de29bb..bb8d8d33c 100644 --- a/fedn/fedn/network/api/v1/__init__.py +++ b/fedn/fedn/network/api/v1/__init__.py @@ -0,0 +1,10 @@ +from fedn.network.api.v1.client_routes import bp as client_bp +from fedn.network.api.v1.combiner_routes import bp as combiner_bp +from fedn.network.api.v1.model_routes import bp as model_bp +from fedn.network.api.v1.package_routes import bp as package_bp +from fedn.network.api.v1.round_routes import bp as round_bp +from fedn.network.api.v1.session_routes import bp as session_bp +from fedn.network.api.v1.status_routes import bp as status_bp +from fedn.network.api.v1.validation_routes import bp as validation_bp + +_routes = [client_bp, combiner_bp, model_bp, package_bp, round_bp, session_bp, status_bp, validation_bp] diff --git a/fedn/fedn/network/api/v1/client_routes.py b/fedn/fedn/network/api/v1/client_routes.py index c4215bd54..30322a9b7 100644 --- a/fedn/fedn/network/api/v1/client_routes.py +++ b/fedn/fedn/network/api/v1/client_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.client_store import ClientStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_clients(): """Get clients Retrieves a list of clients based on the provided parameters. @@ -127,6 +129,7 @@ def get_clients(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_clients(): """List clients Retrieves a list of clients based on the provided parameters. @@ -213,6 +216,7 @@ def list_clients(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_clients_count(): """Clients count Retrieves the total number of clients based on the provided parameters. @@ -273,6 +277,7 @@ def get_clients_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def clients_count(): """Clients count Retrieves the total number of clients based on the provided parameters. @@ -325,6 +330,7 @@ def clients_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_client(id: str): """Get client Retrieves a client based on the provided id. diff --git a/fedn/fedn/network/api/v1/combiner_routes.py b/fedn/fedn/network/api/v1/combiner_routes.py index ba6bf5dbd..7d1761bee 100644 --- a/fedn/fedn/network/api/v1/combiner_routes.py +++ b/fedn/fedn/network/api/v1/combiner_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.combiner_store import CombinerStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiners(): """Get combiners Retrieves a list of combiners based on the provided parameters. @@ -119,6 +121,7 @@ def get_combiners(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_combiners(): """List combiners Retrieves a list of combiners based on the provided parameters. @@ -203,6 +206,7 @@ def list_combiners(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiners_count(): """Combiners count Retrieves the count of combiners based on the provided parameters. @@ -249,6 +253,7 @@ def get_combiners_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def combiners_count(): """Combiners count Retrieves the count of combiners based on the provided parameters. @@ -297,6 +302,7 @@ def combiners_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_combiner(id: str): """Get combiner Retrieves a combiner based on the provided id. diff --git a/fedn/fedn/network/api/v1/model_routes.py b/fedn/fedn/network/api/v1/model_routes.py index 2db99083c..8e9308408 100644 --- a/fedn/fedn/network/api/v1/model_routes.py +++ b/fedn/fedn/network/api/v1/model_routes.py @@ -3,6 +3,7 @@ import numpy as np from flask import Blueprint, jsonify, request, send_file +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers, mdb, @@ -22,6 +23,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_models(): """Get models Retrieves a list of models based on the provided parameters. @@ -124,6 +126,7 @@ def get_models(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_models(): """List models Retrieves a list of models based on the provided parameters. @@ -210,6 +213,7 @@ def list_models(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_models_count(): """Models count Retrieves the count of models based on the provided parameters. @@ -257,6 +261,7 @@ def get_models_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def models_count(): """Models count Retrieves the count of models based on the provided parameters. @@ -308,6 +313,7 @@ def models_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_model(id: str): """Get model Retrieves a model based on the provided id. @@ -353,6 +359,7 @@ def get_model(id: str): @bp.route("//descendants", methods=["GET"]) +@jwt_auth_required(role="admin") def get_descendants(id: str): """Get model descendants Retrieves a list of model descendants of the provided model id/model property. @@ -406,6 +413,7 @@ def get_descendants(id: str): @bp.route("//ancestors", methods=["GET"]) +@jwt_auth_required(role="admin") def get_ancestors(id: str): """Get model ancestors Retrieves a list of model ancestors of the provided model id/model property. diff --git a/fedn/fedn/network/api/v1/package_routes.py b/fedn/fedn/network/api/v1/package_routes.py index b62aa96a9..30ac4d51e 100644 --- a/fedn/fedn/network/api/v1/package_routes.py +++ b/fedn/fedn/network/api/v1/package_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb) @@ -12,6 +13,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_packages(): """Get packages Retrieves a list of packages based on the provided parameters. @@ -132,6 +134,7 @@ def get_packages(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_packages(): """List packages Retrieves a list of packages based on the provided parameters. @@ -221,6 +224,7 @@ def list_packages(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_packages_count(): """Package count Retrieves the count of packages based on the provided parameters. @@ -281,6 +285,7 @@ def get_packages_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def packages_count(): """Package count Retrieves the count of packages based on the provided parameters. @@ -342,6 +347,7 @@ def packages_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_package(id: str): """Get package Retrieves a package based on the provided id. @@ -388,6 +394,7 @@ def get_package(id: str): @bp.route("/active", methods=["GET"]) +@jwt_auth_required(role="admin") def get_active_package(): """Get active package Retrieves the active package diff --git a/fedn/fedn/network/api/v1/round_routes.py b/fedn/fedn/network/api/v1/round_routes.py index 317c767ee..8890c510a 100644 --- a/fedn/fedn/network/api/v1/round_routes.py +++ b/fedn/fedn/network/api/v1/round_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.round_store import RoundStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_rounds(): """Get rounds Retrieves a list of rounds based on the provided parameters. @@ -107,6 +109,7 @@ def get_rounds(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_rounds(): """List rounds Retrieves a list of rounds based on the provided parameters. @@ -187,6 +190,7 @@ def list_rounds(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_rounds_count(): """Rounds count Retrieves the count of rounds based on the provided parameters. @@ -227,6 +231,7 @@ def get_rounds_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def rounds_count(): """Rounds count Retrieves the count of rounds based on the provided parameters. @@ -271,6 +276,7 @@ def rounds_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_round(id: str): """Get round Retrieves a round based on the provided id. diff --git a/fedn/fedn/network/api/v1/session_routes.py b/fedn/fedn/network/api/v1/session_routes.py index 4d3fe493f..99c52d8db 100644 --- a/fedn/fedn/network/api/v1/session_routes.py +++ b/fedn/fedn/network/api/v1/session_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb) from fedn.network.storage.statestore.stores.session_store import SessionStore @@ -11,6 +12,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_sessions(): """Get sessions Retrieves a list of sessions based on the provided parameters. @@ -99,6 +101,7 @@ def get_sessions(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_sessions(): """List sessions Retrieves a list of sessions based on the provided parameters. @@ -178,6 +181,7 @@ def list_sessions(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_sessions_count(): """Sessions count Retrieves the count of sessions based on the provided parameters. @@ -218,6 +222,7 @@ def get_sessions_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def sessions_count(): """Sessions count Retrieves the count of sessions based on the provided parameters. @@ -262,6 +267,7 @@ def sessions_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_session(id: str): """Get session Retrieves a session based on the provided id. diff --git a/fedn/fedn/network/api/v1/status_routes.py b/fedn/fedn/network/api/v1/status_routes.py index 562b971db..e78c18533 100644 --- a/fedn/fedn/network/api/v1/status_routes.py +++ b/fedn/fedn/network/api/v1/status_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb) @@ -12,6 +13,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_statuses(): """Get statuses Retrieves a list of statuses based on the provided parameters. @@ -144,6 +146,7 @@ def get_statuses(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_statuses(): """Get statuses Retrieves a list of statuses based on the provided parameters. @@ -246,6 +249,7 @@ def list_statuses(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_statuses_count(): """Statuses count Retrieves the count of statuses based on the provided parameters. @@ -307,6 +311,7 @@ def get_statuses_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def statuses_count(): """Statuses count Retrieves the count of statuses based on the provided parameters. @@ -368,6 +373,7 @@ def statuses_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_status(id: str): """Get status Retrieves a status based on the provided id. diff --git a/fedn/fedn/network/api/v1/validation_routes.py b/fedn/fedn/network/api/v1/validation_routes.py index 874154dbc..96fbac55c 100644 --- a/fedn/fedn/network/api/v1/validation_routes.py +++ b/fedn/fedn/network/api/v1/validation_routes.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from fedn.network.api.auth import jwt_auth_required from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, get_typed_list_headers, get_use_typing, mdb) @@ -13,6 +14,7 @@ @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_validations(): """Get validations Retrieves a list of validations based on the provided parameters. @@ -152,6 +154,7 @@ def get_validations(): @bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") def list_validations(): """Get validations Retrieves a list of validations based on the provided parameters. @@ -257,6 +260,7 @@ def list_validations(): @bp.route("/count", methods=["GET"]) +@jwt_auth_required(role="admin") def get_validations_count(): """Validations count Retrieves the count of validations based on the provided parameters. @@ -322,6 +326,7 @@ def get_validations_count(): @bp.route("/count", methods=["POST"]) +@jwt_auth_required(role="admin") def validations_count(): """Validations count Retrieves the count of validations based on the provided parameters. @@ -386,6 +391,7 @@ def validations_count(): @bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") def get_validation(id: str): """Get validation Retrieves a validation based on the provided id. diff --git a/fedn/fedn/network/clients/client.py b/fedn/fedn/network/clients/client.py index a6408d66c..7fc0af1b2 100644 --- a/fedn/fedn/network/clients/client.py +++ b/fedn/fedn/network/clients/client.py @@ -21,6 +21,7 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc +from fedn.common.config import FEDN_AUTH_SCHEME from fedn.common.log_config import (logger, set_log_level_from_string, set_log_stream) from fedn.network.clients.connect import ConnectorClient, Status @@ -112,7 +113,7 @@ def _assign(self): while True: status, response = self.connector.assign() if status == Status.TryAgain: - logger.info(response) + logger.info("Assignment request failed. Retrying in 5 seconds.") time.sleep(5) continue if status == Status.Assigned: @@ -125,7 +126,10 @@ def _assign(self): logger.critical(response) sys.exit("Exiting: UnMatchedConfig") time.sleep(5) - + # If token was refreshed, update the config + if self.config['token'] != self.connector.token: + self.config['token'] = self.connector.token + self._add_grpc_metadata('authorization', f"{FEDN_AUTH_SCHEME} {self.config['token']}") logger.info("Assignment successfully received.") logger.info("Received combiner configuration: {}".format(client_config)) return client_config @@ -178,7 +182,9 @@ def _connect(self, client_config): host = client_config['host'] # Add host to gRPC metadata self._add_grpc_metadata('grpc-server', host) - logger.info("Client using metadata: {}.".format(self.metadata)) + if self.config['token']: + self._add_grpc_metadata('authorization', f"{FEDN_AUTH_SCHEME} {self.config['token']}") + logger.debug("Client using metadata: {}.".format(self.metadata)) port = client_config['port'] secure = False if client_config['fqdn'] is not None: @@ -370,21 +376,24 @@ def get_model_from_combiner(self, id, timeout=20): request.sender.name = self.name request.sender.role = fedn.WORKER - for part in self.modelStub.Download(request, metadata=self.metadata): - - if part.status == fedn.ModelStatus.IN_PROGRESS: - data.write(part.data) + try: + for part in self.modelStub.Download(request, metadata=self.metadata): - if part.status == fedn.ModelStatus.OK: - return data + if part.status == fedn.ModelStatus.IN_PROGRESS: + data.write(part.data) - if part.status == fedn.ModelStatus.FAILED: - return None + if part.status == fedn.ModelStatus.OK: + return data - if part.status == fedn.ModelStatus.UNKNOWN: - if time.time() - time_start >= timeout: + if part.status == fedn.ModelStatus.FAILED: return None - continue + + if part.status == fedn.ModelStatus.UNKNOWN: + if time.time() - time_start >= timeout: + return None + continue + except grpc.RpcError as e: + logger.critical(f"GRPC: An error occurred during model download: {e}") return data @@ -409,7 +418,10 @@ def send_model_to_combiner(self, model, id): bt.seek(0, 0) - result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata) + try: + result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata) + except grpc.RpcError as e: + logger.critical(f"GRPC: An error occurred during model upload: {e}") return result @@ -448,16 +460,26 @@ def _listen_to_task_stream(self): # Handle gRPC errors status_code = e.code() if status_code == grpc.StatusCode.UNAVAILABLE: - logger.warning("GRPC server unavailable during model update request stream. Retrying.") + logger.warning("GRPC TaskStream: server unavailable during model update request stream. Retrying.") # Retry after a delay time.sleep(5) + if status_code == grpc.StatusCode.UNAUTHENTICATED: + details = e.details() + if details == 'Token expired': + logger.warning("GRPC TaskStream: Token expired. Reconnecting.") + self.detach() + + if status_code == grpc.StatusCode.CANCELLED: + # Expected if the client is detached + logger.critical("GRPC TaskStream: Client detached from combiner. Atempting to reconnect.") + else: # Log the error and continue - logger.error(f"An error occurred during model update request stream: {e}") + logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {e}") except Exception as ex: # Handle other exceptions - logger.error(f"An error occurred during model update request stream: {ex}") + logger.error(f"GRPC TaskStream: An error occurred during model update request stream: {ex}") # Detach if not attached if not self._attached: @@ -654,12 +676,14 @@ def process_request(self): self.inbox.task_done() except queue.Empty: pass + except grpc.RpcError as e: + logger.critical(f"GRPC process_request: An error occurred during process request: {e}") def _handle_combiner_failure(self): """ Register failed combiner connection.""" self._missed_heartbeat += 1 if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']: - self.detach()() + self.detach() def _send_heartbeat(self, update_frequency=2.0): """Send a heartbeat to the combiner. @@ -677,8 +701,13 @@ def _send_heartbeat(self, update_frequency=2.0): self._missed_heartbeat = 0 except grpc.RpcError as e: status_code = e.code() - logger.warning("Client heartbeat: GRPC error, {}. Retrying.".format( - status_code.name)) + if status_code == grpc.StatusCode.UNAVAILABLE: + logger.warning("GRPC hearbeat: server unavailable during send heartbeat. Retrying.") + if status_code == grpc.StatusCode.UNAUTHENTICATED: + details = e.details() + if details == 'Token expired': + logger.warning("GRPC hearbeat: Token expired. Reconnecting.") + self.detach() logger.debug(e) self._handle_combiner_failure() @@ -714,7 +743,16 @@ def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None, self.logs.append( "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) - _ = self.connectorStub.SendStatus(status, metadata=self.metadata) + try: + _ = self.connectorStub.SendStatus(status, metadata=self.metadata) + except grpc.RpcError as e: + status_code = e.code() + if status_code == grpc.StatusCode.UNAVAILABLE: + logger.warning("GRPC SendStatus: server unavailable during send status.") + if status_code == grpc.StatusCode.UNAUTHENTICATED: + details = e.details() + if details == 'Token expired': + logger.warning("GRPC SendStatus: Token expired.") def run(self): """ Run the client. """ diff --git a/fedn/fedn/network/clients/connect.py b/fedn/fedn/network/clients/connect.py index 478844d26..2e9345ebb 100644 --- a/fedn/fedn/network/clients/connect.py +++ b/fedn/fedn/network/clients/connect.py @@ -8,6 +8,9 @@ import requests +from fedn.common.config import (FEDN_AUTH_REFRESH_TOKEN, + FEDN_AUTH_REFRESH_TOKEN_URI, FEDN_AUTH_SCHEME, + FEDN_CUSTOM_URL_PREFIX) from fedn.common.log_config import logger @@ -77,12 +80,11 @@ def assign(self): try: retval = None payload = {'client_id': self.name, 'preferred_combiner': self.preferred_combiner} - - retval = requests.post(self.connect_string + '/add_client', + retval = requests.post(self.connect_string + FEDN_CUSTOM_URL_PREFIX + '/add_client', json=payload, verify=self.verify, allow_redirects=True, - headers={'Authorization': 'Token {}'.format(self.token)}) + headers={'Authorization': f"{FEDN_AUTH_SCHEME} {self.token}"}) except Exception as e: print('***** {}'.format(e), flush=True) return Status.Unassigned, {} @@ -93,6 +95,16 @@ def assign(self): return Status.UnMatchedConfig, reason if retval.status_code == 401: + if 'message' in retval.json(): + reason = retval.json()['message'] + logger.warning(reason) + if reason == 'Token expired': + status_code = self.refresh_token() + if status_code >= 200 and status_code < 204: + logger.info("Token refreshed.") + return Status.TryAgain, reason + else: + return Status.UnAuthorized, "Could not refresh token" reason = "Unauthorized connection to reducer, make sure the correct token is set" return Status.UnAuthorized, reason @@ -115,3 +127,21 @@ def assign(self): return Status.Assigned, retval.json() return Status.Unassigned, None + + def refresh_token(self): + """ + Refresh client token. + + :return: Tuple with assingment status, combiner connection information if sucessful, else None. + :rtype: tuple(:class:`fedn.network.clients.connect.Status`, str) + """ + if not FEDN_AUTH_REFRESH_TOKEN_URI or not FEDN_AUTH_REFRESH_TOKEN: + logger.error("No refresh token URI/Token set, cannot refresh token.") + return 401 + + payload = requests.post(FEDN_AUTH_REFRESH_TOKEN_URI, + verify=self.verify, + allow_redirects=True, + json={'refresh': FEDN_AUTH_REFRESH_TOKEN}) + self.token = payload.json()['access'] + return payload.status_code diff --git a/fedn/fedn/network/clients/package.py b/fedn/fedn/network/clients/package.py index d56296de8..cc98eae94 100644 --- a/fedn/fedn/network/clients/package.py +++ b/fedn/fedn/network/clients/package.py @@ -9,6 +9,7 @@ import requests import yaml +from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_CUSTOM_URL_PREFIX from fedn.common.log_config import logger from fedn.utils.checksum import sha from fedn.utils.dispatcher import Dispatcher @@ -52,13 +53,13 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): else: scheme = "http" if port: - path = f"{scheme}://{host}:{port}/download_package" + path = f"{scheme}://{host}:{port}{FEDN_CUSTOM_URL_PREFIX}/download_package" else: - path = f"{scheme}://{host}/download_package" + path = f"{scheme}://{host}{FEDN_CUSTOM_URL_PREFIX}/download_package" if name: path = path + "?name={}".format(name) - with requests.get(path, stream=True, verify=False, headers={'Authorization': 'Token {}'.format(token)}) as r: + with requests.get(path, stream=True, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: if 200 <= r.status_code < 204: params = cgi.parse_header( @@ -73,13 +74,13 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): for chunk in r.iter_content(chunk_size=8192): f.write(chunk) if port: - path = f"{scheme}://{host}:{port}/get_package_checksum" + path = f"{scheme}://{host}:{port}{FEDN_CUSTOM_URL_PREFIX}/get_package_checksum" else: - path = f"{scheme}://{host}/get_package_checksum" + path = f"{scheme}://{host}{FEDN_CUSTOM_URL_PREFIX}/get_package_checksum" if name: path = path + "?name={}".format(name) - with requests.get(path, verify=False, headers={'Authorization': 'Token {}'.format(token)}) as r: + with requests.get(path, verify=False, headers={'Authorization': f'{FEDN_AUTH_SCHEME} {token}'}) as r: if 200 <= r.status_code < 204: data = r.json() diff --git a/fedn/fedn/network/combiner/connect.py b/fedn/fedn/network/combiner/connect.py index 4c1c94266..7dc388261 100644 --- a/fedn/fedn/network/combiner/connect.py +++ b/fedn/fedn/network/combiner/connect.py @@ -5,6 +5,7 @@ # # import enum +import os import requests @@ -72,10 +73,14 @@ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, self.myhost = myhost self.myport = myport self.token = token + self.token_scheme = os.environ.get('FEDN_AUTH_SCHEME', 'Bearer') self.name = name self.secure = secure self.verify = verify + if not self.token: + self.token = os.environ.get('FEDN_AUTH_TOKEN', None) + # for https we assume a an ingress handles permanent redirect (308) self.prefix = "http://" if port: @@ -101,10 +106,11 @@ def announce(self): "port": self.myport, "secure_grpc": self.secure } + url_prefix = os.environ.get('FEDN_CUSTOM_URL_PREFIX', '') try: - retval = requests.post(self.connect_string + '/add_combiner', json=payload, + retval = requests.post(self.connect_string + url_prefix + '/add_combiner', json=payload, verify=self.verify, - headers={'Authorization': 'Token {}'.format(self.token)}) + headers={'Authorization': f'{self.token_scheme} {self.token}'}) except Exception: return Status.Unassigned, {} diff --git a/fedn/fedn/network/grpc/auth.py b/fedn/fedn/network/grpc/auth.py new file mode 100644 index 000000000..d879cd812 --- /dev/null +++ b/fedn/fedn/network/grpc/auth.py @@ -0,0 +1,95 @@ +import grpc +import jwt + +from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_JWT_ALGORITHM, SECRET_KEY +from fedn.common.log_config import logger +from fedn.network.api.auth import check_custom_claims + +ENDPOINT_ROLES_MAPPING = { + '/fedn.Combiner/TaskStream': ['client'], + '/fedn.Combiner/SendModelUpdate': ['client'], + '/fedn.Combiner/SendModelValidation': ['client'], + '/fedn.Connector/SendHeartbeat': ['client'], + '/fedn.Connector/SendStatus': ['client'], + '/fedn.ModelService/Download': ['client'], + '/fedn.ModelService/Upload': ['client'], + '/fedn.Control/Start': ['controller'], + '/fedn.Control/Stop': ['controller'], + '/fedn.Control/FlushAggregationQueue': ['controller'], + '/fedn.Control/SetAggregator': ['controller'], +} + +ENDPOINT_WHITELIST = [ + '/fedn.Connector/AcceptingClients', + '/fedn.Connector/ListActiveClients', + '/fedn.Control/Start', + '/fedn.Control/Stop', + '/fedn.Control/FlushAggregationQueue', + '/fedn.Control/SetAggregator', +] + +USER_AGENT_WHITELIST = [ + 'grpc_health_probe' +] + + +def check_role_claims(payload, endpoint): + user_role = payload.get('role', '') + + # Perform endpoint-specific RBAC check + allowed_roles = ENDPOINT_ROLES_MAPPING.get(endpoint) + if allowed_roles and user_role not in allowed_roles: + return False + return True + + +def _unary_unary_rpc_terminator(code, details): + def terminate(ignored_request, context): + context.abort(code, details) + + return grpc.unary_unary_rpc_method_handler(terminate) + + +class JWTInterceptor(grpc.ServerInterceptor): + def __init__(self): + pass + + def intercept_service(self, continuation, handler_call_details): + # Pass if no secret key is set + if not SECRET_KEY: + return continuation(handler_call_details) + metadata = dict(handler_call_details.invocation_metadata) + # Pass whitelisted methods + if handler_call_details.method in ENDPOINT_WHITELIST: + return continuation(handler_call_details) + # Pass if the request comes from whitelisted user agents + user_agent = metadata.get('user-agent').split(' ')[0] + if user_agent in USER_AGENT_WHITELIST: + return continuation(handler_call_details) + + token = metadata.get('authorization') + if token is None: + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token is missing') + + if not token.startswith(FEDN_AUTH_SCHEME): + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, f'Invalid token scheme, expected {FEDN_AUTH_SCHEME}') + + token = token.split(' ')[1] + + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[FEDN_JWT_ALGORITHM]) + + if not check_role_claims(payload, handler_call_details.method): + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + + if not check_custom_claims(payload): + return _unary_unary_rpc_terminator(grpc.StatusCode.PERMISSION_DENIED, 'Insufficient permissions') + + return continuation(handler_call_details) + except jwt.InvalidTokenError: + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Invalid token') + except jwt.ExpiredSignatureError: + return _unary_unary_rpc_terminator(grpc.StatusCode.UNAUTHENTICATED, 'Token expired') + except Exception as e: + logger.error(str(e)) + return _unary_unary_rpc_terminator(grpc.StatusCode.UNKNOWN, str(e)) diff --git a/fedn/fedn/network/grpc/server.py b/fedn/fedn/network/grpc/server.py index 59ed6b1ba..916b4756e 100644 --- a/fedn/fedn/network/grpc/server.py +++ b/fedn/fedn/network/grpc/server.py @@ -6,6 +6,7 @@ import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.log_config import (logger, set_log_level_from_string, set_log_stream) +from fedn.network.grpc.auth import JWTInterceptor class Server: @@ -16,7 +17,7 @@ def __init__(self, servicer, modelservicer, config): set_log_level_from_string(config.get('verbosity', "INFO")) set_log_stream(config.get('logfile', None)) - self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=350)) + self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=350), interceptors=[JWTInterceptor()]) self.certificate = None self.health_servicer = health.HealthServicer() diff --git a/fedn/fedn/tests/__init__.py b/fedn/fedn/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000