Skip to content

Commit

Permalink
add psi in nn tutorial (intel-analytics#5532)
Browse files Browse the repository at this point in the history
  • Loading branch information
Litchilitchy authored and ForJadeForest committed Sep 20, 2022
1 parent 2c2cdd9 commit a5717c4
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
8 changes: 7 additions & 1 deletion python/ppml/example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ We use [Diabetes](https://www.kaggle.com/competitions/house-prices-advanced-regr
The code is available in projects, including [Client 1 code](fgboost_regression_party_1.py) and [Client 2 code](fgboost_regression_party_2.py). You could directly start two different terminals are run them respectively to start a federated learning, and the order of start does not matter. Following is the detailed step-by-step tutorial to introduce how the code works.

### 2.1 Private Set Intersection
// TODO: add this section after Python version of PSI is done
We first need to get the intersection of datasets across parties by Private Set Intersection algorithm.
```python
df_train['ID'] = df_train['ID'].astype(str)
psi = PSI()
intersection = psi.get_intersection(list(df_train['ID']))
df_train = df_train[df_train['ID'].isin(intersection)]
```

### 2.2 Data Preprocessing
Since one party owns label data while another not, different operations should be done before training.
Expand Down
13 changes: 5 additions & 8 deletions python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
from torch import Tensor, nn
from bigdl.ppml.fl.estimator import Estimator
from bigdl.ppml.fl.psi.psi import PSI


class LocalModel(nn.Module):
Expand Down Expand Up @@ -50,16 +51,12 @@ def forward(self, x: List[Tensor]):
@click.command()
@click.option('--load_model', default=False)
def run_client(load_model):
# fl_server = FLServer(2)
# fl_server.build()
# fl_server.start()
df_train = pd.read_csv('.data/diabetes-vfl-1.csv')

# this should wait for the merge of 2 FLServer (Py4J Java gRPC and Python gRPC)
# df_train['ID'] = df_train['ID'].astype(str)
# psi = PSI()
# intersection = psi.get_intersection(list(df_train['ID']))
# df_train = df_train[df_train['ID'].isin(intersection)]
df_train['ID'] = df_train['ID'].astype(str)
psi = PSI()
intersection = psi.get_intersection(list(df_train['ID']))
df_train = df_train[df_train['ID'].isin(intersection)]

df_x = df_train.drop('Outcome', 1)
df_y = df_train['Outcome']
Expand Down
10 changes: 5 additions & 5 deletions python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from torch import nn
from bigdl.ppml.fl.estimator import Estimator
from bigdl.ppml.fl.psi.psi import PSI


class LocalModel(nn.Module):
Expand All @@ -38,11 +39,10 @@ def forward(self, x):
def run_client(load_model):
df_train = pd.read_csv('.data/diabetes-vfl-2.csv')

# this should wait for the merge of 2 FLServer (Py4J Java gRPC and Python gRPC)
# df_train['ID'] = df_train['ID'].astype(str)
# psi = PSI()
# intersection = psi.get_intersection(list(df_train['ID']))
# df_train = df_train[df_train['ID'].isin(intersection)]
df_train['ID'] = df_train['ID'].astype(str)
psi = PSI()
intersection = psi.get_intersection(list(df_train['ID']))
df_train = df_train[df_train['ID'].isin(intersection)]

df_x = df_train
x = df_x.to_numpy(dtype="float32")
Expand Down
6 changes: 3 additions & 3 deletions python/ppml/src/bigdl/ppml/fl/nn/fl_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from bigdl.ppml.fl.nn.nn_service import NNServiceImpl
import yaml

from ..psi.psi_service import PSIServiceImpl
from .generated.nn_service_pb2_grpc import *
from .generated.psi_service_pb2_grpc import *
from bigdl.ppml.fl.psi.psi_service import PSIServiceImpl
from bigdl.ppml.fl.nn.generated.nn_service_pb2_grpc import *
from bigdl.ppml.fl.nn.generated.psi_service_pb2_grpc import *



Expand Down

0 comments on commit a5717c4

Please sign in to comment.