Skip to content

Commit

Permalink
add psi support for fl nn server
Browse files Browse the repository at this point in the history
  • Loading branch information
Litchilitchy committed Aug 17, 2022
1 parent c5be04f commit 7297b08
Show file tree
Hide file tree
Showing 13 changed files with 313 additions and 5 deletions.
26 changes: 24 additions & 2 deletions python/ppml/src/bigdl/ppml/fl/nn/fl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,41 @@
import logging
import pickle
import grpc
from numpy import ndarray
from bigdl.ppml.fl.nn.generated.nn_service_pb2 import TrainRequest, PredictRequest, UploadMetaRequest
from bigdl.ppml.fl.nn.generated.nn_service_pb2_grpc import *
from bigdl.ppml.fl.nn.utils import ndarray_map_to_tensor_map
import yaml
import threading
from torch.utils.data import DataLoader
from bigdl.dllib.utils.log4Error import invalidInputError

from bigdl.ppml.fl.nn.utils import ClassAndArgsWrapper

class FLClient(object):
channel = None
_lock = threading.Lock()
client_id = None
target = "localhost:8980"
secure = False
creds = None

@staticmethod
def set_client_id(client_id):
FLClient.client_id = client_id

@staticmethod
def set_target(target):
FLClient.target = target

@staticmethod
def ensure_initialized():
with FLClient._lock:
if FLClient.channel == None:
if FLClient.secure:
FLClient.channel = grpc.secure_channel(FLClient.target, FLClient.creds)
else:
FLClient.channel = grpc.insecure_channel(FLClient.target)


def __init__(self, client_id, aggregator, target="localhost:8980") -> None:
self.secure = False
self.load_config()
Expand Down Expand Up @@ -91,3 +112,4 @@ def load_config(self):
logging.warn('Loading config failed, using default config ')
except Exception as e:
logging.warn('Failed to find config file "ppml-conf.yaml", using default config')

22 changes: 22 additions & 0 deletions python/ppml/src/bigdl/ppml/fl/nn/fl_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from ..nn.fl_client import FLClient

def init_fl_context(client_id, target="localhost:8980"):
FLClient.set_client_id(client_id)
FLClient.set_target(target)
FLClient.ensure_initialized()
9 changes: 6 additions & 3 deletions python/ppml/src/bigdl/ppml/fl/nn/fl_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from bigdl.ppml.fl.nn.nn_service import NNServiceImpl
import yaml

from ..psi.psi_service import PSIServiceImpl

from .generated.psi_service_pb2_grpc import add_PSIServiceServicer_to_server



class FLServer(object):
Expand All @@ -35,9 +39,8 @@ def set_port(self, port):
self.port = port

def build(self):
add_NNServiceServicer_to_server(
NNServiceImpl(client_num=self.client_num),
self.server)
add_NNServiceServicer_to_server(NNServiceImpl(conf=self.conf), self.server)
add_PSIServiceServicer_to_server(PSIServiceImpl(conf=self.conf), self.server)
if self.secure:
self.server.add_secure_port(f'[::]:{self.port}', self.server_credentials)
else:
Expand Down
49 changes: 49 additions & 0 deletions python/ppml/src/bigdl/ppml/fl/psi/psi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging

from bigdl.dllib.utils.log4Error import invalidOperationError
from ..nn.fl_client import FLClient
from bigdl.ppml.fl.nn.generated.psi_service_pb2_grpc import *
from ..nn.generated.psi_service_pb2 import DownloadIntersectionRequest, SaltRequest, UploadSetRequest

class PSI(object):
def __init__(self) -> None:
self.stub = PSIServiceStub(FLClient.channel)

def get_salt(self, secure_code=""):
return self.stub.getSalt(SaltRequest(secure_code=secure_code)).salt_reply

def upload_set(self, ids, salt=""):
# TODO: add hashing
return self.stub.uploadSet(
UploadSetRequest(client_id=FLClient.client_id, hashedID=ids))

def download_intersection(self, max_try=100, retry=3):
for i in range(max_try):
intersection = self.stub.downloadIntersection(
DownloadIntersectionRequest()).intersection
if intersection is not None:
intersection = list(intersection)
logging.info(f"Intersection completed, size {len(intersection)}")
return intersection
invalidOperationError(False, "Max retry reached, could not get intersection, exiting.")

def get_intersection(self, ids, secure_code="", max_try=100, retry=3):
salt = self.stub.getSalt(SaltRequest(secure_code=secure_code)).salt_reply
self.upload_set(ids, salt)
return self.download_intersection(max_try, retry)
52 changes: 52 additions & 0 deletions python/ppml/src/bigdl/ppml/fl/psi/psi_intersection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import logging
import threading

from bigdl.ppml.utils.log4Error import invalidOperationError


class PsiIntersection(object):
def __init__(self, max_collection=1) -> None:
self.intersection = []
self._thread_intersection = []

self.max_collection = max_collection
self.condition = threading.Condition()
self._lock = threading.Lock()

self.collection = []

def find_intersection(self, a, b):
return list(set(a) & set(b))

def add_collection(self, collection):
with self._lock:
invalidOperationError(len(self.collection) < self.max_collection,
f"PSI collection is full, got: {len(self.collection)}/{self.max_collection}")
self.collection.append(collection)
if len(self.collection) == self.max_collection:
current_intersection = self.collection[0]
for i in range(1, len(self.collection)):
current_intersection = \
self.find_intersection(current_intersection, self.collection[i])
self.intersection = current_intersection

def get_intersection(self):
with self._lock:
return self.intersection
60 changes: 60 additions & 0 deletions python/ppml/src/bigdl/ppml/fl/psi/psi_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
from random import randint
from uuid import uuid4
from bigdl.ppml.fl.psi.psi_intersection import PsiIntersection
from bigdl.ppml.fl.nn.generated.psi_service_pb2_grpc import *
from bigdl.ppml.fl.nn.generated.psi_service_pb2 import *



class PSIServiceImpl(PSIServiceServicer):
def __init__(self, conf) -> None:
self.client_salt = None
self.client_secret = None
self.client_shuffle_seed = 0
# self.psi_collections = {}
self.psi_intersection = PsiIntersection(conf['clientNum'])

def getSalt(self, request, context):
if self.client_salt is not None:
salt = self.client_salt
else:
salt = str(uuid4())
self.client_salt = salt

if self.client_secret is None:
self.client_secret = request.secure_code
elif self.client_secret != request.secure_code:
salt = ""

if self.client_shuffle_seed == 0:
self.client_shuffle_seed = randint(0, 100)
return SaltReply(salt_reply=salt)

def uploadSet(self, request, context):
client_id = request.client_id
ids = request.hashedID
self.psi_intersection.add_collection(ids)
logging.info(f"{len(self.psi_intersection.collection)}-th collection added")
return UploadSetResponse(status=1)


def downloadIntersection(self, request, context):
intersection = self.psi_intersection.get_intersection()
return DownloadIntersectionResponse(intersection=intersection)
17 changes: 17 additions & 0 deletions python/ppml/src/bigdl/ppml/fl/psi/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# TODO: add security utils here
15 changes: 15 additions & 0 deletions python/ppml/test/bigdl/ppml/fl/psi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
53 changes: 53 additions & 0 deletions python/ppml/test/bigdl/ppml/fl/psi/test_psi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest


from bigdl.ppml.fl.psi.psi import PSI
from bigdl.ppml.fl.nn.fl_server import FLServer
from bigdl.ppml.fl.nn.fl_context import init_fl_context
from bigdl.ppml.fl.utils import FLTest

class TestPSI(FLTest):
def setUp(self) -> None:
self.fl_server = FLServer(1)
self.fl_server.set_port(self.port)
self.fl_server.build()
self.fl_server.start()

def tearDown(self) -> None:
self.fl_server.stop()


def test_psi_get_salt(self):
init_fl_context("1", self.target)
psi = PSI()
salt = psi.get_salt()
assert (isinstance(salt, str))

def test_psi_pipeline(self):
init_fl_context("1", self.target)
psi = PSI()
salt = psi.get_salt()
key = ["k1", "k2"]
psi.upload_set(key, salt)
intersection = psi.download_intersection()
assert (isinstance(intersection, list))
self.assertEqual(len(intersection), 2)

if __name__ == '__main__':
unittest.main()
15 changes: 15 additions & 0 deletions python/ppml/test/bigdl/ppml/fl/py4j/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

0 comments on commit 7297b08

Please sign in to comment.