Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-911 | Add terminate session to controller #641

Merged
merged 3 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions fedn/network/controller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def __init__(self, message):
super().__init__(self.message)


class SessionTerminatedException(Exception):
"""Exception class for when session is terminated"""

def __init__(self, message):
"""Constructor method.

:param message: The exception message.
:type message: str

"""
self.message = message
super().__init__(self.message)


class Control(ControlBase):
"""Controller, implementing the overall global training, validation and inference logic.

Expand Down Expand Up @@ -122,6 +136,8 @@ def start_session(self, session_id: str, rounds: int) -> None:
current_round = round

try:
if self.get_session_status(session_id) == "Terminated":
break
_, round_data = self.round(session_config, str(current_round))
except TypeError as e:
logger.error("Failed to execute round: {0}".format(e))
Expand All @@ -130,7 +146,8 @@ def start_session(self, session_id: str, rounds: int) -> None:

session_config["model_id"] = self.statestore.get_latest_model()

self.set_session_status(session_id, "Finished")
if self.get_session_status(session_id) == "Started":
self.set_session_status(session_id, "Finished")
self._state = ReducerState.idle

def session(self, config: RoundConfig) -> None:
Expand Down Expand Up @@ -172,6 +189,8 @@ def session(self, config: RoundConfig) -> None:
current_round = round

try:
if self.get_session_status(config["session_id"]) == "Terminated":
break
_, round_data = self.round(config, str(current_round))
except TypeError as e:
logger.error("Failed to execute round: {0}".format(e))
Expand All @@ -181,7 +200,8 @@ def session(self, config: RoundConfig) -> None:
config["model_id"] = self.statestore.get_latest_model()

# TODO: Report completion of session
self.set_session_status(config["session_id"], "Finished")
if self.get_session_status(config["session_id"]) == "Started":
self.set_session_status(config["session_id"], "Finished")
self._state = ReducerState.idle

def inference_session(self, config: RoundConfig) -> None:
Expand Down Expand Up @@ -227,6 +247,7 @@ def round(self, session_config: RoundConfig, round_id: str):
: type round_id: str

"""
session_id = session_config["session_id"]
self.create_round({"round_id": round_id, "status": "Pending"})

if len(self.network.get_combiners()) < 1:
Expand All @@ -239,7 +260,7 @@ def round(self, session_config: RoundConfig, round_id: str):
round_config["rounds"] = 1
round_config["round_id"] = round_id
round_config["task"] = "training"
round_config["session_id"] = session_config["session_id"]
round_config["session_id"] = session_id

self.set_round_config(round_id, round_config)

Expand All @@ -263,7 +284,11 @@ def round(self, session_config: RoundConfig, round_id: str):
# Wait until participating combiners have produced an updated global model,
# or round times out.
def do_if_round_times_out(result):
logger.warning("Round timed out!")
if isinstance(result.outcome.exception(), SessionTerminatedException):
logger.warning("Session terminated!")
return None, self.statestore.get_round(round_id)
else:
logger.warning("Round timed out!")

@retry(
wait=wait_random(min=1.0, max=2.0),
Expand All @@ -273,6 +298,9 @@ def do_if_round_times_out(result):
)
def combiners_done():
round = self.statestore.get_round(round_id)
session_status = self.get_session_status(session_id)
if session_status == "Terminated":
raise SessionTerminatedException("Session terminated!")
if "combiners" not in round:
logger.info("Waiting for combiners to update model...")
raise CombinersNotDoneException("Combiners have not yet reported.")
Expand All @@ -283,7 +311,7 @@ def combiners_done():

return True

combiners_done()
_ = combiners_done()

# Due to the distributed nature of the computation, there might be a
# delay before combiners have reported the round data to the db,
Expand Down
10 changes: 10 additions & 0 deletions fedn/network/controller/controlbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,16 @@ def set_session_status(self, session_id, status):
"""
self.statestore.set_session_status(session_id, status)

def get_session_status(self, session_id):
"""Get the status of a session.

:param session_id: The session unique identifier
:type session_id: str
:return: The status
:rtype: str
"""
return self.statestore.get_session_status(session_id)

def create_round(self, round_data):
"""Initialize a new round in backend db."""
self.statestore.create_round(round_data)
Expand Down
11 changes: 11 additions & 0 deletions fedn/network/storage/statestore/mongostatestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def get_session(self, session_id):
"""
return self.sessions.find_one({"session_id": session_id})

def get_session_status(self, session_id):
"""Get the session status.

:param session_id: The session id.
:type session_id: str
:return: The session status.
:rtype: str
"""
session = self.sessions.find_one({"session_id": session_id})
return session["status"]

def set_latest_model(self, model_id, session_id=None):
"""Set the latest model id.

Expand Down
Loading