Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-613 | Display model trail in studio #501

Merged
merged 4 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions fedn/fedn/network/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,24 @@ def get_models(self, session_id: str = None, limit: str = None, skip: str = None

return jsonify(result)

def get_model(self, model_id: str):
result = self.statestore.get_model(model_id)

if result is None:
return (
jsonify({"success": False, "message": "No model found."}),
404,
)

payload = {
"committed_at": result["committed_at"],
"parent_model": result["parent_model"],
"model": result["model"],
"session_id": result["session_id"],
}

return jsonify(payload)

def get_model_trail(self):
"""Get the model trail for a given session.

Expand All @@ -784,6 +802,86 @@ def get_model_trail(self):
{"success": False, "message": "No model trail available."}
)

def get_model_ancestors(self, model_id: str, limit: str = None):
"""Get the model ancestors for a given model.

:param model_id: The model id to get the model ancestors for.
:type model_id: str
:param limit: The number of ancestors to return.
:type limit: str
:return: The model ancestors for the given model as a json response.
:rtype: :class:`flask.Response`
"""
if model_id is None:
return jsonify(
{"success": False, "message": "No model id provided."}
)

limit: int = int(limit) if limit is not None else 10 # if limit is None, default to 10

response = self.statestore.get_model_ancestors(model_id, limit)
if response:

arr: list = []

for element in response:
obj = {
"model": element["model"],
"committed_at": element["committed_at"],
"session_id": element["session_id"],
"parent_model": element["parent_model"],
}
arr.append(obj)

result = {"result": arr}

return jsonify(result)
else:
return jsonify(
{"success": False, "message": "No model ancestors available."}
)

def get_model_descendants(self, model_id: str, limit: str = None):
"""Get the model descendants for a given model.

:param model_id: The model id to get the model descendants for.
:type model_id: str
:param limit: The number of descendants to return.
:type limit: str
:return: The model descendants for the given model as a json response.
:rtype: :class:`flask.Response`
"""

if model_id is None:
return jsonify(
{"success": False, "message": "No model id provided."}
)

limit: int = int(limit) if limit is not None else 10

response: list = self.statestore.get_model_descendants(model_id, limit)

if response:

arr: list = []

for element in response:
obj = {
"model": element["model"],
"committed_at": element["committed_at"],
"session_id": element["session_id"],
"parent_model": element["parent_model"],
}
arr.append(obj)

result = {"result": arr}

return jsonify(result)
else:
return jsonify(
{"success": False, "message": "No model descendants available."}
)

def get_all_rounds(self):
"""Get all rounds.

Expand Down
47 changes: 47 additions & 0 deletions fedn/fedn/network/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,38 @@ def get_model_trail():
return api.get_model_trail()


@app.route("/get_model_ancestors", methods=["GET"])
def get_model_ancestors():
"""Get the ancestors of a model.
param: model: The model id to get the ancestors for.
type: model: str
param: limit: The maximum number of ancestors to return.
type: limit: int
return: A list of model objects that the model derives from.
rtype: json
"""
model = request.args.get("model", None)
limit = request.args.get("limit", None)

return api.get_model_ancestors(model, limit)


@app.route("/get_model_descendants", methods=["GET"])
def get_model_descendants():
"""Get the ancestors of a model.
param: model: The model id to get the child for.
type: model: str
param: limit: The maximum number of descendants to return.
type: limit: int
return: A list of model objects that are descendents of the provided model id.
rtype: json
"""
model = request.args.get("model", None)
limit = request.args.get("limit", None)

return api.get_model_descendants(model, limit)


@app.route("/list_models", methods=["GET"])
def list_models():
"""Get models from the statestore.
Expand All @@ -50,6 +82,21 @@ def list_models():
return api.get_models(session_id, limit, skip, include_active)


@app.route("/get_model", methods=["GET"])
def get_model():
"""Get a model from the statestore.
param: model: The model id to get.
type: model: str
return: The model as a json object.
rtype: json
"""
model = request.args.get("model", None)
if model is None:
return jsonify({"success": False, "message": "Missing model id."}), 400

return api.get_model(model)


@app.route("/delete_model_trail", methods=["GET", "POST"])
def delete_model_trail():
"""Delete the model trail for a given session.
Expand Down
74 changes: 74 additions & 0 deletions fedn/fedn/network/storage/statestore/mongostatestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,20 @@ def set_latest_model(self, model_id, session_id=None):
"""

committed_at = datetime.now()
current_model = self.model.find_one({"key": "current_model"})
parent_model = None

# if session_id is set the it means the model is generated from a session
# and we need to set the parent model
# if not the model is uploaded by the user and we don't need to set the parent model
if session_id is not None:
parent_model = current_model["model"] if current_model and "model" in current_model else None

self.model.insert_one(
{
"key": "models",
"model": model_id,
"parent_model": parent_model,
"session_id": session_id,
"committed_at": committed_at,
}
Expand Down Expand Up @@ -534,6 +543,71 @@ def get_model_trail(self):
except (KeyError, IndexError):
return None

def get_model_ancestors(self, model_id: str, limit: int):
"""Get the model ancestors.

:param model_id: The model id.
:type model_id: str
:param limit: The maximum number of ancestors to return.
:type limit: int
:return: List of model ancestors.
:rtype: list
"""
model = self.model.find_one({"key": "models", "model": model_id})
current_model_id = model["parent_model"] if model is not None else None
result = []

for _ in range(limit):
if current_model_id is None:
break

model = self.model.find_one({"key": "models", "model": current_model_id})

if model is not None:
result.append(model)
current_model_id = model["parent_model"]

return result

def get_model_descendants(self, model_id: str, limit: int):
"""Get the model descendants.

:param model_id: The model id.
:type model_id: str
:param limit: The maximum number of descendants to return.
:type limit: int
:return: List of model descendants.
:rtype: list
"""

model: object = self.model.find_one({"key": "models", "model": model_id})
current_model_id: str = model["model"] if model is not None else None
result: list = []

for _ in range(limit):
if current_model_id is None:
break

model: str = self.model.find_one({"key": "models", "parent_model": current_model_id})

if model is not None:
result.append(model)
current_model_id = model["model"]

result.reverse()

return result

def get_model(self, model_id):
"""Get model with id.

:param model_id: id of model to get
:type model_id: str
:return: model with id
:rtype: ObjectId
"""
return self.model.find_one({"key": "models", "model": model_id})

def get_events(self, **kwargs):
"""Get events from the database.

Expand Down
Loading