From 2c7362d24d7634dfaef9375424cac97b3764ea80 Mon Sep 17 00:00:00 2001 From: Niklas Date: Thu, 20 Jun 2024 11:09:33 +0200 Subject: [PATCH] added PUT and PUSH to /sessions API + changed validation of sessions to handle old typ as well --- fedn/network/api/v1/model_routes.py | 10 +- fedn/network/api/v1/session_routes.py | 121 ++++++++++++++++++ .../statestore/stores/session_store.py | 23 +++- 3 files changed, 142 insertions(+), 12 deletions(-) diff --git a/fedn/network/api/v1/model_routes.py b/fedn/network/api/v1/model_routes.py index 5b2ebf925..aaea9d733 100644 --- a/fedn/network/api/v1/model_routes.py +++ b/fedn/network/api/v1/model_routes.py @@ -404,13 +404,13 @@ def patch_model(id: str): continue model[key] = value - success = model_store.update(_id, model) + success, message = model_store.update(_id, model) if success: response = model return jsonify(response), 200 - return jsonify({"message": "Failed to update model"}), 500 + return jsonify({"message": f"Failed to update model: {message}"}), 500 except EntityNotFound: return jsonify({"message": f"Entity with id: {id} not found"}), 404 except Exception: @@ -420,7 +420,7 @@ def patch_model(id: str): @bp.route("/", methods=["PUT"]) @jwt_auth_required(role="admin") def put_model(id: str): - """Patch model + """Put model Updates a model based on the provided id. All fields will be updated with the new data. --- tags: @@ -461,13 +461,13 @@ def put_model(id: str): data = request.get_json() _id = model["id"] - success = model_store.update(_id, data) + success, message = model_store.update(_id, data) if success: response = model return jsonify(response), 200 - return jsonify({"message": "Failed to update model"}), 500 + return jsonify({"message": f"Failed to update model: {message}"}), 500 except EntityNotFound: return jsonify({"message": f"Entity with id: {id} not found"}), 404 except Exception: diff --git a/fedn/network/api/v1/session_routes.py b/fedn/network/api/v1/session_routes.py index f3f81fac0..97d4904af 100644 --- a/fedn/network/api/v1/session_routes.py +++ b/fedn/network/api/v1/session_routes.py @@ -389,3 +389,124 @@ def start_session(): return jsonify({"message": "Session started"}), 200 except Exception: return jsonify({"message": "An unexpected error occurred"}), 500 + +@bp.route("/", methods=["PATCH"]) +@jwt_auth_required(role="admin") +def patch_session(id: str): + """Patch session + Updates a session based on the provided id. Only the fields that are present in the request will be updated. + --- + tags: + - Sessions + parameters: + - name: id + in: path + required: true + type: string + description: The id or session property of the session + - name: session + in: body + required: true + type: object + description: The session data to update + responses: + 200: + description: The updated session + schema: + $ref: '#/definitions/Session' + 404: + description: The session was not found + schema: + type: object + properties: + message: + type: string + 500: + description: An error occurred + schema: + type: object + properties: + message: + type: string + """ + try: + session = session_store.get(id, use_typing=False) + + data = request.get_json() + _id = session["id"] + + # Update the session with the new data + # Only update the fields that are present in the request + for key, value in data.items(): + if key in ["_id", "session_id"]: + continue + session[key] = value + + success, message = session_store.update(_id, session) + + if success: + response = session + return jsonify(response), 200 + + return jsonify({"message": f"Failed to update session: {message}"}), 500 + except EntityNotFound: + return jsonify({"message": f"Entity with id: {id} not found"}), 404 + except Exception: + return jsonify({"message": "An unexpected error occurred"}), 500 + + +@bp.route("/", methods=["PUT"]) +@jwt_auth_required(role="admin") +def put_session(id: str): + """Put session + Updates a session based on the provided id. All fields will be updated with the new data. + --- + tags: + - Sessions + parameters: + - name: id + in: path + required: true + type: string + description: The id or session property of the session + - name: session + in: body + required: true + type: object + description: The session data to update + responses: + 200: + description: The updated session + schema: + $ref: '#/definitions/Session' + 404: + description: The session was not found + schema: + type: object + properties: + message: + type: string + 500: + description: An error occurred + schema: + type: object + properties: + message: + type: string + """ + try: + session = session_store.get(id, use_typing=False) + data = request.get_json() + _id = session["id"] + + success, message = session_store.update(_id, data) + + if success: + response = session + return jsonify(response), 200 + + return jsonify({"message": f"Failed to update session: {message}"}), 500 + except EntityNotFound: + return jsonify({"message": f"Entity with id: {id} not found"}), 404 + except Exception: + return jsonify({"message": "An unexpected error occurred"}), 500 \ No newline at end of file diff --git a/fedn/network/storage/statestore/stores/session_store.py b/fedn/network/storage/statestore/stores/session_store.py index b25a34319..f2675c912 100644 --- a/fedn/network/storage/statestore/stores/session_store.py +++ b/fedn/network/storage/statestore/stores/session_store.py @@ -38,7 +38,7 @@ def _validate_session_config(self, session_config: dict) -> Tuple[bool, str]: if "round_timeout" not in session_config: return False, "session_config.round_timeout is required" - if not isinstance(session_config["round_timeout"], int): + if not isinstance(session_config["round_timeout"], (int, float)): return False, "session_config.round_timeout must be an integer" if "buffer_size" not in session_config: @@ -82,10 +82,15 @@ def _validate_session_config(self, session_config: dict) -> Tuple[bool, str]: def _validate(self, item: Session) -> Tuple[bool, str]: if "session_config" not in item or item["session_config"] is None: return False, "session_config is required" - elif not isinstance(item["session_config"], dict): - return False, "session_config must be a dict" - session_config = item["session_config"] + session_config = None + + if isinstance(item["session_config"], dict): + session_config = item["session_config"] + elif isinstance(item["session_config"], list): + session_config = item["session_config"][0] + else: + return False, "session_config must be a dict" return self._validate_session_config(session_config) @@ -117,10 +122,14 @@ def get(self, id: str, use_typing: bool = False) -> Session: return Session.from_dict(document) if use_typing else from_document(document) - def update(self, id: str, item: Session) -> bool: - raise NotImplementedError("Update not implemented for SessionStore") + def update(self, id: str, item: Session) -> Tuple[bool, Any]: + valid, message = self._validate(item) + if not valid: + return False, message + + return super().update(id, item) - def add(self, item: Session)-> Tuple[bool, Any]: + def add(self, item: Session) -> Tuple[bool, Any]: """Add an entity param item: The entity to add type: Session