diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py index 6f3b993b6..fd147d9f5 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/fedn/network/api/interface.py @@ -784,6 +784,45 @@ def get_model_trail(self): {"success": False, "message": "No model trail available."} ) + def get_model_ancestors(self, model_id: str, limit: str = None): + """Get the model ancestors for a given model. + + :param model_id: The model id to get the model ancestors for. + :type model_id: str + :param limit: The number of ancestors to return. + :type limit: str + :return: The model ancestors for the given model as a json response. + :rtype: :class:`flask.Response` + """ + if model_id is None: + return jsonify( + {"success": False, "message": "No model id provided."} + ) + + limit = int(limit) if limit is not None else 10 # if limit is None, default to 10 + + response = self.statestore.get_model_ancestors(model_id, limit) + if response: + + arr = [] + + for element in response: + obj = { + "model": element["model"], + "committed_at": element["committed_at"], + "session_id": element["session_id"], + "parent_model": element["parent_model"], + } + arr.append(obj) + + result = {"result": arr} + + return jsonify(result) + else: + return jsonify( + {"success": False, "message": "No model ancestors available."} + ) + def get_all_rounds(self): """Get all rounds. diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index 0b385c566..304e7c57c 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -29,6 +29,21 @@ def get_model_trail(): return api.get_model_trail() +@app.route("/get_model_ancestors", methods=["GET"]) +def get_model_ancestors(): + """Get the ancestors of a model. + param: model: The model id to get the ancestors for. + type: model: str + type: limit: int + return: The a list of model objects that the model derives from. + rtype: json + """ + model = request.args.get("model", None) + limit = request.args.get("limit", None) + + return api.get_model_ancestors(model, limit) + + @app.route("/list_models", methods=["GET"]) def list_models(): """Get models from the statestore. diff --git a/fedn/fedn/network/storage/statestore/mongostatestore.py b/fedn/fedn/network/storage/statestore/mongostatestore.py index fe6f93c51..b735d7635 100644 --- a/fedn/fedn/network/storage/statestore/mongostatestore.py +++ b/fedn/fedn/network/storage/statestore/mongostatestore.py @@ -185,10 +185,21 @@ def set_latest_model(self, model_id, session_id=None): committed_at = datetime.now() + current_model = self.model.find_one({"key": "current_model"}) + + parent_model = None + + # if session_id is set the it means the model is generated from a session + # and we need to set the parent model + # if not the model is uploaded by the user and we don't need to set the parent model + if session_id is not None: + parent_model = current_model["model"] if current_model and "model" in current_model else None + self.model.insert_one( { "key": "models", "model": model_id, + "parent_model": parent_model, "session_id": session_id, "committed_at": committed_at, } @@ -534,6 +545,33 @@ def get_model_trail(self): except (KeyError, IndexError): return None + def get_model_ancestors(self, model_id: str, limit: int): + """Get the model ancestors. + + :param model_id: The model id. + :type model_id: str + :param limit: The maximum number of ancestors to return. + :type limit: int + :return: List of model ancestors. + :rtype: list + """ + + model = self.model.find_one({"key": "models", "model": model_id}) + current_model_id = model["parent_model"] if model is not None else None + result = [] + + for _ in range(limit): + if current_model_id is None: + break + + model = self.model.find_one({"key": "models", "model": current_model_id}) + + if model is not None: + result.append(model) + current_model_id = model["parent_model"] + + return result + def get_events(self, **kwargs): """Get events from the database.