Skip to content

Commit

Permalink
Added get_model_ancestors endpoint to api
Browse files Browse the repository at this point in the history
  • Loading branch information
niklastheman committed Jan 5, 2024
1 parent e23421a commit 58e1596
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 0 deletions.
39 changes: 39 additions & 0 deletions fedn/fedn/network/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,45 @@ def get_model_trail(self):
{"success": False, "message": "No model trail available."}
)

def get_model_ancestors(self, model_id: str, limit: str = None):
"""Get the model ancestors for a given model.
:param model_id: The model id to get the model ancestors for.
:type model_id: str
:param limit: The number of ancestors to return.
:type limit: str
:return: The model ancestors for the given model as a json response.
:rtype: :class:`flask.Response`
"""
if model_id is None:
return jsonify(
{"success": False, "message": "No model id provided."}
)

limit = int(limit) if limit is not None else 10 # if limit is None, default to 10

response = self.statestore.get_model_ancestors(model_id, limit)
if response:

arr = []

for element in response:
obj = {
"model": element["model"],
"committed_at": element["committed_at"],
"session_id": element["session_id"],
"parent_model": element["parent_model"],
}
arr.append(obj)

result = {"result": arr}

return jsonify(result)
else:
return jsonify(
{"success": False, "message": "No model ancestors available."}
)

def get_all_rounds(self):
"""Get all rounds.
Expand Down
15 changes: 15 additions & 0 deletions fedn/fedn/network/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ def get_model_trail():
return api.get_model_trail()


@app.route("/get_model_ancestors", methods=["GET"])
def get_model_ancestors():
"""Get the ancestors of a model.
param: model: The model id to get the ancestors for.
type: model: str
type: limit: int
return: The a list of model objects that the model derives from.
rtype: json
"""
model = request.args.get("model", None)
limit = request.args.get("limit", None)

return api.get_model_ancestors(model, limit)


@app.route("/list_models", methods=["GET"])
def list_models():
"""Get models from the statestore.
Expand Down
38 changes: 38 additions & 0 deletions fedn/fedn/network/storage/statestore/mongostatestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,21 @@ def set_latest_model(self, model_id, session_id=None):

committed_at = datetime.now()

current_model = self.model.find_one({"key": "current_model"})

parent_model = None

# if session_id is set the it means the model is generated from a session
# and we need to set the parent model
# if not the model is uploaded by the user and we don't need to set the parent model
if session_id is not None:
parent_model = current_model["model"] if current_model and "model" in current_model else None

self.model.insert_one(
{
"key": "models",
"model": model_id,
"parent_model": parent_model,
"session_id": session_id,
"committed_at": committed_at,
}
Expand Down Expand Up @@ -534,6 +545,33 @@ def get_model_trail(self):
except (KeyError, IndexError):
return None

def get_model_ancestors(self, model_id: str, limit: int):
"""Get the model ancestors.
:param model_id: The model id.
:type model_id: str
:param limit: The maximum number of ancestors to return.
:type limit: int
:return: List of model ancestors.
:rtype: list
"""

model = self.model.find_one({"key": "models", "model": model_id})
current_model_id = model["parent_model"] if model is not None else None
result = []

for _ in range(limit):
if current_model_id is None:
break

model = self.model.find_one({"key": "models", "model": current_model_id})

if model is not None:
result.append(model)
current_model_id = model["parent_model"]

return result

def get_events(self, **kwargs):
"""Get events from the database.
Expand Down

0 comments on commit 58e1596

Please sign in to comment.