Skip to content

Commit

Permalink
PPML add pytorch incremental training and tutorial (intel-analytics#5379
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Litchilitchy authored and ForJadeForest committed Sep 20, 2022
1 parent 79aa6e6 commit f765583
Show file tree
Hide file tree
Showing 15 changed files with 528 additions and 58 deletions.
48 changes: 46 additions & 2 deletions python/ppml/example/pytorch_nn_lr/pytorch-nn-lr-tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,39 @@ Then call `fit` method to train
```python
response = ppl.fit(x, y)
```
### 2.7 Predict

### 2.6 Predict
```python
result = ppl.predict(x)
```

### 2.7 Save/Load
After training, save the client and server model by
```python
torch.save(ppl.model, model_path)
ppl.save_server_model(server_model_path)
```
To start a new application to continue training
```python
client_model = torch.load(model_path)
# we do not pass server model this time, instead, we load it directly from server machine
ppl = Estimator.from_torch(client_model=model,
client_id=client_id,
loss_fn=loss_fn,
optimizer_cls=torch.optim.SGD,
optimizer_args={'lr':1e-3},
target='localhost:8980')
ppl.load_server_model(server_model_path)

## 3 Run FGBoost
FL Server is required before running any federated applications. Check [Start FL Server]() section for details.
### 3.1 Start FL Server in SGX

#### 3.1.1 Start the container
Before running FL Server in SGX, please prepaer keys and start the BigDL PPML container first. Check [3.1 BigDL PPML Hello World](https://github.com/intel-analytics/BigDL/tree/main/ppml#31-bigdl-ppml-hello-world) for details.
Before running FL Server in SGX, please prepare keys and start the BigDL PPML container first. Check [3.1 BigDL PPML Hello World](https://github.com/intel-analytics/BigDL/tree/main/ppml#31-bigdl-ppml-hello-world) for details.
#### 3.1.2 Run FL Server in SGX
You can run FL Server in SGX with the following command:

```bash
bash start-python-fl-server-sgx.sh -p 8980 -c 2
```
Expand Down Expand Up @@ -129,3 +150,26 @@ The first 5 predict results are printed
[1.2120417e-23]
[0.0000000e+00]]
```
### 3.4 Incremental Training
Incremental training is supported, we just need to use the same configurations and start FL Server again.

In SGX container, start FL Server
```
./ppml/scripts/start-fl-server.sh
```
For client applications, we change from creating model to directly loading. This is already implemented in example code, we just need to run client applications with an argument

```bash
# run following commands in 2 different terminals
python pytorch_nn_lr_1.py true
python pytorch_nn_lr_2.py true
```
The result based on new boosted trees are printed
```
[[1.8799074e-36]
[1.7512805e-25]
[4.6501680e-30]
[1.4828590e-27]
[0.0000000e+00]]
```
and you can see the loss continues to drop from the log of [Section 3.3](#33-get-results)
49 changes: 35 additions & 14 deletions python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
from typing import List
import numpy as np
import pandas as pd
import click

import torch
from torch import Tensor, nn
from bigdl.ppml.fl.estimator import Estimator
from bigdl.ppml.fl.algorithms.psi import PSI
from bigdl.ppml.fl.nn.fl_server import FLServer
from bigdl.ppml.fl.nn.pytorch.utils import set_one_like_parameter


class LocalModel(nn.Module):
Expand All @@ -48,7 +46,10 @@ def forward(self, x: List[Tensor]):
return x


if __name__ == '__main__':

@click.command()
@click.option('--load_model', default=False)
def run_client(load_model):
# fl_server = FLServer(2)
# fl_server.build()
# fl_server.start()
Expand All @@ -66,16 +67,36 @@ def forward(self, x: List[Tensor]):
x = df_x.to_numpy(dtype="float32")
y = np.expand_dims(df_y.to_numpy(dtype="float32"), axis=1)

model = LocalModel(len(df_x.columns))
loss_fn = nn.BCELoss()
server_model = ServerModel()
ppl = Estimator.from_torch(client_model=model,
client_id='1',
loss_fn=loss_fn,
optimizer_cls=torch.optim.SGD,
optimizer_args={'lr':1e-5},
target='localhost:8980',
server_model=server_model)
response = ppl.fit(x, y)

if load_model:
model = torch.load('/tmp/pytorch_client_model_1.pt')
ppl = Estimator.from_torch(client_model=model,
client_id='1',
loss_fn=loss_fn,
optimizer_cls=torch.optim.SGD,
optimizer_args={'lr':1e-5},
target='localhost:8980',
server_model_path='/tmp/pytorch_server_model',
client_model_path='/tmp/pytorch_client_model_1.pt')
ppl.load_server_model('/tmp/pytorch_server_model')
response = ppl.fit(x, y, 5)
else:
model = LocalModel(len(df_x.columns))

server_model = ServerModel()
ppl = Estimator.from_torch(client_model=model,
client_id='1',
loss_fn=loss_fn,
optimizer_cls=torch.optim.SGD,
optimizer_args={'lr':1e-5},
target='localhost:8980',
server_model=server_model,
server_model_path='/tmp/pytorch_server_model',
client_model_path='/tmp/pytorch_client_model_1.pt')
response = ppl.fit(x, y, 5)
result = ppl.predict(x)
print(result[:5])

if __name__ == '__main__':
run_client()
41 changes: 29 additions & 12 deletions python/ppml/example/pytorch_nn_lr/pytorch_nn_lr_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@

import numpy as np
import pandas as pd
import click

import torch
from torch import nn
from bigdl.ppml.fl.estimator import Estimator
from bigdl.ppml.fl.algorithms.psi import PSI
from bigdl.ppml.fl.nn.pytorch.utils import set_one_like_parameter


class LocalModel(nn.Module):
Expand All @@ -34,7 +33,9 @@ def forward(self, x):
return x


if __name__ == '__main__':
@click.command()
@click.option('--load_model', default=False)
def run_client(load_model):
df_train = pd.read_csv('.data/diabetes-vfl-2.csv')

# this should wait for the merge of 2 FLServer (Py4J Java gRPC and Python gRPC)
Expand All @@ -45,16 +46,32 @@ def forward(self, x):

df_x = df_train
x = df_x.to_numpy(dtype="float32")
y = None
y = None

model = LocalModel(len(df_x.columns))
loss_fn = nn.BCELoss()
ppl = Estimator.from_torch(client_model=model,
client_id='2',
loss_fn=loss_fn,
optimizer_cls=torch.optim.SGD,
optimizer_args={'lr':1e-5},
target='localhost:8980')
response = ppl.fit(x, y)

if load_model:
model = torch.load('/tmp/pytorch_client_model_2.pt')
ppl = Estimator.from_torch(client_model=model,
client_id='2',
loss_fn=loss_fn,
optimizer_cls=torch.optim.SGD,
optimizer_args={'lr':1e-5},
target='localhost:8980',
client_model_path='/tmp/pytorch_client_model_2.pt')
response = ppl.fit(x, y, 5)
else:
model = LocalModel(len(df_x.columns))
ppl = Estimator.from_torch(client_model=model,
client_id='2',
loss_fn=loss_fn,
optimizer_cls=torch.optim.SGD,
optimizer_args={'lr':1e-5},
target='localhost:8980',
client_model_path='/tmp/pytorch_client_model_2.pt')
response = ppl.fit(x, y, 5)
result = ppl.predict(x)
print(result[:5])

if __name__ == '__main__':
run_client()
10 changes: 7 additions & 3 deletions python/ppml/src/bigdl/ppml/fl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,18 @@ def from_torch(client_model: nn.Module,
optimizer_cls,
optimizer_args={},
target="localhost:8980",
server_model=None):
server_model=None,
client_model_path=None,
server_model_path=None):
estimator = PytorchEstimator(model=client_model,
loss_fn=loss_fn,
loss_fn=loss_fn,
optimizer_cls=optimizer_cls,
optimizer_args=optimizer_args,
client_id=client_id,
target=target,
server_model=server_model)
server_model=server_model,
client_model_path=client_model_path,
server_model_path=server_model_path)
return estimator

@staticmethod
Expand Down
19 changes: 14 additions & 5 deletions python/ppml/src/bigdl/ppml/fl/nn/fl_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,22 @@


class FLServer(object):
def __init__(self, client_num=1):
def __init__(self, client_num=None):
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=5))
self.port = 8980
self.client_num = client_num
self.secure = False
self.load_config()
# a chance to overwrite client num
if client_num is not None:
self.conf['clientNum'] = client_num

def set_port(self, port):
self.port = port

def build(self):
add_NNServiceServicer_to_server(
NNServiceImpl(client_num=self.client_num),
NNServiceImpl(conf=self.conf),
self.server)
if self.secure:
self.server.add_secure_port(f'[::]:{self.port}', self.server_credentials)
Expand Down Expand Up @@ -65,11 +68,17 @@ def load_config(self):
( (private_key, certificate_chain), ) )
if 'serverPort' in conf:
self.port = conf['serverPort']
self.generate_conf(conf)

except yaml.YAMLError as e:
logging.warn('Loading config failed, using default config ')
except Exception as e:
logging.warn('Failed to find config file "ppml-conf.yaml", using default config')
logging.warn('Failed to load config file "ppml-conf.yaml", using default config')
self.generate_conf({})

def generate_conf(self, conf: dict):
self.conf = conf
# set default parameters if not specified in config
if 'clientNum' not in conf.keys():
self.conf['clientNum'] = 1

def wait_for_termination(self):
self.server.wait_for_termination()
Expand Down
46 changes: 43 additions & 3 deletions python/ppml/src/bigdl/ppml/fl/nn/generated/nn_service_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit f765583

Please sign in to comment.