Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-871 | Enable updating models via the API #632

Merged
merged 2 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 151 additions & 29 deletions fedn/network/api/v1/model_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from flask import Blueprint, jsonify, request, send_file

from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.v1.shared import api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers, mdb
from fedn.network.api.shared import modelstorage_config
from fedn.network.api.v1.shared import api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers, mdb
from fedn.network.storage.s3.base import RepositoryBase
from fedn.network.storage.s3.miniorepository import MINIORepository
from fedn.network.storage.statestore.stores.model_store import ModelStore
Expand Down Expand Up @@ -117,8 +117,8 @@ def get_models():
response = {"count": models["count"], "result": result}

return jsonify(response), 200
except Exception as e:
return jsonify({"message": str(e)}), 500
except Exception:
return jsonify({"message": "An unexpected error occurred"}), 500


@bp.route("/list", methods=["POST"])
Expand Down Expand Up @@ -202,8 +202,8 @@ def list_models():
response = {"count": models["count"], "result": result}

return jsonify(response), 200
except Exception as e:
return jsonify({"message": str(e)}), 500
except Exception:
return jsonify({"message": "An unexpected error occurred"}), 500


@bp.route("/count", methods=["GET"])
Expand Down Expand Up @@ -250,8 +250,8 @@ def get_models_count():
count = model_store.count(**kwargs)
response = count
return jsonify(response), 200
except Exception as e:
return jsonify({"message": str(e)}), 500
except Exception:
return jsonify({"message": "An unexpected error occurred"}), 500


@bp.route("/count", methods=["POST"])
Expand Down Expand Up @@ -302,8 +302,8 @@ def models_count():
count = model_store.count(**kwargs)
response = count
return jsonify(response), 200
except Exception as e:
return jsonify({"message": str(e)}), 500
except Exception:
return jsonify({"message": "An unexpected error occurred"}), 500


@bp.route("/<string:id>", methods=["GET"])
Expand Down Expand Up @@ -346,10 +346,132 @@ def get_model(id: str):
response = model

return jsonify(response), 200
except EntityNotFound as e:
return jsonify({"message": str(e)}), 404
except Exception as e:
return jsonify({"message": str(e)}), 500
except EntityNotFound:
return jsonify({"message": f"Entity with id: {id} not found"}), 404
except Exception:
return jsonify({"message": "An unexpected error occurred"}), 500


@bp.route("/<string:id>", methods=["PATCH"])
@jwt_auth_required(role="admin")
def patch_model(id: str):
"""Patch model
Updates a model based on the provided id. Only the fields that are present in the request will be updated.
---
tags:
- Models
parameters:
- name: id
in: path
required: true
type: string
description: The id or model property of the model
- name: model
in: body
required: true
type: object
description: The model data to update
responses:
200:
description: The updated model
schema:
$ref: '#/definitions/Model'
404:
description: The model was not found
schema:
type: object
properties:
message:
type: string
500:
description: An error occurred
schema:
type: object
properties:
message:
type: string
"""
try:
model = model_store.get(id, use_typing=False)

data = request.get_json()
_id = model["id"]

# Update the model with the new data
# Only update the fields that are present in the request
for key, value in data.items():
if key in ["_id", "model"]:
continue
model[key] = value

success = model_store.update(_id, model)

if success:
response = model
return jsonify(response), 200

return jsonify({"message": "Failed to update model"}), 500
except EntityNotFound:
return jsonify({"message": f"Entity with id: {id} not found"}), 404
except Exception:
return jsonify({"message": "An unexpected error occurred"}), 500


@bp.route("/<string:id>", methods=["PUT"])
@jwt_auth_required(role="admin")
def put_model(id: str):
"""Patch model
Updates a model based on the provided id. All fields will be updated with the new data.
---
tags:
- Models
parameters:
- name: id
in: path
required: true
type: string
description: The id or model property of the model
- name: model
in: body
required: true
type: object
description: The model data to update
responses:
200:
description: The updated model
schema:
$ref: '#/definitions/Model'
404:
description: The model was not found
schema:
type: object
properties:
message:
type: string
500:
description: An error occurred
schema:
type: object
properties:
message:
type: string
"""
try:
model = model_store.get(id, use_typing=False)
data = request.get_json()
_id = model["id"]

success = model_store.update(_id, data)

if success:
response = model
return jsonify(response), 200

return jsonify({"message": "Failed to update model"}), 500
except EntityNotFound:
return jsonify({"message": f"Entity with id: {id} not found"}), 404
except Exception:
return jsonify({"message": "An unexpected error occurred"}), 500


@bp.route("/<string:id>/descendants", methods=["GET"])
Expand Down Expand Up @@ -400,10 +522,10 @@ def get_descendants(id: str):
response = descendants

return jsonify(response), 200
except EntityNotFound as e:
return jsonify({"message": str(e)}), 404
except Exception as e:
return jsonify({"message": str(e)}), 500
except EntityNotFound:
return jsonify({"message": f"Entity with id: {id} not found"}), 404
except Exception:
return jsonify({"message": "An unexpected error occurred"}), 500


@bp.route("/<string:id>/ancestors", methods=["GET"])
Expand Down Expand Up @@ -469,10 +591,10 @@ def get_ancestors(id: str):
response = ancestors

return jsonify(response), 200
except EntityNotFound as e:
return jsonify({"message": str(e)}), 404
except Exception as e:
return jsonify({"message": str(e)}), 500
except EntityNotFound:
return jsonify({"message": f"Entity with id: {id} not found"}), 404
except Exception:
return jsonify({"message": "An unexpected error occurred"}), 500


@bp.route("/<string:id>/download", methods=["GET"])
Expand Down Expand Up @@ -517,10 +639,10 @@ def download(id: str):
return send_file(file, as_attachment=True, download_name=model_id)
else:
return jsonify({"message": "No model storage configured"}), 500
except EntityNotFound as e:
return jsonify({"message": str(e)}), 404
except Exception as e:
return jsonify({"message": str(e)}), 500
except EntityNotFound:
return jsonify({"message": f"Entity with id: {id} not found"}), 404
except Exception:
return jsonify({"message": "An unexpected error occurred"}), 500


@bp.route("/<string:id>/parameters", methods=["GET"])
Expand Down Expand Up @@ -581,7 +703,7 @@ def get_parameters(id: str):
return jsonify(array=weights), 200
else:
return jsonify({"message": "No model storage configured"}), 500
except EntityNotFound as e:
return jsonify({"message": str(e)}), 404
except Exception as e:
return jsonify({"message": str(e)}), 500
except EntityNotFound:
return jsonify({"message": f"Entity with id: {id} not found"}), 404
except Exception:
return jsonify({"message": "An unexpected error occurred"}), 500
20 changes: 18 additions & 2 deletions fedn/network/storage/statestore/stores/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,24 @@ def get(self, id: str, use_typing: bool = False) -> Model:

return Model.from_dict(document) if use_typing else from_document(document)

def update(self, id: str, item: Model) -> bool:
raise NotImplementedError("Update not implemented for ModelStore")
def _validate(self, item: Model) -> Tuple[bool, str]:
if "model" not in item or not item["model"]:
return False, "Model is required"

return True, ""

def _complement(self, item: Model) -> Model:
if "key" not in item or item["key"] is None:
item["key"] = "models"

def update(self, id: str, item: Model) -> Tuple[bool, Any]:
valid, message = self._validate(item)
if not valid:
return False, message

self._complement(item)

return super().update(id, item)

def add(self, item: Model)-> Tuple[bool, Any]:
raise NotImplementedError("Add not implemented for ModelStore")
Expand Down
12 changes: 10 additions & 2 deletions fedn/network/storage/statestore/stores/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,16 @@ def get(self, id: str, use_typing: bool = False) -> T:

return from_document(document) if not use_typing else document

def update(self, id: str, item: T) -> bool:
pass
def update(self, id: str, item: T) -> Tuple[bool, Any]:
try:
result = self.database[self.collection].update_one({"_id": ObjectId(id)}, {"$set": item})
if result.modified_count == 1:
document = self.database[self.collection].find_one({"_id": ObjectId(id)})
return True, from_document(document)
else:
return False, "Entity not found"
except Exception as e:
return False, str(e)

def add(self, item: T) -> Tuple[bool, Any]:
try:
Expand Down
Loading