Skip to content

Commit

Permalink
Added get_model and get_model_descendants to api
Browse files Browse the repository at this point in the history
  • Loading branch information
niklastheman committed Jan 8, 2024
1 parent 58e1596 commit 771b3c7
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 3 deletions.
63 changes: 61 additions & 2 deletions fedn/fedn/network/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,24 @@ def get_models(self, session_id: str = None, limit: str = None, skip: str = None

return jsonify(result)

def get_model(self, model_id: str):
result = self.statestore.get_model(model_id)

if result is None:
return (
jsonify({"success": False, "message": "No model found."}),
404,
)

payload = {
"committed_at": result["committed_at"],
"parent_model": result["parent_model"],
"model": result["model"],
"session_id": result["session_id"],
}

return jsonify(payload)

def get_model_trail(self):
"""Get the model trail for a given session.
Expand Down Expand Up @@ -799,12 +817,12 @@ def get_model_ancestors(self, model_id: str, limit: str = None):
{"success": False, "message": "No model id provided."}
)

limit = int(limit) if limit is not None else 10 # if limit is None, default to 10
limit: int = int(limit) if limit is not None else 10 # if limit is None, default to 10

response = self.statestore.get_model_ancestors(model_id, limit)
if response:

arr = []
arr: list = []

for element in response:
obj = {
Expand All @@ -823,6 +841,47 @@ def get_model_ancestors(self, model_id: str, limit: str = None):
{"success": False, "message": "No model ancestors available."}
)

def get_model_descendants(self, model_id: str, limit: str = None):
"""Get the model descendants for a given model.
:param model_id: The model id to get the model descendants for.
:type model_id: str
:param limit: The number of descendants to return.
:type limit: str
:return: The model descendants for the given model as a json response.
:rtype: :class:`flask.Response`
"""

if model_id is None:
return jsonify(
{"success": False, "message": "No model id provided."}
)

limit: int = int(limit) if limit is not None else 10

response: list = self.statestore.get_model_descendants(model_id, limit)

if response:

arr: list = []

for element in response:
obj = {
"model": element["model"],
"committed_at": element["committed_at"],
"session_id": element["session_id"],
"parent_model": element["parent_model"],
}
arr.append(obj)

result = {"result": arr}

return jsonify(result)
else:
return jsonify(
{"success": False, "message": "No model descendants available."}
)

def get_all_rounds(self):
"""Get all rounds.
Expand Down
34 changes: 33 additions & 1 deletion fedn/fedn/network/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def get_model_ancestors():
"""Get the ancestors of a model.
param: model: The model id to get the ancestors for.
type: model: str
param: limit: The maximum number of ancestors to return.
type: limit: int
return: The a list of model objects that the model derives from.
return: A list of model objects that the model derives from.
rtype: json
"""
model = request.args.get("model", None)
Expand All @@ -44,6 +45,22 @@ def get_model_ancestors():
return api.get_model_ancestors(model, limit)


@app.route("/get_model_descendants", methods=["GET"])
def get_model_descendants():
"""Get the ancestors of a model.
param: model: The model id to get the child for.
type: model: str
param: limit: The maximum number of descendants to return.
type: limit: int
return: A list of model objects that are descendents of the provided model id.
rtype: json
"""
model = request.args.get("model", None)
limit = request.args.get("limit", None)

return api.get_model_descendants(model, limit)


@app.route("/list_models", methods=["GET"])
def list_models():
"""Get models from the statestore.
Expand All @@ -65,6 +82,21 @@ def list_models():
return api.get_models(session_id, limit, skip, include_active)


@app.route("/get_model", methods=["GET"])
def get_model():
"""Get a model from the statestore.
param: model: The model id to get.
type: model: str
return: The model as a json object.
rtype: json
"""
model = request.args.get("model", None)
if model is None:
return jsonify({"success": False, "message": "Missing model id."}), 400

return api.get_model(model)


@app.route("/delete_model_trail", methods=["GET", "POST"])
def delete_model_trail():
"""Delete the model trail for a given session.
Expand Down
39 changes: 39 additions & 0 deletions fedn/fedn/network/storage/statestore/mongostatestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,45 @@ def get_model_ancestors(self, model_id: str, limit: int):

return result

def get_model_descendants(self, model_id: str, limit: int):
"""Get the model descendants.
:param model_id: The model id.
:type model_id: str
:param limit: The maximum number of descendants to return.
:type limit: int
:return: List of model descendants.
:rtype: list
"""

model: object = self.model.find_one({"key": "models", "model": model_id})
current_model_id: str = model["model"] if model is not None else None
result: list = []

for _ in range(limit):
if current_model_id is None:
break

model: str = self.model.find_one({"key": "models", "parent_model": current_model_id})

if model is not None:
result.append(model)
current_model_id = model["model"]

result.reverse()

return result

def get_model(self, model_id):
"""Get model with id.
:param model_id: id of model to get
:type model_id: str
:return: model with id
:rtype: ObjectId
"""
return self.model.find_one({"key": "models", "model": model_id})

def get_events(self, **kwargs):
"""Get events from the database.
Expand Down

0 comments on commit 771b3c7

Please sign in to comment.