From f76558323ba1964a58e29c1ac5f4abd583518734 Mon Sep 17 00:00:00 2001 From: Song Jiaming Date: Wed, 24 Aug 2022 13:05:11 +0800 Subject: [PATCH] PPML add pytorch incremental training and tutorial (#5379) --- .../pytorch_nn_lr/pytorch-nn-lr-tutorial.md | 48 ++++- .../example/pytorch_nn_lr/pytorch_nn_lr_1.py | 49 +++-- .../example/pytorch_nn_lr/pytorch_nn_lr_2.py | 41 ++-- python/ppml/src/bigdl/ppml/fl/estimator.py | 10 +- python/ppml/src/bigdl/ppml/fl/nn/fl_server.py | 19 +- .../ppml/fl/nn/generated/nn_service_pb2.py | 46 +++- .../fl/nn/generated/nn_service_pb2_grpc.py | 66 ++++++ .../ppml/src/bigdl/ppml/fl/nn/nn_service.py | 20 +- .../bigdl/ppml/fl/nn/pytorch/aggregator.py | 38 +++- .../src/bigdl/ppml/fl/nn/pytorch/estimator.py | 21 +- .../bigdl/ppml/fl/nn/tensorflow/aggregator.py | 7 +- python/ppml/src/bigdl/ppml/fl/nn/utils.py | 3 + .../bigdl/ppml/fl/nn/pytorch/test_mnist.py | 2 +- .../ppml/fl/nn/pytorch/test_save_load.py | 197 ++++++++++++++++++ scala/ppml/src/main/proto/nn_service.proto | 19 ++ 15 files changed, 528 insertions(+), 58 deletions(-) create mode 100644 python/ppml/test/bigdl/ppml/fl/nn/pytorch/test_save_load.py diff --git a/python/ppml/example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md b/python/ppml/example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md index 86a6c844da0..3e0bdb92bfc 100644 --- a/python/ppml/example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md +++ b/python/ppml/example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md @@ -88,18 +88,39 @@ Then call `fit` method to train ```python response = ppl.fit(x, y) ``` -### 2.7 Predict + +### 2.6 Predict ```python result = ppl.predict(x) ``` +### 2.7 Save/Load +After training, save the client and server model by +```python +torch.save(ppl.model, model_path) +ppl.save_server_model(server_model_path) +``` +To start a new application to continue training +```python +client_model = torch.load(model_path) +# we do not pass server model this time, instead, we load it directly from server machine +ppl = Estimator.from_torch(client_model=model, + client_id=client_id, + loss_fn=loss_fn, + optimizer_cls=torch.optim.SGD, + optimizer_args={'lr':1e-3}, + target='localhost:8980') +ppl.load_server_model(server_model_path) + ## 3 Run FGBoost FL Server is required before running any federated applications. Check [Start FL Server]() section for details. ### 3.1 Start FL Server in SGX + #### 3.1.1 Start the container -Before running FL Server in SGX, please prepaer keys and start the BigDL PPML container first. Check [3.1 BigDL PPML Hello World](https://github.com/intel-analytics/BigDL/tree/main/ppml#31-bigdl-ppml-hello-world) for details. +Before running FL Server in SGX, please prepare keys and start the BigDL PPML container first. Check [3.1 BigDL PPML Hello World](https://github.com/intel-analytics/BigDL/tree/main/ppml#31-bigdl-ppml-hello-world) for details. #### 3.1.2 Run FL Server in SGX You can run FL Server in SGX with the following command: + ```bash bash start-python-fl-server-sgx.sh -p 8980 -c 2 ``` @@ -129,3 +150,26 @@ The first 5 predict results are printed [1.2120417e-23] [0.0000000e+00]] ``` +### 3.4 Incremental Training +Incremental training is supported, we just need to use the same configurations and start FL Server again. + +In SGX container, start FL Server +``` +./ppml/scripts/start-fl-server.sh +``` +For client applications, we change from creating model to directly loading. This is already implemented in example code, we just need to run client applications with an argument + +```bash +# run following commands in 2 different terminals +python pytorch_nn_lr_1.py true +python pytorch_nn_lr_2.py true +``` +The result based on new boosted trees are printed +``` +[[1.8799074e-36] + [1.7512805e-25] + [4.6501680e-30] + [1.4828590e-27] + [0.0000000e+00]] +``` +and you can see the loss continues to drop from the log of [Section 3.3](#33-get-results) \ No newline at end of file diff --git a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py index 98b7a4aa446..e760f42d7ed 100644 --- a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py +++ b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py @@ -17,13 +17,11 @@ from typing import List import numpy as np import pandas as pd +import click import torch from torch import Tensor, nn from bigdl.ppml.fl.estimator import Estimator -from bigdl.ppml.fl.algorithms.psi import PSI -from bigdl.ppml.fl.nn.fl_server import FLServer -from bigdl.ppml.fl.nn.pytorch.utils import set_one_like_parameter class LocalModel(nn.Module): @@ -48,7 +46,10 @@ def forward(self, x: List[Tensor]): return x -if __name__ == '__main__': + +@click.command() +@click.option('--load_model', default=False) +def run_client(load_model): # fl_server = FLServer(2) # fl_server.build() # fl_server.start() @@ -66,16 +67,36 @@ def forward(self, x: List[Tensor]): x = df_x.to_numpy(dtype="float32") y = np.expand_dims(df_y.to_numpy(dtype="float32"), axis=1) - model = LocalModel(len(df_x.columns)) loss_fn = nn.BCELoss() - server_model = ServerModel() - ppl = Estimator.from_torch(client_model=model, - client_id='1', - loss_fn=loss_fn, - optimizer_cls=torch.optim.SGD, - optimizer_args={'lr':1e-5}, - target='localhost:8980', - server_model=server_model) - response = ppl.fit(x, y) + + if load_model: + model = torch.load('/tmp/pytorch_client_model_1.pt') + ppl = Estimator.from_torch(client_model=model, + client_id='1', + loss_fn=loss_fn, + optimizer_cls=torch.optim.SGD, + optimizer_args={'lr':1e-5}, + target='localhost:8980', + server_model_path='/tmp/pytorch_server_model', + client_model_path='/tmp/pytorch_client_model_1.pt') + ppl.load_server_model('/tmp/pytorch_server_model') + response = ppl.fit(x, y, 5) + else: + model = LocalModel(len(df_x.columns)) + + server_model = ServerModel() + ppl = Estimator.from_torch(client_model=model, + client_id='1', + loss_fn=loss_fn, + optimizer_cls=torch.optim.SGD, + optimizer_args={'lr':1e-5}, + target='localhost:8980', + server_model=server_model, + server_model_path='/tmp/pytorch_server_model', + client_model_path='/tmp/pytorch_client_model_1.pt') + response = ppl.fit(x, y, 5) result = ppl.predict(x) print(result[:5]) + +if __name__ == '__main__': + run_client() \ No newline at end of file diff --git a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py index f8b4004a3ab..29a50d36e3b 100644 --- a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py +++ b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py @@ -16,12 +16,11 @@ import numpy as np import pandas as pd +import click import torch from torch import nn from bigdl.ppml.fl.estimator import Estimator -from bigdl.ppml.fl.algorithms.psi import PSI -from bigdl.ppml.fl.nn.pytorch.utils import set_one_like_parameter class LocalModel(nn.Module): @@ -34,7 +33,9 @@ def forward(self, x): return x -if __name__ == '__main__': +@click.command() +@click.option('--load_model', default=False) +def run_client(load_model): df_train = pd.read_csv('.data/diabetes-vfl-2.csv') # this should wait for the merge of 2 FLServer (Py4J Java gRPC and Python gRPC) @@ -45,16 +46,32 @@ def forward(self, x): df_x = df_train x = df_x.to_numpy(dtype="float32") - y = None + y = None - model = LocalModel(len(df_x.columns)) loss_fn = nn.BCELoss() - ppl = Estimator.from_torch(client_model=model, - client_id='2', - loss_fn=loss_fn, - optimizer_cls=torch.optim.SGD, - optimizer_args={'lr':1e-5}, - target='localhost:8980') - response = ppl.fit(x, y) + + if load_model: + model = torch.load('/tmp/pytorch_client_model_2.pt') + ppl = Estimator.from_torch(client_model=model, + client_id='2', + loss_fn=loss_fn, + optimizer_cls=torch.optim.SGD, + optimizer_args={'lr':1e-5}, + target='localhost:8980', + client_model_path='/tmp/pytorch_client_model_2.pt') + response = ppl.fit(x, y, 5) + else: + model = LocalModel(len(df_x.columns)) + ppl = Estimator.from_torch(client_model=model, + client_id='2', + loss_fn=loss_fn, + optimizer_cls=torch.optim.SGD, + optimizer_args={'lr':1e-5}, + target='localhost:8980', + client_model_path='/tmp/pytorch_client_model_2.pt') + response = ppl.fit(x, y, 5) result = ppl.predict(x) print(result[:5]) + +if __name__ == '__main__': + run_client() \ No newline at end of file diff --git a/python/ppml/src/bigdl/ppml/fl/estimator.py b/python/ppml/src/bigdl/ppml/fl/estimator.py index 6d2d928b431..23d420f46aa 100644 --- a/python/ppml/src/bigdl/ppml/fl/estimator.py +++ b/python/ppml/src/bigdl/ppml/fl/estimator.py @@ -35,14 +35,18 @@ def from_torch(client_model: nn.Module, optimizer_cls, optimizer_args={}, target="localhost:8980", - server_model=None): + server_model=None, + client_model_path=None, + server_model_path=None): estimator = PytorchEstimator(model=client_model, - loss_fn=loss_fn, + loss_fn=loss_fn, optimizer_cls=optimizer_cls, optimizer_args=optimizer_args, client_id=client_id, target=target, - server_model=server_model) + server_model=server_model, + client_model_path=client_model_path, + server_model_path=server_model_path) return estimator @staticmethod diff --git a/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py b/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py index a62a2a7962e..0dee411e31b 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py @@ -24,19 +24,22 @@ class FLServer(object): - def __init__(self, client_num=1): + def __init__(self, client_num=None): self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=5)) self.port = 8980 self.client_num = client_num self.secure = False self.load_config() + # a chance to overwrite client num + if client_num is not None: + self.conf['clientNum'] = client_num def set_port(self, port): self.port = port def build(self): add_NNServiceServicer_to_server( - NNServiceImpl(client_num=self.client_num), + NNServiceImpl(conf=self.conf), self.server) if self.secure: self.server.add_secure_port(f'[::]:{self.port}', self.server_credentials) @@ -65,11 +68,17 @@ def load_config(self): ( (private_key, certificate_chain), ) ) if 'serverPort' in conf: self.port = conf['serverPort'] + self.generate_conf(conf) - except yaml.YAMLError as e: - logging.warn('Loading config failed, using default config ') except Exception as e: - logging.warn('Failed to find config file "ppml-conf.yaml", using default config') + logging.warn('Failed to load config file "ppml-conf.yaml", using default config') + self.generate_conf({}) + + def generate_conf(self, conf: dict): + self.conf = conf + # set default parameters if not specified in config + if 'clientNum' not in conf.keys(): + self.conf['clientNum'] = 1 def wait_for_termination(self): self.server.wait_for_termination() diff --git a/python/ppml/src/bigdl/ppml/fl/nn/generated/nn_service_pb2.py b/python/ppml/src/bigdl/ppml/fl/nn/generated/nn_service_pb2.py index f389aabd5ec..b4ea28a680d 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/generated/nn_service_pb2.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/generated/nn_service_pb2.py @@ -15,7 +15,7 @@ import fl_base_pb2 as fl__base__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10nn_service.proto\x12\x02nn\x1a\rfl_base.proto\"O\n\x0cTrainRequest\x12\x12\n\nclientuuid\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x11\n\talgorithm\x18\x03 \x01(\t\"I\n\rTrainResponse\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x0c\n\x04\x63ode\x18\x03 \x01(\x05\"b\n\x0f\x45valuateRequest\x12\x12\n\nclientuuid\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x11\n\talgorithm\x18\x03 \x01(\t\x12\x0e\n\x06return\x18\x04 \x01(\x08\"]\n\x10\x45valuateResponse\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x0c\n\x04\x63ode\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\"Q\n\x0ePredictRequest\x12\x12\n\nclientuuid\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x11\n\talgorithm\x18\x03 \x01(\t\"K\n\x0fPredictResponse\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x0c\n\x04\x63ode\x18\x03 \x01(\x05\"r\n\x11UploadMetaRequest\x12\x13\n\x0b\x63lient_uuid\x18\x01 \x01(\t\x12\x0f\n\x07loss_fn\x18\x02 \x01(\x0c\x12#\n\toptimizer\x18\x03 \x01(\x0b\x32\x10.nn.ClassAndArgs\x12\x12\n\naggregator\x18\x04 \x01(\t\"\x1b\n\tByteChunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x0c\x43lassAndArgs\x12\x0b\n\x03\x63ls\x18\x01 \x01(\x0c\x12\x0c\n\x04\x61rgs\x18\x02 \x01(\x0c\"%\n\x12UploadMetaResponse\x12\x0f\n\x07message\x18\x01 \x01(\t2\xa4\x02\n\tNNService\x12.\n\x05train\x12\x10.nn.TrainRequest\x1a\x11.nn.TrainResponse\"\x00\x12\x37\n\x08\x65valuate\x12\x13.nn.EvaluateRequest\x1a\x14.nn.EvaluateResponse\"\x00\x12\x34\n\x07predict\x12\x12.nn.PredictRequest\x1a\x13.nn.PredictResponse\"\x00\x12>\n\x0bupload_meta\x12\x15.nn.UploadMetaRequest\x1a\x16.nn.UploadMetaResponse\"\x00\x12\x38\n\x0bupload_file\x12\r.nn.ByteChunk\x1a\x16.nn.UploadMetaResponse\"\x00(\x01\x42=\n+com.intel.analytics.bigdl.ppml.fl.generatedB\x0eNNServiceProtob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10nn_service.proto\x12\x02nn\x1a\rfl_base.proto\"O\n\x0cTrainRequest\x12\x12\n\nclientuuid\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x11\n\talgorithm\x18\x03 \x01(\t\"I\n\rTrainResponse\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x0c\n\x04\x63ode\x18\x03 \x01(\x05\"b\n\x0f\x45valuateRequest\x12\x12\n\nclientuuid\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x11\n\talgorithm\x18\x03 \x01(\t\x12\x0e\n\x06return\x18\x04 \x01(\x08\"]\n\x10\x45valuateResponse\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x0c\n\x04\x63ode\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\"Q\n\x0ePredictRequest\x12\x12\n\nclientuuid\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x11\n\talgorithm\x18\x03 \x01(\t\"K\n\x0fPredictResponse\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x18\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\n.TensorMap\x12\x0c\n\x04\x63ode\x18\x03 \x01(\x05\"r\n\x11UploadMetaRequest\x12\x13\n\x0b\x63lient_uuid\x18\x01 \x01(\t\x12\x0f\n\x07loss_fn\x18\x02 \x01(\x0c\x12#\n\toptimizer\x18\x03 \x01(\x0b\x32\x10.nn.ClassAndArgs\x12\x12\n\naggregator\x18\x04 \x01(\t\"\x1b\n\tByteChunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x0c\x43lassAndArgs\x12\x0b\n\x03\x63ls\x18\x01 \x01(\x0c\x12\x0c\n\x04\x61rgs\x18\x02 \x01(\x0c\"%\n\x12UploadMetaResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\"J\n\x10LoadModelRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x0f\n\x07\x62\x61\x63kend\x18\x02 \x01(\t\x12\x12\n\nmodel_path\x18\x03 \x01(\t\"$\n\x11LoadModelResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\"J\n\x10SaveModelRequest\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x0f\n\x07\x62\x61\x63kend\x18\x02 \x01(\t\x12\x12\n\nmodel_path\x18\x03 \x01(\t\"$\n\x11SaveModelResponse\x12\x0f\n\x07message\x18\x01 \x01(\t2\xac\x03\n\tNNService\x12.\n\x05train\x12\x10.nn.TrainRequest\x1a\x11.nn.TrainResponse\"\x00\x12\x37\n\x08\x65valuate\x12\x13.nn.EvaluateRequest\x1a\x14.nn.EvaluateResponse\"\x00\x12\x34\n\x07predict\x12\x12.nn.PredictRequest\x1a\x13.nn.PredictResponse\"\x00\x12>\n\x0bupload_meta\x12\x15.nn.UploadMetaRequest\x1a\x16.nn.UploadMetaResponse\"\x00\x12\x38\n\x0bupload_file\x12\r.nn.ByteChunk\x1a\x16.nn.UploadMetaResponse\"\x00(\x01\x12\x42\n\x11save_server_model\x12\x14.nn.SaveModelRequest\x1a\x15.nn.SaveModelResponse\"\x00\x12\x42\n\x11load_server_model\x12\x14.nn.LoadModelRequest\x1a\x15.nn.LoadModelResponse\"\x00\x42=\n+com.intel.analytics.bigdl.ppml.fl.generatedB\x0eNNServiceProtob\x06proto3') @@ -29,6 +29,10 @@ _BYTECHUNK = DESCRIPTOR.message_types_by_name['ByteChunk'] _CLASSANDARGS = DESCRIPTOR.message_types_by_name['ClassAndArgs'] _UPLOADMETARESPONSE = DESCRIPTOR.message_types_by_name['UploadMetaResponse'] +_LOADMODELREQUEST = DESCRIPTOR.message_types_by_name['LoadModelRequest'] +_LOADMODELRESPONSE = DESCRIPTOR.message_types_by_name['LoadModelResponse'] +_SAVEMODELREQUEST = DESCRIPTOR.message_types_by_name['SaveModelRequest'] +_SAVEMODELRESPONSE = DESCRIPTOR.message_types_by_name['SaveModelResponse'] TrainRequest = _reflection.GeneratedProtocolMessageType('TrainRequest', (_message.Message,), { 'DESCRIPTOR' : _TRAINREQUEST, '__module__' : 'nn_service_pb2' @@ -99,6 +103,34 @@ }) _sym_db.RegisterMessage(UploadMetaResponse) +LoadModelRequest = _reflection.GeneratedProtocolMessageType('LoadModelRequest', (_message.Message,), { + 'DESCRIPTOR' : _LOADMODELREQUEST, + '__module__' : 'nn_service_pb2' + # @@protoc_insertion_point(class_scope:nn.LoadModelRequest) + }) +_sym_db.RegisterMessage(LoadModelRequest) + +LoadModelResponse = _reflection.GeneratedProtocolMessageType('LoadModelResponse', (_message.Message,), { + 'DESCRIPTOR' : _LOADMODELRESPONSE, + '__module__' : 'nn_service_pb2' + # @@protoc_insertion_point(class_scope:nn.LoadModelResponse) + }) +_sym_db.RegisterMessage(LoadModelResponse) + +SaveModelRequest = _reflection.GeneratedProtocolMessageType('SaveModelRequest', (_message.Message,), { + 'DESCRIPTOR' : _SAVEMODELREQUEST, + '__module__' : 'nn_service_pb2' + # @@protoc_insertion_point(class_scope:nn.SaveModelRequest) + }) +_sym_db.RegisterMessage(SaveModelRequest) + +SaveModelResponse = _reflection.GeneratedProtocolMessageType('SaveModelResponse', (_message.Message,), { + 'DESCRIPTOR' : _SAVEMODELRESPONSE, + '__module__' : 'nn_service_pb2' + # @@protoc_insertion_point(class_scope:nn.SaveModelResponse) + }) +_sym_db.RegisterMessage(SaveModelResponse) + _NNSERVICE = DESCRIPTOR.services_by_name['NNService'] if _descriptor._USE_C_DESCRIPTORS == False: @@ -124,6 +156,14 @@ _CLASSANDARGS._serialized_end=736 _UPLOADMETARESPONSE._serialized_start=738 _UPLOADMETARESPONSE._serialized_end=775 - _NNSERVICE._serialized_start=778 - _NNSERVICE._serialized_end=1070 + _LOADMODELREQUEST._serialized_start=777 + _LOADMODELREQUEST._serialized_end=851 + _LOADMODELRESPONSE._serialized_start=853 + _LOADMODELRESPONSE._serialized_end=889 + _SAVEMODELREQUEST._serialized_start=891 + _SAVEMODELREQUEST._serialized_end=965 + _SAVEMODELRESPONSE._serialized_start=967 + _SAVEMODELRESPONSE._serialized_end=1003 + _NNSERVICE._serialized_start=1006 + _NNSERVICE._serialized_end=1434 # @@protoc_insertion_point(module_scope) diff --git a/python/ppml/src/bigdl/ppml/fl/nn/generated/nn_service_pb2_grpc.py b/python/ppml/src/bigdl/ppml/fl/nn/generated/nn_service_pb2_grpc.py index 80fec8b4936..6e789b1154c 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/generated/nn_service_pb2_grpc.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/generated/nn_service_pb2_grpc.py @@ -39,6 +39,16 @@ def __init__(self, channel): request_serializer=nn__service__pb2.ByteChunk.SerializeToString, response_deserializer=nn__service__pb2.UploadMetaResponse.FromString, ) + self.save_server_model = channel.unary_unary( + '/nn.NNService/save_server_model', + request_serializer=nn__service__pb2.SaveModelRequest.SerializeToString, + response_deserializer=nn__service__pb2.SaveModelResponse.FromString, + ) + self.load_server_model = channel.unary_unary( + '/nn.NNService/load_server_model', + request_serializer=nn__service__pb2.LoadModelRequest.SerializeToString, + response_deserializer=nn__service__pb2.LoadModelResponse.FromString, + ) class NNServiceServicer(object): @@ -74,6 +84,18 @@ def upload_file(self, request_iterator, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def save_server_model(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def load_server_model(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_NNServiceServicer_to_server(servicer, server): rpc_method_handlers = { @@ -102,6 +124,16 @@ def add_NNServiceServicer_to_server(servicer, server): request_deserializer=nn__service__pb2.ByteChunk.FromString, response_serializer=nn__service__pb2.UploadMetaResponse.SerializeToString, ), + 'save_server_model': grpc.unary_unary_rpc_method_handler( + servicer.save_server_model, + request_deserializer=nn__service__pb2.SaveModelRequest.FromString, + response_serializer=nn__service__pb2.SaveModelResponse.SerializeToString, + ), + 'load_server_model': grpc.unary_unary_rpc_method_handler( + servicer.load_server_model, + request_deserializer=nn__service__pb2.LoadModelRequest.FromString, + response_serializer=nn__service__pb2.LoadModelResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'nn.NNService', rpc_method_handlers) @@ -196,3 +228,37 @@ def upload_file(request_iterator, nn__service__pb2.UploadMetaResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def save_server_model(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/nn.NNService/save_server_model', + nn__service__pb2.SaveModelRequest.SerializeToString, + nn__service__pb2.SaveModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def load_server_model(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/nn.NNService/load_server_model', + nn__service__pb2.LoadModelRequest.SerializeToString, + nn__service__pb2.LoadModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/python/ppml/src/bigdl/ppml/fl/nn/nn_service.py b/python/ppml/src/bigdl/ppml/fl/nn/nn_service.py index 79a1ee0f963..50fe8183598 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/nn_service.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/nn_service.py @@ -22,7 +22,7 @@ import bigdl.ppml.fl.nn.tensorflow.aggregator as tf_agg from bigdl.ppml.fl.nn.generated.fl_base_pb2 import TensorMap -from bigdl.ppml.fl.nn.generated.nn_service_pb2 import TrainRequest, TrainResponse, PredictResponse, UploadMetaResponse +from bigdl.ppml.fl.nn.generated.nn_service_pb2 import LoadModelResponse, SaveModelResponse, TrainRequest, TrainResponse, PredictResponse, UploadMetaResponse from bigdl.ppml.fl.nn.generated.nn_service_pb2_grpc import * from bigdl.ppml.fl.nn.utils import tensor_map_to_ndarray_map import tensorflow as tf @@ -35,11 +35,11 @@ class NNServiceImpl(NNServiceServicer): - def __init__(self, client_num, **kargs) -> None: - self.client_num = client_num + def __init__(self, conf, **kargs) -> None: + self.client_num = conf['clientNum'] self.aggregator_map = { - 'tf': tf_agg.Aggregator(client_num, **kargs), - 'pt': pt_agg.Aggregator(client_num, **kargs)} + 'tf': tf_agg.Aggregator(conf, **kargs), + 'pt': pt_agg.Aggregator(conf, **kargs)} self.model_dir = tempfile.mkdtemp() # store tmp file dir self.model_path = os.path.join(self.model_dir, "vfl_server_model") @@ -131,3 +131,13 @@ def validate_client_id(self, client_id): if client_id <= 0 or client_id > self.client_num: invalidInputError(False, f"invalid client ID received: {client_id}, \ must be in range of client number [1, {self.client_num}]") + + def save_server_model(self, request, context): + aggregator = self.aggregator_map[request.backend] + aggregator.save_server_model(request.model_path) + return SaveModelResponse(message=f"Server model saved to {request.model_path}") + + def load_server_model(self, request, context): + aggregator = self.aggregator_map[request.backend] + aggregator.load_server_model(request.model_path) + return LoadModelResponse(message=f"Server model loaded from {request.model_path}") \ No newline at end of file diff --git a/python/ppml/src/bigdl/ppml/fl/nn/pytorch/aggregator.py b/python/ppml/src/bigdl/ppml/fl/nn/pytorch/aggregator.py index 4880b5d662f..9a47bb794df 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/pytorch/aggregator.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/pytorch/aggregator.py @@ -22,28 +22,29 @@ from bigdl.ppml.fl.nn.utils import ndarray_map_to_tensor_map from bigdl.dllib.utils.log4Error import invalidInputError from threading import Condition +import os class Aggregator(object): - def __init__(self, - client_num=1) -> None: + def __init__(self, conf) -> None: self.model = None self.client_data = {'train':{}, 'eval':{}, 'pred':{}} self.server_data = {'train':{}, 'eval':{}, 'pred':{}} - self.client_num = client_num + self.client_num = conf['clientNum'] self.condition = Condition() self._lock = threading.Lock() - logging.info(f"Initialized Pytorch aggregator [client_num: {client_num}]") - + self.optimizer_cls = None + self.optimizer_args = None + logging.info(f"Initialized Pytorch aggregator [client_num: {self.client_num}]") def set_meta(self, loss_fn, optimizer): with self._lock: self.set_loss_fn(loss_fn) optimizer_cls = pickle.loads(optimizer.cls) optimizer_args = pickle.loads(optimizer.args) + self.optimizer_cls, self.optimizer_args = optimizer_cls, optimizer_args self.set_optimizer(optimizer_cls, optimizer_args) - def set_loss_fn(self, loss_fn): self.loss_fn = loss_fn @@ -72,6 +73,7 @@ def put_client_data(self, client_id, data, phase): def aggregate(self, phase): + input, target = [], None # to record the order of tensors with client ID for cid, ndarray_map in self.client_data[phase].items(): @@ -112,7 +114,8 @@ def sort_by_key(kv_tuple): for cid, input_tensor in input: grad_map = {'grad': input_tensor.grad.numpy(), 'loss': loss.detach().numpy()} - self.server_data['train'][cid] = ndarray_map_to_tensor_map(grad_map) + self.server_data['train'][cid] = ndarray_map_to_tensor_map(grad_map) + elif phase == 'eval': pass elif phase == 'pred': @@ -123,3 +126,24 @@ def sort_by_key(kv_tuple): else: invalidInputError(False, f'Invalid phase: {phase}, should be train/eval/pred') + + def save_server_model(self, model_path): + if not os.path.exists(f"{model_path}/model.meta"): + os.makedirs(f"{model_path}", exist_ok=True) + with open(f"{model_path}/model.meta", 'wb') as meta_file: + pickle.dump({'loss': self.loss_fn, + 'optimizer': (self.optimizer_cls, self.optimizer_args)}, + meta_file) + m = torch.jit.script(self.model) + torch.jit.save(m, f"{model_path}/model.pt") + # save meta to file if not saved yet + + + def load_server_model(self, model_path): + logging.info(f"Trying to load model from {model_path}") + self.model = torch.jit.load(f"{model_path}/model.pt") + # if loaded, set meta here to make the optimizer bind the model + with open(f"{model_path}/model.meta", "rb") as meta_file: + meta = pickle.load(meta_file) + self.loss_fn = meta['loss'] + self.set_optimizer(meta['optimizer'][0], meta['optimizer'][1]) diff --git a/python/ppml/src/bigdl/ppml/fl/nn/pytorch/estimator.py b/python/ppml/src/bigdl/ppml/fl/nn/pytorch/estimator.py index dc8ea50fd71..ae578b032a6 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/pytorch/estimator.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/pytorch/estimator.py @@ -24,7 +24,8 @@ from bigdl.ppml.fl.nn.utils import file_chunk_generate, tensor_map_to_ndarray_map import os import tempfile -from torch.testing._internal.jit_utils import clear_class_registry + +from nn_service_pb2 import LoadModelRequest, SaveModelRequest @@ -38,18 +39,30 @@ def __init__(self, bigdl_type="float", target="localhost:8980", fl_client=None, - server_model=None): + server_model=None, + client_model_path=None, + server_model_path=None): self.bigdl_type = bigdl_type self.model = model self.loss_fn = loss_fn self.optimizer = optimizer_cls(model.parameters(), **optimizer_args) self.version = 0 + self.client_model_path = client_model_path + self.server_model_path = server_model_path self.fl_client = fl_client if fl_client is not None \ else FLClient(client_id=client_id, aggregator='pt', target=target) self.loss_history = [] if server_model is not None: self.__add_server_model(server_model, loss_fn, optimizer_cls, optimizer_args) + def save_server_model(self, model_path): + self.fl_client.nn_stub.save_server_model( + SaveModelRequest(model_path=model_path, backend='pt')) + + def load_server_model(self, model_path): + self.fl_client.nn_stub.load_server_model( + LoadModelRequest(model_path=model_path, backend='pt')) + @staticmethod def load_model_as_bytes(model): model_path = os.path.join(tempfile.mkdtemp(), "vfl_server_model") @@ -129,6 +142,10 @@ def fit(self, x, y=None, epoch=1, batch_size=4): else: invalidInputError(False, f'got unsupported data input type: {type(x)}') + if self.server_model_path is not None: + self.save_server_model(self.server_model_path) + if self.client_model_path is not None: + torch.save(self.model, self.client_model_path) def predict(self, x): diff --git a/python/ppml/src/bigdl/ppml/fl/nn/tensorflow/aggregator.py b/python/ppml/src/bigdl/ppml/fl/nn/tensorflow/aggregator.py index 526c8c521bc..ac56c9d6efe 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/tensorflow/aggregator.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/tensorflow/aggregator.py @@ -26,15 +26,14 @@ # TODO: tf and pytorch aggregator could be integrated to one using inherit class Aggregator(object): - def __init__(self, - client_num=1) -> None: + def __init__(self, conf) -> None: self.model = None self.client_data = {'train':{}, 'eval':{}, 'pred':{}} self.server_data = {'train':{}, 'eval':{}, 'pred':{}} - self.client_num = client_num + self.client_num = conf['clientNum'] self.condition = Condition() self._lock = threading.Lock() - logging.info(f"Initialized Tensorflow aggregator [client_num: {client_num}]") + logging.info(f"Initialized Tensorflow aggregator [client_num: {self.client_num}]") def set_meta(self, loss_fn, optimizer): diff --git a/python/ppml/src/bigdl/ppml/fl/nn/utils.py b/python/ppml/src/bigdl/ppml/fl/nn/utils.py index 8ca99429628..498dd43a118 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/utils.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/utils.py @@ -16,6 +16,7 @@ import logging import pickle +import stat import numpy as np from bigdl.ppml.fl.nn.generated.nn_service_pb2 import * @@ -32,6 +33,8 @@ def to_protobuf(self): args = pickle.dumps(self.args) return ClassAndArgs(cls=cls, args=args) + + import numpy as np from bigdl.dllib.utils.log4Error import invalidInputError from bigdl.ppml.fl.nn.generated.fl_base_pb2 import FloatTensor, TensorMap diff --git a/python/ppml/test/bigdl/ppml/fl/nn/pytorch/test_mnist.py b/python/ppml/test/bigdl/ppml/fl/nn/pytorch/test_mnist.py index 1e0eef11d70..56f700af157 100644 --- a/python/ppml/test/bigdl/ppml/fl/nn/pytorch/test_mnist.py +++ b/python/ppml/test/bigdl/ppml/fl/nn/pytorch/test_mnist.py @@ -38,7 +38,7 @@ resource_path = os.path.join(os.path.dirname(__file__), "../../resources") -class TestCorrectness(FLTest): +class TestMnist(FLTest): fmt = '%(asctime)s %(levelname)s {%(module)s:%(lineno)d} - %(message)s' logging.basicConfig(format=fmt, level=logging.INFO) def setUp(self) -> None: diff --git a/python/ppml/test/bigdl/ppml/fl/nn/pytorch/test_save_load.py b/python/ppml/test/bigdl/ppml/fl/nn/pytorch/test_save_load.py new file mode 100644 index 00000000000..5411ce547a2 --- /dev/null +++ b/python/ppml/test/bigdl/ppml/fl/nn/pytorch/test_save_load.py @@ -0,0 +1,197 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from multiprocessing import Process +import time +from typing import List +import unittest +import numpy as np +import pandas as pd +import os + +from bigdl.ppml.fl import * +from bigdl.ppml.fl.nn.fl_server import FLServer +from bigdl.ppml.fl.nn.fl_client import FLClient +from bigdl.ppml.fl.nn.pytorch.utils import set_one_like_parameter +from bigdl.ppml.fl.utils import init_fl_context +from bigdl.ppml.fl.estimator import Estimator + +from torch import Tensor, nn +import torch +from torch.utils.data import DataLoader +from torchvision import datasets +from torchvision.transforms import ToTensor +from bigdl.ppml.fl.utils import FLTest +import shutil + +resource_path = os.path.join(os.path.dirname(__file__), "../../resources") + + +class TestSaveLoad(FLTest): + fmt = '%(asctime)s %(levelname)s {%(module)s:%(lineno)d} - %(message)s' + logging.basicConfig(format=fmt, level=logging.INFO) + server_model_path = '/tmp/vfl_server_model' + client_model_path = '/tmp/vfl_client_model' + def setUp(self) -> None: + self.fl_server = FLServer(client_num=1) + self.fl_server.set_port(self.port) + self.fl_server.build() + self.fl_server.start() + + def tearDown(self) -> None: + self.fl_server.stop() + if os.path.exists(TestSaveLoad.server_model_path): + shutil.rmtree(TestSaveLoad.server_model_path) + if os.path.exists(TestSaveLoad.client_model_path): + os.remove(TestSaveLoad.client_model_path) + + def test_mnist(self) -> None: + """ + following code is copied from pytorch quick start + link: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html + """ + + training_data = datasets.FashionMNIST( + root="data", + train=True, + download=True, + transform=ToTensor(), + ) + + # Download test data from open datasets. + test_data = datasets.FashionMNIST( + root="data", + train=False, + download=True, + transform=ToTensor(), + ) + batch_size = 64 + + # Create data loaders. + train_dataloader = DataLoader(training_data, batch_size=batch_size) + test_dataloader = DataLoader(test_data, batch_size=batch_size) + + for X, y in test_dataloader: + print(f"Shape of X [N, C, H, W]: {X.shape}") + print(f"Shape of y: {y.shape} {y.dtype}") + break + + model = NeuralNetwork() + loss_fn = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + + # list for result validation + pytorch_loss_list = [] + def train(dataloader, model, loss_fn, optimizer): + size = len(dataloader.dataset) + model.train() + for batch, (X, y) in enumerate(dataloader): + # Compute prediction error + pred = model(X) + loss = loss_fn(pred, y) + + # Backpropagation + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if batch % 100 == 0: + loss, current = loss.item(), batch * len(X) + pytorch_loss_list.append(np.array(loss)) + print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") + + # for i in range(2): + # train(train_dataloader, model, loss_fn, optimizer) + vfl_model_1 = NeuralNetworkPart1() + vfl_model_2 = NeuralNetworkPart2() + vfl_client_ppl = Estimator.from_torch(client_model=vfl_model_1, + client_id="1", + loss_fn=loss_fn, + optimizer_cls=torch.optim.SGD, + optimizer_args={'lr':1e-3}, + target=self.target, + server_model=vfl_model_2, + server_model_path=TestSaveLoad.server_model_path, + client_model_path=TestSaveLoad.client_model_path) + vfl_client_ppl.fit(train_dataloader) + self.fl_server.stop() + self.setUp() + client_model_loaded = torch.load(TestSaveLoad.client_model_path) + ppl_from_file = Estimator.from_torch(client_model=client_model_loaded, + client_id="1", + loss_fn=loss_fn, + optimizer_cls=torch.optim.SGD, + optimizer_args={'lr':1e-3}, + target=self.target) + ppl_from_file.load_server_model(TestSaveLoad.server_model_path) + ppl_from_file.fit(train_dataloader) + + assert ppl_from_file.loss_history[-1] < 2, \ + f"Validation failed, incremental training loss does not meet requirement, \ + required < 2, current {ppl_from_file.loss_history[-1]}" + + +class NeuralNetwork(nn.Module): + def __init__(self): + super(NeuralNetwork, self).__init__() + self.flatten = nn.Flatten() + self.sequential_1 = nn.Sequential( + nn.Linear(28*28, 512), + nn.ReLU() + ) + self.sequential_2 = nn.Sequential( + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 10) + ) + + def forward(self, x): + x = self.flatten(x) + x = self.sequential_1(x) + x = self.sequential_2(x) + return x + +class NeuralNetworkPart1(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.sequential_1 = nn.Sequential( + nn.Linear(28*28, 512), + nn.ReLU() + ) + + def forward(self, x): + x = self.flatten(x) + x = self.sequential_1(x) + return x + +class NeuralNetworkPart2(nn.Module): + def __init__(self): + super().__init__() + self.sequential_2 = nn.Sequential( + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 10) + ) + + def forward(self, x: List[Tensor]): + x = x[0] # this act as interactive layer, take the first tensor + x = self.sequential_2(x) + return x + + +if __name__ == '__main__': + unittest.main() diff --git a/scala/ppml/src/main/proto/nn_service.proto b/scala/ppml/src/main/proto/nn_service.proto index d131c160527..4062b38b444 100644 --- a/scala/ppml/src/main/proto/nn_service.proto +++ b/scala/ppml/src/main/proto/nn_service.proto @@ -26,6 +26,8 @@ service NNService { rpc predict(PredictRequest) returns (PredictResponse) {} rpc upload_meta(UploadMetaRequest) returns (UploadMetaResponse) {} rpc upload_file(stream ByteChunk) returns (UploadMetaResponse) {} + rpc save_server_model(SaveModelRequest) returns (SaveModelResponse) {} + rpc load_server_model(LoadModelRequest) returns (LoadModelResponse) {} } message TrainRequest { @@ -81,4 +83,21 @@ message UploadMetaResponse { string message = 1; } +message LoadModelRequest { + string client_id = 1; + string backend = 2; + string model_path = 3; +} + +message LoadModelResponse { + string message = 1; +} +message SaveModelRequest { + string client_id = 1; + string backend = 2; + string model_path = 3; +} +message SaveModelResponse { + string message = 1; +}