From ee493aacbfaa5ac4e85723f43be94265e4e0af2b Mon Sep 17 00:00:00 2001 From: Jin Hanyu <476099001@qq.com> Date: Mon, 15 Aug 2022 17:55:52 +0800 Subject: [PATCH] Enable FL Server in SGX --- .../example/pytorch_nn_lr/pytorch_nn_lr_1.py | 2 +- .../example/pytorch_nn_lr/pytorch_nn_lr_2.py | 2 +- python/ppml/scripts/setup-env.sh | 2 +- python/ppml/scripts/start-fl-server.py | 43 ++++++++++++++++++- python/ppml/src/bigdl/ppml/fl/nn/fl_server.py | 4 +- .../nn/generated/fgboost_service_pb2_grpc.py | 16 +++---- .../fl/nn/generated/psi_service_pb2_grpc.py | 7 ++- scala/ppml/pom.xml | 14 ------ 8 files changed, 55 insertions(+), 35 deletions(-) diff --git a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py index b69522174c9a..98b7a4aa446a 100644 --- a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py +++ b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py @@ -52,7 +52,7 @@ def forward(self, x: List[Tensor]): # fl_server = FLServer(2) # fl_server.build() # fl_server.start() - df_train = pd.read_csv('./python/ppml/example/pytorch_nn_lr/data/diabetes-vfl-1.csv') + df_train = pd.read_csv('.data/diabetes-vfl-1.csv') # this should wait for the merge of 2 FLServer (Py4J Java gRPC and Python gRPC) # df_train['ID'] = df_train['ID'].astype(str) diff --git a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py index e3aa9c2dc29f..f8b4004a3ab2 100644 --- a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py +++ b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py @@ -35,7 +35,7 @@ def forward(self, x): if __name__ == '__main__': - df_train = pd.read_csv('./python/ppml/example/pytorch_nn_lr/data/diabetes-vfl-2.csv') + df_train = pd.read_csv('.data/diabetes-vfl-2.csv') # this should wait for the merge of 2 FLServer (Py4J Java gRPC and Python gRPC) # df_train['ID'] = df_train['ID'].astype(str) diff --git a/python/ppml/scripts/setup-env.sh b/python/ppml/scripts/setup-env.sh index 314917927746..f82f58f237da 100644 --- a/python/ppml/scripts/setup-env.sh +++ b/python/ppml/scripts/setup-env.sh @@ -2,4 +2,4 @@ PYTHON_ZIP=$(find lib -name *-python-api.zip) JAR=$(find lib -name *-jar-with-dependencies.jar) export PYTHONPATH=$PYTHONPATH:$(pwd)/$PYTHON_ZIP export PYTHONPATH=$PYTHONPATH:$(pwd)/$PYTHON_ZIP/bigdl/ppml/fl/nn/generated -export BIGDL_CLASSPATH=$JAR \ No newline at end of file +export BIGDL_CLASSPATH=$JAR diff --git a/python/ppml/scripts/start-fl-server.py b/python/ppml/scripts/start-fl-server.py index fe0e4396d7f9..c72ac47ec34f 100644 --- a/python/ppml/scripts/start-fl-server.py +++ b/python/ppml/scripts/start-fl-server.py @@ -14,10 +14,49 @@ # limitations under the License. # +import sys +import os +import fnmatch +import getopt + +for files in os.listdir('/ppml/trusted-big-data-ml/work/bigdl-2.1.0-SNAPSHOT/python/'): + if fnmatch.fnmatch(files, 'bigdl-ppml-*-python-api.zip'): + sys.path.append('/ppml/trusted-big-data-ml/work/bigdl-2.1.0-SNAPSHOT/python/' + files) + sys.path.append('/ppml/trusted-big-data-ml/work/bigdl-2.1.0-SNAPSHOT/python/' + files + '/bigdl/ppml/fl/nn/generated') + +if '/usr/lib/python3.6' in sys.path: + sys.path.remove('/usr/lib/python3.6') +if '/usr/lib/python3.6/lib-dynload' in sys.path: + sys.path.remove('/usr/lib/python3.6/lib-dynload') +if '/usr/local/lib/python3.6/dist-packages' in sys.path: + sys.path.remove('/usr/local/lib/python3.6/dist-packages') +if '/usr/lib/python3/dist-packages' in sys.path: + sys.path.remove('/usr/lib/python3/dist-packages') + from bigdl.ppml.fl.nn.fl_server import FLServer if __name__ == '__main__': - fl_server = FLServer() + + client_num = 2 + port = 8980 + + try: + opts, args = getopt.getopt(sys.argv[1:], "hc:p:", ["client-num=", "port="]) + except getopt.GetoptError: + print("start_fl_server.py -c -p ") + sys.exit(2) + + for opt, arg in opts: + if opt == '-h': + print("start_fl_server.py -c -p ") + elif opt in ("-c", "--client-num"): + client_num = arg + elif opt in ("-p", "--port"): + port = arg + + fl_server = FLServer(client_num) + fl_server.set_port(port) fl_server.build() fl_server.start() - fl_server.wait_for_termination() \ No newline at end of file + + fl_server.wait_for_termination() diff --git a/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py b/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py index cfe291bf3628..a62a2a7962eb 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py @@ -65,8 +65,6 @@ def load_config(self): ( (private_key, certificate_chain), ) ) if 'serverPort' in conf: self.port = conf['serverPort'] - if 'clientNum' in conf: - self.client_num = conf['clientNum'] except yaml.YAMLError as e: logging.warn('Loading config failed, using default config ') @@ -81,4 +79,4 @@ def wait_for_termination(self): fl_server = FLServer(2) fl_server.build() fl_server.start() - fl_server.wait_for_termination() \ No newline at end of file + fl_server.wait_for_termination() diff --git a/python/ppml/src/bigdl/ppml/fl/nn/generated/fgboost_service_pb2_grpc.py b/python/ppml/src/bigdl/ppml/fl/nn/generated/fgboost_service_pb2_grpc.py index 7c8546245b57..ddfcf12c06cb 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/generated/fgboost_service_pb2_grpc.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/generated/fgboost_service_pb2_grpc.py @@ -3,8 +3,6 @@ import grpc import fgboost_service_pb2 as fgboost__service__pb2 -from bigdl.dllib.utils.log4Error import invalidInputError - class FGBoostServiceStub(object): """Missing associated documentation comment in .proto file.""" @@ -59,43 +57,43 @@ def uploadLabel(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') - invalidInputError(False, 'Method not implemented!') + raise NotImplementedError('Method not implemented!') def downloadLabel(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') - invalidInputError(False, 'Method not implemented!') + raise NotImplementedError('Method not implemented!') def split(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') - invalidInputError(False, 'Method not implemented!') + raise NotImplementedError('Method not implemented!') def register(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') - invalidInputError(False, 'Method not implemented!') + raise NotImplementedError('Method not implemented!') def uploadTreeLeaf(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') - invalidInputError(False, 'Method not implemented!') + raise NotImplementedError('Method not implemented!') def evaluate(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') - invalidInputError(False, 'Method not implemented!') + raise NotImplementedError('Method not implemented!') def predict(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') - invalidInputError(False, 'Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_FGBoostServiceServicer_to_server(servicer, server): diff --git a/python/ppml/src/bigdl/ppml/fl/nn/generated/psi_service_pb2_grpc.py b/python/ppml/src/bigdl/ppml/fl/nn/generated/psi_service_pb2_grpc.py index 226393cba353..9916c0356fbd 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/generated/psi_service_pb2_grpc.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/generated/psi_service_pb2_grpc.py @@ -3,7 +3,6 @@ import grpc import psi_service_pb2 as psi__service__pb2 -from bigdl.dllib.utils.log4Error import invalidInputError class PSIServiceStub(object): @@ -40,19 +39,19 @@ def getSalt(self, request, context): """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') - invalidInputError(False, 'Method not implemented!') + raise NotImplementedError('Method not implemented!') def uploadSet(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') - invalidInputError(False, 'Method not implemented!') + raise NotImplementedError('Method not implemented!') def downloadIntersection(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') - invalidInputError(False, 'Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_PSIServiceServicer_to_server(servicer, server): diff --git a/scala/ppml/pom.xml b/scala/ppml/pom.xml index 8e0ad945c375..61c22b446469 100644 --- a/scala/ppml/pom.xml +++ b/scala/ppml/pom.xml @@ -414,20 +414,6 @@ single - - assembly - false - package - - single - - - - ${project.basedir}/src/assembly/ppml-assembly.xml - - - -