diff --git a/python/ppml/src/bigdl/ppml/fl/nn/fl_client.py b/python/ppml/src/bigdl/ppml/fl/nn/fl_client.py index 5061260c5ca..24ffd387d2f 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/fl_client.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/fl_client.py @@ -18,13 +18,11 @@ import logging import pickle import grpc -from numpy import ndarray from bigdl.ppml.fl.nn.generated.nn_service_pb2 import TrainRequest, PredictRequest, UploadMetaRequest from bigdl.ppml.fl.nn.generated.nn_service_pb2_grpc import * from bigdl.ppml.fl.nn.utils import ndarray_map_to_tensor_map import yaml import threading -from torch.utils.data import DataLoader from bigdl.dllib.utils.log4Error import invalidInputError from bigdl.ppml.fl.nn.utils import ClassAndArgsWrapper @@ -32,6 +30,29 @@ class FLClient(object): channel = None _lock = threading.Lock() + client_id = None + target = "localhost:8980" + secure = False + creds = None + + @staticmethod + def set_client_id(client_id): + FLClient.client_id = client_id + + @staticmethod + def set_target(target): + FLClient.target = target + + @staticmethod + def ensure_initialized(): + with FLClient._lock: + if FLClient.channel == None: + if FLClient.secure: + FLClient.channel = grpc.secure_channel(FLClient.target, FLClient.creds) + else: + FLClient.channel = grpc.insecure_channel(FLClient.target) + + def __init__(self, client_id, aggregator, target="localhost:8980") -> None: self.secure = False self.load_config() @@ -91,3 +112,4 @@ def load_config(self): logging.warn('Loading config failed, using default config ') except Exception as e: logging.warn('Failed to find config file "ppml-conf.yaml", using default config') + diff --git a/python/ppml/src/bigdl/ppml/fl/nn/fl_context.py b/python/ppml/src/bigdl/ppml/fl/nn/fl_context.py new file mode 100644 index 00000000000..815d297cb4b --- /dev/null +++ b/python/ppml/src/bigdl/ppml/fl/nn/fl_context.py @@ -0,0 +1,22 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from ..nn.fl_client import FLClient + +def init_fl_context(client_id, target="localhost:8980"): + FLClient.set_client_id(client_id) + FLClient.set_target(target) + FLClient.ensure_initialized() diff --git a/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py b/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py index 0dee411e31b..fa9517ab08c 100644 --- a/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py +++ b/python/ppml/src/bigdl/ppml/fl/nn/fl_server.py @@ -17,10 +17,14 @@ from concurrent import futures import grpc from bigdl.ppml.fl import * -from bigdl.ppml.fl.nn.generated.nn_service_pb2_grpc import * + from bigdl.ppml.fl.nn.nn_service import NNServiceImpl import yaml -import logging + +from ..psi.psi_service import PSIServiceImpl +from .generated.nn_service_pb2_grpc import * +from .generated.psi_service_pb2_grpc import * + class FLServer(object): @@ -38,9 +42,8 @@ def set_port(self, port): self.port = port def build(self): - add_NNServiceServicer_to_server( - NNServiceImpl(conf=self.conf), - self.server) + add_NNServiceServicer_to_server(NNServiceImpl(conf=self.conf), self.server) + add_PSIServiceServicer_to_server(PSIServiceImpl(conf=self.conf), self.server) if self.secure: self.server.add_secure_port(f'[::]:{self.port}', self.server_credentials) else: diff --git a/python/ppml/test/bigdl/ppml/fl/algorithms/__init__.py b/python/ppml/src/bigdl/ppml/fl/psi/__init__.py similarity index 100% rename from python/ppml/test/bigdl/ppml/fl/algorithms/__init__.py rename to python/ppml/src/bigdl/ppml/fl/psi/__init__.py diff --git a/python/ppml/src/bigdl/ppml/fl/psi/psi.py b/python/ppml/src/bigdl/ppml/fl/psi/psi.py new file mode 100644 index 00000000000..4095ea9d8d5 --- /dev/null +++ b/python/ppml/src/bigdl/ppml/fl/psi/psi.py @@ -0,0 +1,49 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging + +from bigdl.dllib.utils.log4Error import invalidOperationError +from ..nn.fl_client import FLClient +from bigdl.ppml.fl.nn.generated.psi_service_pb2_grpc import * +from ..nn.generated.psi_service_pb2 import DownloadIntersectionRequest, SaltRequest, UploadSetRequest + +class PSI(object): + def __init__(self) -> None: + self.stub = PSIServiceStub(FLClient.channel) + + def get_salt(self, secure_code=""): + return self.stub.getSalt(SaltRequest(secure_code=secure_code)).salt_reply + + def upload_set(self, ids, salt=""): + # TODO: add hashing + return self.stub.uploadSet( + UploadSetRequest(client_id=FLClient.client_id, hashedID=ids)) + + def download_intersection(self, max_try=100, retry=3): + for i in range(max_try): + intersection = self.stub.downloadIntersection( + DownloadIntersectionRequest()).intersection + if intersection is not None: + intersection = list(intersection) + logging.info(f"Intersection completed, size {len(intersection)}") + return intersection + invalidOperationError(False, "Max retry reached, could not get intersection, exiting.") + + def get_intersection(self, ids, secure_code="", max_try=100, retry=3): + salt = self.stub.getSalt(SaltRequest(secure_code=secure_code)).salt_reply + self.upload_set(ids, salt) + return self.download_intersection(max_try, retry) diff --git a/python/ppml/src/bigdl/ppml/fl/psi/psi_intersection.py b/python/ppml/src/bigdl/ppml/fl/psi/psi_intersection.py new file mode 100644 index 00000000000..064559840cb --- /dev/null +++ b/python/ppml/src/bigdl/ppml/fl/psi/psi_intersection.py @@ -0,0 +1,52 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import logging +import threading + +from bigdl.ppml.utils.log4Error import invalidOperationError + + +class PsiIntersection(object): + def __init__(self, max_collection=1) -> None: + self.intersection = [] + self._thread_intersection = [] + + self.max_collection = max_collection + self.condition = threading.Condition() + self._lock = threading.Lock() + + self.collection = [] + + def find_intersection(self, a, b): + return list(set(a) & set(b)) + + def add_collection(self, collection): + with self._lock: + invalidOperationError(len(self.collection) < self.max_collection, + f"PSI collection is full, got: {len(self.collection)}/{self.max_collection}") + self.collection.append(collection) + if len(self.collection) == self.max_collection: + current_intersection = self.collection[0] + for i in range(1, len(self.collection)): + current_intersection = \ + self.find_intersection(current_intersection, self.collection[i]) + self.intersection = current_intersection + + def get_intersection(self): + with self._lock: + return self.intersection \ No newline at end of file diff --git a/python/ppml/src/bigdl/ppml/fl/psi/psi_service.py b/python/ppml/src/bigdl/ppml/fl/psi/psi_service.py new file mode 100644 index 00000000000..958cdb36433 --- /dev/null +++ b/python/ppml/src/bigdl/ppml/fl/psi/psi_service.py @@ -0,0 +1,60 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from random import randint +from uuid import uuid4 +from bigdl.ppml.fl.psi.psi_intersection import PsiIntersection +from bigdl.ppml.fl.nn.generated.psi_service_pb2_grpc import * +from bigdl.ppml.fl.nn.generated.psi_service_pb2 import * + + + +class PSIServiceImpl(PSIServiceServicer): + def __init__(self, conf) -> None: + self.client_salt = None + self.client_secret = None + self.client_shuffle_seed = 0 + # self.psi_collections = {} + self.psi_intersection = PsiIntersection(conf['clientNum']) + + def getSalt(self, request, context): + if self.client_salt is not None: + salt = self.client_salt + else: + salt = str(uuid4()) + self.client_salt = salt + + if self.client_secret is None: + self.client_secret = request.secure_code + elif self.client_secret != request.secure_code: + salt = "" + + if self.client_shuffle_seed == 0: + self.client_shuffle_seed = randint(0, 100) + return SaltReply(salt_reply=salt) + + def uploadSet(self, request, context): + client_id = request.client_id + ids = request.hashedID + self.psi_intersection.add_collection(ids) + logging.info(f"{len(self.psi_intersection.collection)}-th collection added") + return UploadSetResponse(status=1) + + + def downloadIntersection(self, request, context): + intersection = self.psi_intersection.get_intersection() + return DownloadIntersectionResponse(intersection=intersection) diff --git a/python/ppml/src/bigdl/ppml/fl/psi/utils.py b/python/ppml/src/bigdl/ppml/fl/psi/utils.py new file mode 100644 index 00000000000..3f987f434ad --- /dev/null +++ b/python/ppml/src/bigdl/ppml/fl/psi/utils.py @@ -0,0 +1,17 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# TODO: add security utils here diff --git a/python/ppml/test/bigdl/ppml/fl/psi/__init__.py b/python/ppml/test/bigdl/ppml/fl/psi/__init__.py new file mode 100644 index 00000000000..2151a805423 --- /dev/null +++ b/python/ppml/test/bigdl/ppml/fl/psi/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/ppml/test/bigdl/ppml/fl/psi/test_psi.py b/python/ppml/test/bigdl/ppml/fl/psi/test_psi.py new file mode 100644 index 00000000000..d31cce4eb7d --- /dev/null +++ b/python/ppml/test/bigdl/ppml/fl/psi/test_psi.py @@ -0,0 +1,53 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + + +from bigdl.ppml.fl.psi.psi import PSI +from bigdl.ppml.fl.nn.fl_server import FLServer +from bigdl.ppml.fl.nn.fl_context import init_fl_context +from bigdl.ppml.fl.utils import FLTest + +class TestPSI(FLTest): + def setUp(self) -> None: + self.fl_server = FLServer(1) + self.fl_server.set_port(self.port) + self.fl_server.build() + self.fl_server.start() + + def tearDown(self) -> None: + self.fl_server.stop() + + + def test_psi_get_salt(self): + init_fl_context("1", self.target) + psi = PSI() + salt = psi.get_salt() + assert (isinstance(salt, str)) + + def test_psi_pipeline(self): + init_fl_context("1", self.target) + psi = PSI() + salt = psi.get_salt() + key = ["k1", "k2"] + psi.upload_set(key, salt) + intersection = psi.download_intersection() + assert (isinstance(intersection, list)) + self.assertEqual(len(intersection), 2) + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppml/test/bigdl/ppml/fl/py4j/__init__.py b/python/ppml/test/bigdl/ppml/fl/py4j/__init__.py new file mode 100644 index 00000000000..2151a805423 --- /dev/null +++ b/python/ppml/test/bigdl/ppml/fl/py4j/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/ppml/test/bigdl/ppml/fl/algorithms/test_fgboost_regression.py b/python/ppml/test/bigdl/ppml/fl/py4j/test_fgboost_regression.py similarity index 100% rename from python/ppml/test/bigdl/ppml/fl/algorithms/test_fgboost_regression.py rename to python/ppml/test/bigdl/ppml/fl/py4j/test_fgboost_regression.py diff --git a/python/ppml/test/bigdl/ppml/fl/algorithms/test_psi.py b/python/ppml/test/bigdl/ppml/fl/py4j/test_psi.py similarity index 100% rename from python/ppml/test/bigdl/ppml/fl/algorithms/test_psi.py rename to python/ppml/test/bigdl/ppml/fl/py4j/test_psi.py