Skip to content

Commit

Permalink
Feature/SK-574 | Add metadata to gRPC calls for python clients (#483)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede authored Nov 13, 2023
1 parent ae06b84 commit c4e5557
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 9 deletions.
44 changes: 35 additions & 9 deletions fedn/fedn/network/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,28 @@ def _assign(self):
print("Received combiner config: {}".format(client_config), flush=True)
return client_config

def _add_grpc_metadata(self, key, value):
"""Add metadata for gRPC calls.
:param key: The key of the metadata.
:type key: str
:param value: The value of the metadata.
:type value: str
"""
# Check if metadata exists and add if not
if not hasattr(self, 'metadata'):
self.metadata = ()

# Check if metadata key already exists and replace value if so
for i, (k, v) in enumerate(self.metadata):
if k == key:
# Replace value
self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1:]
return

# Set metadata using tuple concatenation
self.metadata += ((key, value),)

def _connect(self, client_config):
"""Connect to assigned combiner.
Expand All @@ -137,6 +159,9 @@ def _connect(self, client_config):

# TODO use the client_config['certificate'] for setting up secure comms'
host = client_config['host']
# Add host to gRPC metadata
self._add_grpc_metadata('grpc-server', host)
print("CLIENT: Using metadata: {}".format(self.metadata), flush=True)
port = client_config['port']
secure = False
if client_config['fqdn'] is not None:
Expand Down Expand Up @@ -331,7 +356,7 @@ def get_model(self, id):
"""
data = BytesIO()

for part in self.modelStub.Download(fedn.ModelRequest(id=id)):
for part in self.modelStub.Download(fedn.ModelRequest(id=id), metadata=self.metadata):

if part.status == fedn.ModelStatus.IN_PROGRESS:
data.write(part.data)
Expand Down Expand Up @@ -386,7 +411,7 @@ def upload_request_generator(mdl):
if not b:
break

result = self.modelStub.Upload(upload_request_generator(bt))
result = self.modelStub.Upload(upload_request_generator(bt), metadata=self.metadata)

return result

Expand All @@ -400,11 +425,12 @@ def _listen_to_model_update_request_stream(self):
r = fedn.ClientAvailableMessage()
r.sender.name = self.name
r.sender.role = fedn.WORKER
metadata = [('client', r.sender.name)]
# Add client to metadata
self._add_grpc_metadata('client', self.name)

while True:
try:
for request in self.combinerStub.ModelUpdateRequestStream(r, metadata=metadata):
for request in self.combinerStub.ModelUpdateRequestStream(r, metadata=self.metadata):
if request.sender.role == fedn.COMBINER:
# Process training request
self._send_status("Received model update request.", log_level=fedn.Status.AUDIT,
Expand Down Expand Up @@ -438,7 +464,7 @@ def _listen_to_model_validation_request_stream(self):
r.sender.role = fedn.WORKER
while True:
try:
for request in self.combinerStub.ModelValidationRequestStream(r):
for request in self.combinerStub.ModelValidationRequestStream(r, metadata=self.metadata):
# Process validation request
_ = request.model_id
self._send_status("Recieved model validation request.", log_level=fedn.Status.AUDIT,
Expand Down Expand Up @@ -589,7 +615,7 @@ def process_request(self):
update.correlation_id = request.correlation_id
update.meta = json.dumps(meta)
# TODO: Check responses
_ = self.combinerStub.SendModelUpdate(update)
_ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata)
self._send_status("Model update completed.", log_level=fedn.Status.AUDIT,
type=fedn.StatusType.MODEL_UPDATE, request=update)

Expand Down Expand Up @@ -618,7 +644,7 @@ def process_request(self):
validation.timestamp = self.str
validation.correlation_id = request.correlation_id
_ = self.combinerStub.SendModelValidation(
validation)
validation, metadata=self.metadata)

# Set status type
if request.is_inference:
Expand Down Expand Up @@ -655,7 +681,7 @@ def _send_heartbeat(self, update_frequency=2.0):
heartbeat = fedn.Heartbeat(sender=fedn.Client(
name=self.name, role=fedn.WORKER))
try:
self.connectorStub.SendHeartbeat(heartbeat)
self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata)
self._missed_heartbeat = 0
except grpc.RpcError as e:
status_code = e.code()
Expand Down Expand Up @@ -694,7 +720,7 @@ def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None)
self.logs.append(
"{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level,
status.status))
_ = self.connectorStub.SendStatus(status)
_ = self.connectorStub.SendStatus(status, metadata=self.metadata)

def run(self):
""" Run the client. """
Expand Down
45 changes: 45 additions & 0 deletions fedn/fedn/network/clients/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import unittest

from fedn.network.clients.client import Client


class TestClient(unittest.TestCase):
"""Test the Client class."""

def setUp(self):
self.client = Client()

def test_add_grpc_metadata(self):
"""Test the _add_grpc_metadata method."""

# Test adding metadata when it doesn't exist
self.client._add_grpc_metadata('key1', 'value1')
self.assertEqual(self.client.metadata, (('key1', 'value1'),))

# Test adding metadata when it already exists
self.client._add_grpc_metadata('key1', 'value2')
self.assertEqual(self.client.metadata, (('key1', 'value2'),))

# Test adding multiple metadata
self.client._add_grpc_metadata('key2', 'value3')
self.assertEqual(self.client.metadata, (('key1', 'value2'), ('key2', 'value3')))

# Test adding metadata with special characters
self.client._add_grpc_metadata('key3', 'value4!@#$%^&*()')
self.assertEqual(self.client.metadata, (('key1', 'value2'), ('key2', 'value3'), ('key3', 'value4!@#$%^&*()')))

# Test adding metadata with empty key
with self.assertRaises(ValueError):
self.client._add_grpc_metadata('', 'value5')

# Test adding metadata with empty value
with self.assertRaises(ValueError):
self.client._add_grpc_metadata('key4', '')

# Test adding metadata with None value
with self.assertRaises(ValueError):
self.client._add_grpc_metadata('key5', None)


if __name__ == '__main__':
unittest.main()

0 comments on commit c4e5557

Please sign in to comment.