Skip to content

Commit

Permalink
Bugfix/SK-675 | Model staging not in sync, TaskStream (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede authored Feb 9, 2024
1 parent 8e26c3d commit 7cef838
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 362 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ jobs:
- name: run ${{ matrix.to_test }}
run: .ci/tests/examples/run.sh ${{ matrix.to_test }}

- name: run ${{ matrix.to_test }} inference
run: .ci/tests/examples/run_inference.sh ${{ matrix.to_test }}
if: ${{ matrix.os != 'macos-11' && matrix.to_test == 'mnist-keras keras' }} # example available for Keras
# - name: run ${{ matrix.to_test }} inference
# run: .ci/tests/examples/run_inference.sh ${{ matrix.to_test }}
# if: ${{ matrix.os != 'macos-11' && matrix.to_test == 'mnist-keras keras' }} # example available for Keras

- name: print logs
if: failure()
Expand Down
89 changes: 31 additions & 58 deletions fedn/fedn/network/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,11 @@ def _subscribe_to_combiner(self, config):

# Start sending heartbeats to the combiner.
threading.Thread(target=self._send_heartbeat, kwargs={
'update_frequency': config['heartbeat_interval']}, daemon=True).start()
'update_frequency': config['heartbeat_interval']}, daemon=True).start()

# Start listening for combiner training and validation messages
if config['trainer']:
threading.Thread(
target=self._listen_to_model_update_request_stream, daemon=True).start()
if config['validator']:
threading.Thread(
target=self._listen_to_model_validation_request_stream, daemon=True).start()
threading.Thread(
target=self._listen_to_task_stream, daemon=True).start()
self._attached = True

# Start processing the client message inbox
Expand Down Expand Up @@ -359,7 +355,7 @@ def _initialize_dispatcher(self, config):
copy_tree(from_path, self.run_path)
self.dispatcher = Dispatcher(dispatch_config, self.run_path)

def get_model_from_combiner(self, id):
def get_model_from_combiner(self, id, timeout=20):
"""Fetch a model from the assigned combiner.
Downloads the model update object via a gRPC streaming channel.
Expand All @@ -369,8 +365,12 @@ def get_model_from_combiner(self, id):
:rtype: BytesIO
"""
data = BytesIO()
time_start = time.time()
request = fedn.ModelRequest(id=id)
request.sender.name = self.name
request.sender.role = fedn.WORKER

for part in self.modelStub.Download(fedn.ModelRequest(id=id), metadata=self.metadata):
for part in self.modelStub.Download(request, metadata=self.metadata):

if part.status == fedn.ModelStatus.IN_PROGRESS:
data.write(part.data)
Expand All @@ -381,6 +381,11 @@ def get_model_from_combiner(self, id):
if part.status == fedn.ModelStatus.FAILED:
return None

if part.status == fedn.ModelStatus.UNKNOWN:
if time.time() - time_start >= timeout:
return None
continue

return data

def send_model_to_combiner(self, model, id):
Expand Down Expand Up @@ -408,7 +413,7 @@ def send_model_to_combiner(self, model, id):

return result

def _listen_to_model_update_request_stream(self):
def _listen_to_task_stream(self):
"""Subscribe to the model update request stream.
:return: None
Expand All @@ -423,16 +428,21 @@ def _listen_to_model_update_request_stream(self):

while self._attached:
try:
for request in self.combinerStub.ModelUpdateRequestStream(r, metadata=self.metadata):
for request in self.combinerStub.TaskStream(r, metadata=self.metadata):
if request:
logger.debug("Received model update request from combiner: {}.".format(request))
if request.sender.role == fedn.COMBINER:
# Process training request
self._send_status("Received model update request.", log_level=fedn.Status.AUDIT,
type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request)
logger.info("Received model update request.")
logger.info("Received model update request of type {} for model_id {}".format(request.type, request.model_id))

self.inbox.put(('train', request))
if request.type == fedn.StatusType.MODEL_UPDATE and self.config['trainer']:
self.inbox.put(('train', request))
elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config['validator']:
self.inbox.put(('validate', request))
else:
logger.error("Unknown request type: {}".format(request.type))

except grpc.RpcError as e:
# Handle gRPC errors
Expand All @@ -453,45 +463,6 @@ def _listen_to_model_update_request_stream(self):
if not self._attached:
return

def _listen_to_model_validation_request_stream(self):
"""Subscribe to the model validation request stream.
:return: None
:rtype: None
"""

r = fedn.ClientAvailableMessage()
r.sender.name = self.name
r.sender.role = fedn.WORKER
while True:
try:
for request in self.combinerStub.ModelValidationRequestStream(r, metadata=self.metadata):
# Process validation request
model_id = request.model_id
self._send_status("Received model validation request for model_id {}".format(model_id),
log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_VALIDATION_REQUEST,
request=request)
logger.info("Received model validation request for model_id {}".format(model_id))
self.inbox.put(('validate', request))

except grpc.RpcError as e:
# Handle gRPC errors
status_code = e.code()
if status_code == grpc.StatusCode.UNAVAILABLE:
logger.warning("GRPC server unavailable during model validation request stream. Retrying.")
# Retry after a delay
time.sleep(5)
else:
# Log the error and continue
logger.error(f"An error occurred during model validation request stream: {e}")

except Exception as ex:
# Handle other exceptions
logger.error(f"An error occurred during model validation request stream: {ex}")

if not self._attached:
return

def _process_training_request(self, model_id):
"""Process a training (model update) request.
Expand All @@ -509,6 +480,9 @@ def _process_training_request(self, model_id):
meta = {}
tic = time.time()
mdl = self.get_model_from_combiner(str(model_id))
if mdl is None:
logger.error("Could not retrieve model from combiner. Aborting training request.")
return None, None
meta['fetch_model'] = time.time() - tic

inpath = self.helper.get_tmp_path()
Expand Down Expand Up @@ -573,6 +547,9 @@ def _process_validation_request(self, model_id, is_inference):
self.state = ClientState.validating
try:
model = self.get_model_from_combiner(str(model_id))
if model is None:
logger.error("Could not retrieve model from combiner. Aborting validation request.")
return None
inpath = self.helper.get_tmp_path()

with open(inpath, "wb") as fh:
Expand Down Expand Up @@ -641,7 +618,7 @@ def process_request(self):
elif task_type == 'validate':
self.state = ClientState.validating
metrics = self._process_validation_request(
request.model_id, request.is_inference)
request.model_id, False)

if metrics is not None:
# Send validation
Expand All @@ -658,11 +635,7 @@ def process_request(self):
_ = self.combinerStub.SendModelValidation(
validation, metadata=self.metadata)

# Set status type
if request.is_inference:
status_type = fedn.StatusType.INFERENCE
else:
status_type = fedn.StatusType.MODEL_VALIDATION
status_type = fedn.StatusType.MODEL_VALIDATION

self._send_status("Model validation completed.", log_level=fedn.Status.AUDIT,
type=status_type, request=validation)
Expand Down
117 changes: 20 additions & 97 deletions fedn/fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,12 @@ def request_model_update(self, config, clients=[]):
"""
# The request to be added to the client queue
request = fedn.ModelUpdateRequest()
request = fedn.TaskRequest()
request.model_id = config['model_id']
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
request.data = json.dumps(config)
request.type = fedn.StatusType.MODEL_UPDATE

request.sender.name = self.id
request.sender.role = fedn.COMBINER
Expand All @@ -184,7 +185,7 @@ def request_model_update(self, config, clients=[]):
for client in clients:
request.receiver.name = client
request.receiver.role = fedn.WORKER
self._put_request_to_client_queue(request, fedn.Channel.MODEL_UPDATE_REQUESTS)
self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE)

if len(clients) < 20:
logger.info("Sent model update request for model {} to clients {}".format(
Expand All @@ -205,19 +206,23 @@ def request_model_validation(self, model_id, config, clients=[]):
"""
# The request to be added to the client queue
request = fedn.ModelValidationRequest()
request = fedn.TaskRequest()
request.model_id = model_id
request.correlation_id = str(uuid.uuid4()) # Obsolete?
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
request.is_inference = (config['task'] == 'inference')
# request.is_inference = (config['task'] == 'inference')
request.type = fedn.StatusType.MODEL_VALIDATION

request.sender.name = self.id
request.sender.role = fedn.COMBINER

if len(clients) == 0:
clients = self.get_active_validators()

for client in clients:
request.receiver.name = client
request.receiver.role = fedn.WORKER
self._put_request_to_client_queue(request, fedn.Channel.MODEL_VALIDATION_REQUESTS)
self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE)

if len(clients) < 20:
logger.info("Sent model validation request for model {} to clients {}".format(
Expand All @@ -232,7 +237,7 @@ def get_active_trainers(self):
:return: the list of active trainers
:rtype: list
"""
trainers = self._list_active_clients(fedn.Channel.MODEL_UPDATE_REQUESTS)
trainers = self._list_active_clients(fedn.Queue.TASK_QUEUE)
return trainers

def get_active_validators(self):
Expand All @@ -241,7 +246,7 @@ def get_active_validators(self):
:return: the list of active validators
:rtype: list
"""
validators = self._list_active_clients(fedn.Channel.MODEL_VALIDATION_REQUESTS)
validators = self._list_active_clients(fedn.Queue.TASK_QUEUE)
return validators

def nr_active_trainers(self):
Expand Down Expand Up @@ -349,7 +354,7 @@ def _deamon_thread_client_status(self, timeout=10):
while True:
time.sleep(timeout)
# TODO: Also update validation clients
self._list_active_clients(fedn.Channel.MODEL_UPDATE_REQUESTS)
self._list_active_clients(fedn.Queue.TASK_QUEUE)

def _put_request_to_client_queue(self, request, queue_name):
""" Get a client specific queue and add a request to it.
Expand Down Expand Up @@ -545,7 +550,7 @@ def AcceptingClients(self, request: fedn.ConnectionRequest, context):
"""
response = fedn.ConnectionResponse()
active_clients = self._list_active_clients(
fedn.Channel.MODEL_UPDATE_REQUESTS)
fedn.Queue.TASK_QUEUE)

try:
requested = int(self.max_clients)
Expand Down Expand Up @@ -588,33 +593,7 @@ def SendHeartbeat(self, heartbeat: fedn.Heartbeat, context):

# Combiner Service

def ModelUpdateStream(self, update, context):
""" Model update stream RPC endpoint. Update status for client is connecting to stream.
:param update: the update message
:type update: :class:`fedn.network.grpc.fedn_pb2.ModelUpdate`
:param context: the context
:type context: :class:`grpc._server._Context`
"""
client = update.sender
status = fedn.Status(
status="Client {} connecting to ModelUpdateStream.".format(client.name))
status.log_level = fedn.Status.INFO
status.sender.name = self.id
status.sender.role = role_to_proto_role(self.role)

self._subscribe_client_to_queue(client, fedn.Channel.MODEL_UPDATES)
q = self.__get_queue(client, fedn.Channel.MODEL_UPDATES)

self._send_status(status)

while context.is_active():
try:
yield q.get(timeout=1.0)
except queue.Empty:
pass

def ModelUpdateRequestStream(self, response, context):
def TaskStream(self, response, context):
""" A server stream RPC endpoint (Update model). Messages from client stream.
:param response: the response
Expand All @@ -627,18 +606,18 @@ def ModelUpdateRequestStream(self, response, context):
metadata = context.invocation_metadata()
if metadata:
metadata = dict(metadata)
logger.info("grpc.Combiner.ModelUpdateRequestStream: Client connected: {}\n".format(metadata['client']))
logger.info("grpc.Combiner.TaskStream: Client connected: {}\n".format(metadata['client']))

status = fedn.Status(
status="Client {} connecting to ModelUpdateRequestStream.".format(client.name))
status="Client {} connecting to TaskStream.".format(client.name))
status.log_level = fedn.Status.INFO
status.timestamp.GetCurrentTime()

self.__whoami(status.sender, self)

self._subscribe_client_to_queue(
client, fedn.Channel.MODEL_UPDATE_REQUESTS)
q = self.__get_queue(client, fedn.Channel.MODEL_UPDATE_REQUESTS)
client, fedn.Queue.TASK_QUEUE)
q = self.__get_queue(client, fedn.Queue.TASK_QUEUE)

self._send_status(status)

Expand All @@ -657,62 +636,6 @@ def ModelUpdateRequestStream(self, response, context):
except Exception as e:
logger.error("Error in ModelUpdateRequestStream: {}".format(e))

def ModelValidationStream(self, update, context):
""" Model validation stream RPC endpoint. Update status for client is connecting to stream.
:param update: the update message
:type update: :class:`fedn.network.grpc.fedn_pb2.ModelValidation`
:param context: the context
:type context: :class:`grpc._server._Context`
"""
client = update.sender
status = fedn.Status(
status="Client {} connecting to ModelValidationStream.".format(client.name))
status.log_level = fedn.Status.INFO

status.sender.name = self.id
status.sender.role = role_to_proto_role(self.role)

self._subscribe_client_to_queue(client, fedn.Channel.MODEL_VALIDATIONS)
q = self.__get_queue(client, fedn.Channel.MODEL_VALIDATIONS)

self._send_status(status)

while context.is_active():
try:
yield q.get(timeout=1.0)
except queue.Empty:
pass

def ModelValidationRequestStream(self, response, context):
""" A server stream RPC endpoint (Validation). Messages from client stream.
:param response: the response
:type response: :class:`fedn.network.grpc.fedn_pb2.ModelValidationRequest`
:param context: the context
:type context: :class:`grpc._server._Context`
"""

client = response.sender
status = fedn.Status(
status="Client {} connecting to ModelValidationRequestStream.".format(client.name))
status.log_level = fedn.Status.INFO
status.sender.name = self.id
status.sender.role = role_to_proto_role(self.role)
status.timestamp.GetCurrentTime()

self._subscribe_client_to_queue(
client, fedn.Channel.MODEL_VALIDATION_REQUESTS)
q = self.__get_queue(client, fedn.Channel.MODEL_VALIDATION_REQUESTS)

self._send_status(status)

while context.is_active():
try:
yield q.get(timeout=1.0)
except queue.Empty:
pass

def SendModelUpdate(self, request, context):
""" Send a model update response.
Expand Down
Loading

0 comments on commit 7cef838

Please sign in to comment.