diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py index fd147d9f5..4370c0428 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/fedn/network/api/interface.py @@ -768,6 +768,24 @@ def get_models(self, session_id: str = None, limit: str = None, skip: str = None return jsonify(result) + def get_model(self, model_id: str): + result = self.statestore.get_model(model_id) + + if result is None: + return ( + jsonify({"success": False, "message": "No model found."}), + 404, + ) + + payload = { + "committed_at": result["committed_at"], + "parent_model": result["parent_model"], + "model": result["model"], + "session_id": result["session_id"], + } + + return jsonify(payload) + def get_model_trail(self): """Get the model trail for a given session. @@ -799,12 +817,12 @@ def get_model_ancestors(self, model_id: str, limit: str = None): {"success": False, "message": "No model id provided."} ) - limit = int(limit) if limit is not None else 10 # if limit is None, default to 10 + limit: int = int(limit) if limit is not None else 10 # if limit is None, default to 10 response = self.statestore.get_model_ancestors(model_id, limit) if response: - arr = [] + arr: list = [] for element in response: obj = { @@ -823,6 +841,47 @@ def get_model_ancestors(self, model_id: str, limit: str = None): {"success": False, "message": "No model ancestors available."} ) + def get_model_descendants(self, model_id: str, limit: str = None): + """Get the model descendants for a given model. + + :param model_id: The model id to get the model descendants for. + :type model_id: str + :param limit: The number of descendants to return. + :type limit: str + :return: The model descendants for the given model as a json response. + :rtype: :class:`flask.Response` + """ + + if model_id is None: + return jsonify( + {"success": False, "message": "No model id provided."} + ) + + limit: int = int(limit) if limit is not None else 10 + + response: list = self.statestore.get_model_descendants(model_id, limit) + + if response: + + arr: list = [] + + for element in response: + obj = { + "model": element["model"], + "committed_at": element["committed_at"], + "session_id": element["session_id"], + "parent_model": element["parent_model"], + } + arr.append(obj) + + result = {"result": arr} + + return jsonify(result) + else: + return jsonify( + {"success": False, "message": "No model descendants available."} + ) + def get_all_rounds(self): """Get all rounds. diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index 304e7c57c..0f2cc8177 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -34,8 +34,9 @@ def get_model_ancestors(): """Get the ancestors of a model. param: model: The model id to get the ancestors for. type: model: str + param: limit: The maximum number of ancestors to return. type: limit: int - return: The a list of model objects that the model derives from. + return: A list of model objects that the model derives from. rtype: json """ model = request.args.get("model", None) @@ -44,6 +45,22 @@ def get_model_ancestors(): return api.get_model_ancestors(model, limit) +@app.route("/get_model_descendants", methods=["GET"]) +def get_model_descendants(): + """Get the ancestors of a model. + param: model: The model id to get the child for. + type: model: str + param: limit: The maximum number of descendants to return. + type: limit: int + return: A list of model objects that are descendents of the provided model id. + rtype: json + """ + model = request.args.get("model", None) + limit = request.args.get("limit", None) + + return api.get_model_descendants(model, limit) + + @app.route("/list_models", methods=["GET"]) def list_models(): """Get models from the statestore. @@ -65,6 +82,21 @@ def list_models(): return api.get_models(session_id, limit, skip, include_active) +@app.route("/get_model", methods=["GET"]) +def get_model(): + """Get a model from the statestore. + param: model: The model id to get. + type: model: str + return: The model as a json object. + rtype: json + """ + model = request.args.get("model", None) + if model is None: + return jsonify({"success": False, "message": "Missing model id."}), 400 + + return api.get_model(model) + + @app.route("/delete_model_trail", methods=["GET", "POST"]) def delete_model_trail(): """Delete the model trail for a given session. diff --git a/fedn/fedn/network/storage/statestore/mongostatestore.py b/fedn/fedn/network/storage/statestore/mongostatestore.py index b735d7635..5e229707e 100644 --- a/fedn/fedn/network/storage/statestore/mongostatestore.py +++ b/fedn/fedn/network/storage/statestore/mongostatestore.py @@ -572,6 +572,45 @@ def get_model_ancestors(self, model_id: str, limit: int): return result + def get_model_descendants(self, model_id: str, limit: int): + """Get the model descendants. + + :param model_id: The model id. + :type model_id: str + :param limit: The maximum number of descendants to return. + :type limit: int + :return: List of model descendants. + :rtype: list + """ + + model: object = self.model.find_one({"key": "models", "model": model_id}) + current_model_id: str = model["model"] if model is not None else None + result: list = [] + + for _ in range(limit): + if current_model_id is None: + break + + model: str = self.model.find_one({"key": "models", "parent_model": current_model_id}) + + if model is not None: + result.append(model) + current_model_id = model["model"] + + result.reverse() + + return result + + def get_model(self, model_id): + """Get model with id. + + :param model_id: id of model to get + :type model_id: str + :return: model with id + :rtype: ObjectId + """ + return self.model.find_one({"key": "models", "model": model_id}) + def get_events(self, **kwargs): """Get events from the database.