Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable FL Server in SGX #5396

Merged
merged 11 commits into from
Aug 22, 2022
4 changes: 3 additions & 1 deletion ppml/trusted-big-data-ml/python/docker-graphene/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ RUN apt-get update --fix-missing && \
RUN wget https://raw.githubusercontent.com/intel-analytics/analytics-zoo/bigdl-2.0/docker/hyperzoo/download-bigdl.sh && \
chmod a+x ./download-bigdl.sh
RUN ./download-bigdl.sh && \
rm bigdl*.zip
rm bigdl*.zip && \
cp ${BIGDL_HOME}/python/start-fl-server.py /ppml/trusted-big-data-ml/fl

# stage.4 ppml
FROM ubuntu:20.04
Expand Down Expand Up @@ -323,3 +324,4 @@ ADD azure /ppml/trusted-big-data-ml/azure
WORKDIR /ppml/trusted-big-data-ml

ENTRYPOINT [ "/opt/entrypoint.sh" ]

Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

port=8980
client_num=2

while getopts "p:c:" opt
do
case $opt in
p)
port=$OPTARG
;;
c)
client_num=$OPTARG
;;
esac
done
cd /ppml/trusted-big-data-ml
/graphene/Tools/argv_serializer bash -c " /opt/jdk8/bin/java\
-cp '/ppml/trusted-big-data-ml/work/spark-3.1.2/conf/:/ppml/trusted-big-data-ml/work/spark-3.1.2/jars/*'\
-Xmx10g org.apache.spark.deploy.SparkSubmit\
--master 'local[4]'\
/ppml/trusted-big-data-ml/fl/start-fl-server.py -p $port -c $client_num" > /ppml/trusted-big-data-ml/secured-argvs
./init.sh
SGX=1 ./pal_loader bash 2>&1 | tee fl-server-sgx.log

18 changes: 6 additions & 12 deletions python/ppml/example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,14 @@ result = ppl.predict(x)
## 3 Run FGBoost
FL Server is required before running any federated applications. Check [Start FL Server]() section for details.
### 3.1 Start FL Server in SGX
// TODO: add this section after running FL Server in SGX succesfully in this example.

Modify the config file `ppml-conf.yaml`
```yaml
# the port server gRPC uses
serverPort: 8980

# the number of clients in this federated learning application
clientNum: 2
```
Then start the FL Server
#### 3.1.1 Start the container
Before running FL Server in SGX, please prepaer keys and start the BigDL PPML container first. Check [3.1 BigDL PPML Hello World](https://github.com/intel-analytics/BigDL/tree/main/ppml#31-bigdl-ppml-hello-world) for details.
#### 3.1.2 Run FL Server in SGX
You can run FL Server in SGX with the following command:
```bash
python BigDL/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py
bash start-python-fl-server-sgx.sh -p 8980 -c 2
```
You can set port with `-p` and set client number with `-c` while the default settings are `port=8980` and `client-num=2`.
### 3.2 Start FGBoost Clients
Modify the config file `ppml-conf.yaml`
```yaml
Expand Down
2 changes: 1 addition & 1 deletion python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def forward(self, x: List[Tensor]):
# fl_server = FLServer(2)
# fl_server.build()
# fl_server.start()
df_train = pd.read_csv('./python/ppml/example/pytorch_nn_lr/data/diabetes-vfl-1.csv')
df_train = pd.read_csv('.data/diabetes-vfl-1.csv')

# this should wait for the merge of 2 FLServer (Py4J Java gRPC and Python gRPC)
# df_train['ID'] = df_train['ID'].astype(str)
Expand Down
2 changes: 1 addition & 1 deletion python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def forward(self, x):


if __name__ == '__main__':
df_train = pd.read_csv('./python/ppml/example/pytorch_nn_lr/data/diabetes-vfl-2.csv')
df_train = pd.read_csv('.data/diabetes-vfl-2.csv')

# this should wait for the merge of 2 FLServer (Py4J Java gRPC and Python gRPC)
# df_train['ID'] = df_train['ID'].astype(str)
Expand Down
9 changes: 9 additions & 0 deletions python/ppml/scripts/ppml-conf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
##### Server property
# clientNum: 2
# privateKeyFilePath: /ppml/trusted-big-data-ml/work/keys/server.pem
# certChainFilePath: /ppml/trusted-big-data-ml/work/keys/server.crt
# serverPort:

##### Client property
# clientTarget:
# taskID:
5 changes: 5 additions & 0 deletions python/ppml/scripts/setup-env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
PYTHON_ZIP=$(find lib -name *-python-api.zip)
JAR=$(find lib -name *-jar-with-dependencies.jar)
export PYTHONPATH=$PYTHONPATH:$(pwd)/$PYTHON_ZIP
export PYTHONPATH=$PYTHONPATH:$(pwd)/$PYTHON_ZIP/bigdl/ppml/fl/nn/generated
export BIGDL_CLASSPATH=$JAR
23 changes: 23 additions & 0 deletions python/ppml/scripts/start-fgboost-server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from bigdl.ppml.fl.fl_server import FLServer

if __name__ == '__main__':
fl_server = FLServer()
fl_server.build()
fl_server.start()
fl_server.wait_for_termination()
62 changes: 62 additions & 0 deletions python/ppml/scripts/start-fl-server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys
import os
import fnmatch
import getopt

for files in os.listdir('/ppml/trusted-big-data-ml/work/bigdl-2.1.0-SNAPSHOT/python/'):
if fnmatch.fnmatch(files, 'bigdl-ppml-*-python-api.zip'):
sys.path.append('/ppml/trusted-big-data-ml/work/bigdl-2.1.0-SNAPSHOT/python/' + files)
sys.path.append('/ppml/trusted-big-data-ml/work/bigdl-2.1.0-SNAPSHOT/python/' + files + '/bigdl/ppml/fl/nn/generated')

if '/usr/lib/python3.6' in sys.path:
sys.path.remove('/usr/lib/python3.6')
if '/usr/lib/python3.6/lib-dynload' in sys.path:
sys.path.remove('/usr/lib/python3.6/lib-dynload')
if '/usr/local/lib/python3.6/dist-packages' in sys.path:
sys.path.remove('/usr/local/lib/python3.6/dist-packages')
if '/usr/lib/python3/dist-packages' in sys.path:
sys.path.remove('/usr/lib/python3/dist-packages')

from bigdl.ppml.fl.nn.fl_server import FLServer

if __name__ == '__main__':

client_num = 2
port = 8980

try:
opts, args = getopt.getopt(sys.argv[1:], "hc:p:", ["client-num=", "port="])
except getopt.GetoptError:
print("start_fl_server.py -c <client-num> -p <port>")
sys.exit(2)

for opt, arg in opts:
if opt == '-h':
print("start_fl_server.py -c <client-num> -p <port>")
elif opt in ("-c", "--client-num"):
client_num = arg
elif opt in ("-p", "--port"):
port = arg

fl_server = FLServer(client_num)
fl_server.set_port(port)
fl_server.build()
fl_server.start()

fl_server.wait_for_termination()
6 changes: 3 additions & 3 deletions python/ppml/src/bigdl/ppml/fl/nn/fl_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
from bigdl.ppml.fl.nn.generated.nn_service_pb2_grpc import *
from bigdl.ppml.fl.nn.nn_service import NNServiceImpl
import yaml

import logging


class FLServer(object):
def __init__(self, client_num=1):
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=5))
self.port = 8980 # TODO: set from config file
self.port = 8980
self.client_num = client_num
self.secure = False
self.load_config()
Expand Down Expand Up @@ -79,4 +79,4 @@ def wait_for_termination(self):
fl_server = FLServer(2)
fl_server.build()
fl_server.start()
fl_server.wait_for_termination()
fl_server.wait_for_termination()
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import grpc

import fgboost_service_pb2 as fgboost__service__pb2
from bigdl.dllib.utils.log4Error import invalidInputError


class FGBoostServiceStub(object):
"""Missing associated documentation comment in .proto file."""
Expand Down Expand Up @@ -59,43 +57,43 @@ def uploadLabel(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
invalidInputError(False, 'Method not implemented!')
raise NotImplementedError('Method not implemented!')

def downloadLabel(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
invalidInputError(False, 'Method not implemented!')
raise NotImplementedError('Method not implemented!')

def split(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
invalidInputError(False, 'Method not implemented!')
raise NotImplementedError('Method not implemented!')

def register(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
invalidInputError(False, 'Method not implemented!')
raise NotImplementedError('Method not implemented!')

def uploadTreeLeaf(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
invalidInputError(False, 'Method not implemented!')
raise NotImplementedError('Method not implemented!')

def evaluate(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
invalidInputError(False, 'Method not implemented!')
raise NotImplementedError('Method not implemented!')

def predict(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
invalidInputError(False, 'Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_FGBoostServiceServicer_to_server(servicer, server):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import grpc

import psi_service_pb2 as psi__service__pb2
from bigdl.dllib.utils.log4Error import invalidInputError


class PSIServiceStub(object):
Expand Down Expand Up @@ -40,19 +39,19 @@ def getSalt(self, request, context):
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
invalidInputError(False, 'Method not implemented!')
raise NotImplementedError('Method not implemented!')

def uploadSet(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
invalidInputError(False, 'Method not implemented!')
raise NotImplementedError('Method not implemented!')

def downloadIntersection(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
invalidInputError(False, 'Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_PSIServiceServicer_to_server(servicer, server):
Expand Down
8 changes: 8 additions & 0 deletions scala/assembly/src/main/assembly/assembly.xml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@
<include>bigdl-ppml*-jar-with-dependencies.jar</include>
</includes>
</fileSet>
<fileSet>
<outputDirectory>/python</outputDirectory>
<directory>${project.parent.basedir}/../../../../../python/ppml/scripts</directory>
<includes>
<include>start-fl-server.py</include>
</includes>
</fileSet>
</fileSets>
<dependencySets>
<dependencySet>
Expand Down Expand Up @@ -176,3 +183,4 @@
</dependencySet>
</dependencySets>
</assembly>