-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch VFL NN refactor, refine and example test #5607
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#!/usr/bin/env bash | ||
|
||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
set -ex | ||
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd) | ||
source $SCRIPT_DIR/../prepare_env.sh | ||
|
||
|
||
cd "`dirname $0`" | ||
echo "Running PPML tests" | ||
cd ../../ | ||
|
||
rm -rf /tmp/pytorch_server_model* | ||
rm -rf /tmp/pytorch_client_model* | ||
rm -rf /tmp/vfl_server_model* | ||
python src/bigdl/ppml/fl/nn/fl_server.py --client_num 2 & | ||
python example/pytorch_nn_lr/pytorch_nn_lr_1.py --data_path example/pytorch_nn_lr/data/diabetes-vfl-1.csv & | ||
python example/pytorch_nn_lr/pytorch_nn_lr_2.py --data_path example/pytorch_nn_lr/data/diabetes-vfl-2.csv & |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
# limitations under the License. | ||
# | ||
|
||
import logging | ||
from typing import List | ||
import numpy as np | ||
import pandas as pd | ||
|
@@ -22,8 +23,11 @@ | |
import torch | ||
from torch import Tensor, nn | ||
from bigdl.ppml.fl.estimator import Estimator | ||
from bigdl.ppml.fl.psi.psi import PSI | ||
from bigdl.ppml.fl.nn.fl_context import init_fl_context | ||
from bigdl.ppml.fl.psi.psi_client import PSI | ||
|
||
fmt = '%(asctime)s %(levelname)s {%(module)s:%(lineno)d} - %(message)s' | ||
logging.basicConfig(format=fmt, level=logging.INFO) | ||
|
||
class LocalModel(nn.Module): | ||
def __init__(self, num_feature) -> None: | ||
|
@@ -50,8 +54,10 @@ def forward(self, x: List[Tensor]): | |
|
||
@click.command() | ||
@click.option('--load_model', default=False) | ||
def run_client(load_model): | ||
df_train = pd.read_csv('./data/diabetes-vfl-1.csv') | ||
@click.option('--data_path', default="./data/diabetes-vfl-1.csv") | ||
def run_client(load_model, data_path): | ||
init_fl_context('1') | ||
df_train = pd.read_csv(data_path) | ||
|
||
df_train['ID'] = df_train['ID'].astype(str) | ||
psi = PSI() | ||
|
@@ -69,11 +75,9 @@ def run_client(load_model): | |
if load_model: | ||
model = torch.load('/tmp/pytorch_client_model_1.pt') | ||
ppl = Estimator.from_torch(client_model=model, | ||
client_id='1', | ||
loss_fn=loss_fn, | ||
optimizer_cls=torch.optim.SGD, | ||
optimizer_args={'lr':1e-5}, | ||
target='localhost:8980', | ||
optimizer_args={'lr':1e-4}, | ||
server_model_path='/tmp/pytorch_server_model', | ||
client_model_path='/tmp/pytorch_client_model_1.pt') | ||
ppl.load_server_model('/tmp/pytorch_server_model') | ||
|
@@ -83,11 +87,9 @@ def run_client(load_model): | |
|
||
server_model = ServerModel() | ||
ppl = Estimator.from_torch(client_model=model, | ||
client_id='1', | ||
loss_fn=loss_fn, | ||
optimizer_cls=torch.optim.SGD, | ||
optimizer_args={'lr':1e-5}, | ||
target='localhost:8980', | ||
optimizer_args={'lr':1e-4}, | ||
server_model=server_model, | ||
server_model_path='/tmp/pytorch_server_model', | ||
client_model_path='/tmp/pytorch_client_model_1.pt') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. save after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Estimator would automatically save server model and client model to respective path after each epoch of |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ | |
from ..nn.fl_client import FLClient | ||
|
||
def init_fl_context(client_id, target="localhost:8980"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add python doc |
||
FLClient.load_config() | ||
FLClient.set_client_id(client_id) | ||
# target can be set in config file, and also could be overwritten here | ||
FLClient.set_target(target) | ||
FLClient.ensure_initialized() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should
id
will be an integer instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do this. And user must pass an integer when calling
init_fl_context
and the client ID could only be set in this method.All the protobuf types of ID are String now, I would open another PR to modify all of them.