Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-898 | Add presigned put url to inference workflow #631

Merged
merged 4 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/mnist-pytorch/client/fedn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ entry_points:
train:
command: python train.py
validate:
command: python validate.py
command: python validate.py
predict:
command: python predict.py
37 changes: 37 additions & 0 deletions examples/mnist-pytorch/client/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import sys

import torch
from data import load_data
from model import load_parameters

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))


def predict(in_model_path, out_artifact_path, data_path=None):
"""Validate model.

:param in_model_path: The path to the input model.
:type in_model_path: str
:param out_artifact_path: The path to save the predict output to.
:type out_artifact_path: str
:param data_path: The path to the data file.
:type data_path: str
"""
# Load data
x_test, y_test = load_data(data_path, is_train=False)

# Load model
model = load_parameters(in_model_path)
model.eval()

# Predict
with torch.no_grad():
y_pred = model(x_test)
# Save prediction to file/artifact, the artifact will be uploaded to the object store by the client
torch.save(y_pred, out_artifact_path)


if __name__ == "__main__":
predict(sys.argv[1], sys.argv[2])
76 changes: 71 additions & 5 deletions fedn/network/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from shutil import copytree

import grpc
import requests
from cryptography.hazmat.primitives.serialization import Encoding
from google.protobuf.json_format import MessageToJson
from OpenSSL import SSL
Expand All @@ -22,13 +23,11 @@
import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_PACKAGE_EXTRACT_DIR
from fedn.common.log_config import (logger, set_log_level_from_string,
set_log_stream)
from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream
from fedn.network.clients.connect import ConnectorClient, Status
from fedn.network.clients.package import PackageRuntime
from fedn.network.clients.state import ClientState, ClientStateToString
from fedn.network.combiner.modelservice import (get_tmp_path,
upload_request_generator)
from fedn.network.combiner.modelservice import get_tmp_path, upload_request_generator
from fedn.utils.dispatcher import Dispatcher
from fedn.utils.helpers.helpers import get_helper

Expand Down Expand Up @@ -438,12 +437,18 @@ def _listen_to_task_stream(self):
request=request,
sesssion_id=request.session_id,
)
logger.info("Received model update request of type {} for model_id {}".format(request.type, request.model_id))
logger.info("Received task request of type {} for model_id {}".format(request.type, request.model_id))

if request.type == fedn.StatusType.MODEL_UPDATE and self.config["trainer"]:
self.inbox.put(("train", request))
elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config["validator"]:
self.inbox.put(("validate", request))
elif request.type == fedn.StatusType.INFERENCE and self.config["validator"]:
logger.info("Received inference request for model_id {}".format(request.model_id))
presined_url = json.loads(request.data)
Wrede marked this conversation as resolved.
Show resolved Hide resolved
presined_url = presined_url["presigned_url"]
logger.info("Inference presigned URL: {}".format(presined_url))
self.inbox.put(("infer", request))
else:
logger.error("Unknown request type: {}".format(request.type))

Expand Down Expand Up @@ -586,6 +591,51 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session
self.state = ClientState.idle
return validation

def _process_inference_request(self, model_id: str, session_id: str, presigned_url: str):
"""Process an inference request.

:param model_id: The model id of the model to be used for inference.
:type model_id: str
:param session_id: The id of the current session.
:type session_id: str
:param presigned_url: The presigned URL for the data to be used for inference.
:type presigned_url: str
:return: None
"""
self.send_status(f"Processing inference request for model_id {model_id}", sesssion_id=session_id)
try:
model = self.get_model_from_combiner(str(model_id))
if model is None:
logger.error("Could not retrieve model from combiner. Aborting inference request.")
return
inpath = self.helper.get_tmp_path()

with open(inpath, "wb") as fh:
fh.write(model.getbuffer())

outpath = get_tmp_path()
self.dispatcher.run_cmd(f"predict {inpath} {outpath}")

# Upload the inference result to the presigned URL
with open(outpath, "rb") as fh:
response = requests.put(presigned_url, data=fh.read())

os.unlink(inpath)
os.unlink(outpath)

if response.status_code != 200:
logger.warning("Inference upload failed with status code {}".format(response.status_code))
self.state = ClientState.idle
return

except Exception as e:
logger.warning("Inference failed with exception {}".format(e))
self.state = ClientState.idle
return

self.state = ClientState.idle
return

def process_request(self):
"""Process training and validation tasks."""
while True:
Expand Down Expand Up @@ -682,6 +732,22 @@ def process_request(self):

self.state = ClientState.idle
self.inbox.task_done()
elif task_type == "infer":
self.state = ClientState.inferencing
try:
presigned_url = json.loads(request.data)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode inference request data: {e}")
self.state = ClientState.idle
continue

if "presigned_url" not in presigned_url:
logger.error("Inference request missing presigned_url.")
self.state = ClientState.idle
continue
presigned_url = presigned_url["presigned_url"]
_ = self._process_inference_request(request.model_id, request.session_id, presigned_url)
self.state = ClientState.idle
except queue.Empty:
pass
except grpc.RpcError as e:
Expand Down
1 change: 1 addition & 0 deletions fedn/network/clients/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class ClientState(Enum):
idle = 1
training = 2
validating = 3
inferencing = 4


def ClientStateToString(state):
Expand Down
66 changes: 32 additions & 34 deletions fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ def request_model_update(self, session_id, model_id, config, clients=[]):
:type clients: list

"""
request, clients = self._send_request_type(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients)
clients = self._send_request_type(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients)

if len(clients) < 20:
logger.info("Sent model update request for model {} to clients {}".format(request.model_id, clients))
logger.info("Sent model update request for model {} to clients {}".format(model_id, clients))
else:
logger.info("Sent model update request for model {} to {} clients".format(request.model_id, len(clients)))
logger.info("Sent model update request for model {} to {} clients".format(model_id, len(clients)))

def request_model_validation(self, session_id, model_id, clients=[]):
"""Ask clients to validate the current global model.
Expand All @@ -187,12 +187,12 @@ def request_model_validation(self, session_id, model_id, clients=[]):
:type clients: list

"""
request, clients = self._send_request_type(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients)
clients = self._send_request_type(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients)

if len(clients) < 20:
logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients))
logger.info("Sent model validation request for model {} to clients {}".format(model_id, clients))
else:
logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients)))
logger.info("Sent model validation request for model {} to {} clients".format(model_id, len(clients)))

def request_model_inference(self, session_id: str, model_id: str, clients: list = []) -> None:
"""Ask clients to perform inference on the model.
Expand All @@ -205,12 +205,12 @@ def request_model_inference(self, session_id: str, model_id: str, clients: list
:type clients: list

"""
request, clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients)
clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients)

if len(clients) < 20:
logger.info("Sent model inference request for model {} to clients {}".format(request.model_id, clients))
logger.info("Sent model inference request for model {} to clients {}".format(model_id, clients))
else:
logger.info("Sent model inference request for model {} to {} clients".format(request.model_id, len(clients)))
logger.info("Sent model inference request for model {} to {} clients".format(model_id, len(clients)))

def _send_request_type(self, request_type, session_id, model_id, config=None, clients=[]):
"""Send a request of a specific type to clients.
Expand All @@ -223,40 +223,38 @@ def _send_request_type(self, request_type, session_id, model_id, config=None, cl
:type config: dict
:param clients: the clients to send the request to
:type clients: list
:return: the request and the clients
:rtype: tuple
:return: the clients
:rtype: list
"""
request = fedn.TaskRequest()
request.model_id = model_id
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
request.type = request_type
request.session_id = session_id

request.sender.name = self.id
request.sender.role = fedn.COMBINER

if request_type == fedn.StatusType.MODEL_UPDATE:
request.data = json.dumps(config)
if len(clients) == 0:
if len(clients) == 0:
if request_type == fedn.StatusType.MODEL_UPDATE:
clients = self.get_active_trainers()
elif request_type == fedn.StatusType.MODEL_VALIDATION:
if len(clients) == 0:
elif request_type == fedn.StatusType.MODEL_VALIDATION:
clients = self.get_active_validators()
elif request_type == fedn.StatusType.INFERENCE:
request.data = json.dumps(config)
if len(clients) == 0:
elif request_type == fedn.StatusType.INFERENCE:
# TODO: add inference clients type
clients = self.get_active_validators()

# TODO: if inference, request.data should be user-defined data/parameters

for client in clients:
request = fedn.TaskRequest()
request.model_id = model_id
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
request.type = request_type
request.session_id = session_id

request.sender.name = self.id
request.sender.role = fedn.COMBINER
request.receiver.name = client
request.receiver.role = fedn.WORKER
# Set the request data, not used in validation
if request_type == fedn.StatusType.INFERENCE:
presigned_url = self.repository.presigned_put_url(self.repository.inference_bucket, f"{client}/{session_id}")
# TODO: in inference, request.data should also contain user-defined data/parameters
request.data = json.dumps({"presigned_url": presigned_url})
elif request_type == fedn.StatusType.MODEL_UPDATE:
request.data = json.dumps(config)
self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE)

return request, clients
return clients

def get_active_trainers(self):
"""Get a list of active trainers.
Expand Down
34 changes: 34 additions & 0 deletions fedn/network/storage/s3/repository.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import uuid

from fedn.common.log_config import logger
Expand All @@ -10,12 +11,17 @@ class Repository:
def __init__(self, config):
self.model_bucket = config["storage_bucket"]
self.context_bucket = config["context_bucket"]
try:
self.inference_bucket = config["inference_bucket"]
except KeyError:
self.inference_bucket = "fedn-inference"

# TODO: Make a plug-in solution
self.client = MINIORepository(config)

self.client.create_bucket(self.context_bucket)
self.client.create_bucket(self.model_bucket)
self.client.create_bucket(self.inference_bucket)

def get_model(self, model_id):
"""Retrieve a model with id model_id.
Expand Down Expand Up @@ -104,3 +110,31 @@ def delete_compute_package(self, compute_package):
except Exception:
logger.error("Failed to delete compute_package from repository.")
raise

def presigned_put_url(self, bucket: str, object_name: str, expires: datetime.timedelta = datetime.timedelta(hours=1)):
"""Generate a presigned URL for an upload object request.

:param bucket: The bucket name
:type bucket: str
:param object_name: The object name
:type object_name: str
:param expires: The time the URL is valid
:type expires: datetime.timedelta
:return: The URL
:rtype: str
"""
return self.client.client.presigned_put_object(bucket, object_name, expires)

def presigned_get_url(self, bucket: str, object_name: str, expires: datetime.timedelta = datetime.timedelta(hours=1)) -> str:
"""Generate a presigned URL for a download object request.

:param bucket: The bucket name
:type bucket: str
:param object_name: The object name
:type object_name: str
:param expires: The time the URL is valid
:type expires: datetime.timedelta
:return: The URL
:rtype: str
"""
return self.client.client.presigned_get_object(bucket, object_name, expires)
Loading