From a5717c43d4c12fb3f0e92c42416ad5ba99363379 Mon Sep 17 00:00:00 2001 From: Song Jiaming Date: Fri, 26 Aug 2022 09:25:09 +0800 Subject: [PATCH] add psi in nn tutorial (#5532) --- .../example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md | 8 +++++++- .../ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py | 13 +++++-------- .../ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py | 10 +++++----- python/ppml/src/bigdl/ppml/fl/nn/fl_server.py | 6 +++--- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/python/ppml/example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md b/python/ppml/example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md index 3e0bdb92bfc..74e012a7310 100644 --- a/python/ppml/example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md +++ b/python/ppml/example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md @@ -21,7 +21,13 @@ We use [Diabetes](https://www.kaggle.com/competitions/house-prices-advanced-regr The code is available in projects, including [Client 1 code](fgboost_regression_party_1.py) and [Client 2 code](fgboost_regression_party_2.py). You could directly start two different terminals are run them respectively to start a federated learning, and the order of start does not matter. Following is the detailed step-by-step tutorial to introduce how the code works. ### 2.1 Private Set Intersection -// TODO: add this section after Python version of PSI is done +We first need to get the intersection of datasets across parties by Private Set Intersection algorithm. +```python +df_train['ID'] = df_train['ID'].astype(str) +psi = PSI() +intersection = psi.get_intersection(list(df_train['ID'])) +df_train = df_train[df_train['ID'].isin(intersection)] +``` ### 2.2 Data Preprocessing Since one party owns label data while another not, different operations should be done before training. diff --git a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py index e760f42d7ed..c44903e513d 100644 --- a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py +++ b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py @@ -22,6 +22,7 @@ import torch from torch import Tensor, nn from bigdl.ppml.fl.estimator import Estimator +from bigdl.ppml.fl.psi.psi import PSI class LocalModel(nn.Module): @@ -50,16 +51,12 @@ def forward(self, x: List[Tensor]): @click.command() @click.option('--load_model', default=False) def run_client(load_model): - # fl_server = FLServer(2) - # fl_server.build() - # fl_server.start() df_train = pd.read_csv('.data/diabetes-vfl-1.csv') - # this should wait for the merge of 2 FLServer (Py4J Java gRPC and Python gRPC) - # df_train['ID'] = df_train['ID'].astype(str) - # psi = PSI() - # intersection = psi.get_intersection(list(df_train['ID'])) - # df_train = df_train[df_train['ID'].isin(intersection)] + df_train['ID'] = df_train['ID'].astype(str) + psi = PSI() + intersection = psi.get_intersection(list(df_train['ID'])) + df_train = df_train[df_train['ID'].isin(intersection)] df_x = df_train.drop('Outcome', 1) df_y = df_train['Outcome'] diff --git a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py index 29a50d36e3b..466fd14d2e9 100644 --- a/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py +++ b/python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py @@ -21,6 +21,7 @@ import torch from torch import nn from bigdl.ppml.fl.estimator import Estimator +from bigdl.ppml.fl.psi.psi import PSI class LocalModel(nn.Module): @@ -38,11 +39,10 @@ def forward(self, x): def run_client(load_model): df_train = pd.read_csv('.data/diabetes-vfl-2.csv') - # this should wait for the merge of 2 FLServer (Py4J Java gRPC and Python gRPC) - # df_train['ID'] = df_train['ID'].astype(str) - # psi = PSI() - # intersection = psi.get_intersection(list(df_train['ID'])) - # df_train = df_train[df_train['ID'].isin(intersection)] + df_train['ID'] = df_train['ID'].astype(str) + psi = PSI() + intersection = psi.get_intersection(list(df_train['ID'])) + df_train = df_train[df_train['ID'].isin(intersection)] df_x = df_train x = df_x.to_numpy(dtype="float32") diff --git a/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py b/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py index fa9517ab08c..209c81cf830 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py @@ -21,9 +21,9 @@ from bigdl.ppml.fl.nn.nn_service import NNServiceImpl import yaml -from ..psi.psi_service import PSIServiceImpl -from .generated.nn_service_pb2_grpc import * -from .generated.psi_service_pb2_grpc import * +from bigdl.ppml.fl.psi.psi_service import PSIServiceImpl +from bigdl.ppml.fl.nn.generated.nn_service_pb2_grpc import * +from bigdl.ppml.fl.nn.generated.psi_service_pb2_grpc import *