-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c5be04f
commit 7297b08
Showing
13 changed files
with
313 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from ..nn.fl_client import FLClient | ||
|
||
def init_fl_context(client_id, target="localhost:8980"): | ||
FLClient.set_client_id(client_id) | ||
FLClient.set_target(target) | ||
FLClient.ensure_initialized() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import logging | ||
|
||
from bigdl.dllib.utils.log4Error import invalidOperationError | ||
from ..nn.fl_client import FLClient | ||
from bigdl.ppml.fl.nn.generated.psi_service_pb2_grpc import * | ||
from ..nn.generated.psi_service_pb2 import DownloadIntersectionRequest, SaltRequest, UploadSetRequest | ||
|
||
class PSI(object): | ||
def __init__(self) -> None: | ||
self.stub = PSIServiceStub(FLClient.channel) | ||
|
||
def get_salt(self, secure_code=""): | ||
return self.stub.getSalt(SaltRequest(secure_code=secure_code)).salt_reply | ||
|
||
def upload_set(self, ids, salt=""): | ||
# TODO: add hashing | ||
return self.stub.uploadSet( | ||
UploadSetRequest(client_id=FLClient.client_id, hashedID=ids)) | ||
|
||
def download_intersection(self, max_try=100, retry=3): | ||
for i in range(max_try): | ||
intersection = self.stub.downloadIntersection( | ||
DownloadIntersectionRequest()).intersection | ||
if intersection is not None: | ||
intersection = list(intersection) | ||
logging.info(f"Intersection completed, size {len(intersection)}") | ||
return intersection | ||
invalidOperationError(False, "Max retry reached, could not get intersection, exiting.") | ||
|
||
def get_intersection(self, ids, secure_code="", max_try=100, retry=3): | ||
salt = self.stub.getSalt(SaltRequest(secure_code=secure_code)).salt_reply | ||
self.upload_set(ids, salt) | ||
return self.download_intersection(max_try, retry) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
|
||
import logging | ||
import threading | ||
|
||
from bigdl.ppml.utils.log4Error import invalidOperationError | ||
|
||
|
||
class PsiIntersection(object): | ||
def __init__(self, max_collection=1) -> None: | ||
self.intersection = [] | ||
self._thread_intersection = [] | ||
|
||
self.max_collection = max_collection | ||
self.condition = threading.Condition() | ||
self._lock = threading.Lock() | ||
|
||
self.collection = [] | ||
|
||
def find_intersection(self, a, b): | ||
return list(set(a) & set(b)) | ||
|
||
def add_collection(self, collection): | ||
with self._lock: | ||
invalidOperationError(len(self.collection) < self.max_collection, | ||
f"PSI collection is full, got: {len(self.collection)}/{self.max_collection}") | ||
self.collection.append(collection) | ||
if len(self.collection) == self.max_collection: | ||
current_intersection = self.collection[0] | ||
for i in range(1, len(self.collection)): | ||
current_intersection = \ | ||
self.find_intersection(current_intersection, self.collection[i]) | ||
self.intersection = current_intersection | ||
|
||
def get_intersection(self): | ||
with self._lock: | ||
return self.intersection |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import logging | ||
from random import randint | ||
from uuid import uuid4 | ||
from bigdl.ppml.fl.psi.psi_intersection import PsiIntersection | ||
from bigdl.ppml.fl.nn.generated.psi_service_pb2_grpc import * | ||
from bigdl.ppml.fl.nn.generated.psi_service_pb2 import * | ||
|
||
|
||
|
||
class PSIServiceImpl(PSIServiceServicer): | ||
def __init__(self, conf) -> None: | ||
self.client_salt = None | ||
self.client_secret = None | ||
self.client_shuffle_seed = 0 | ||
# self.psi_collections = {} | ||
self.psi_intersection = PsiIntersection(conf['clientNum']) | ||
|
||
def getSalt(self, request, context): | ||
if self.client_salt is not None: | ||
salt = self.client_salt | ||
else: | ||
salt = str(uuid4()) | ||
self.client_salt = salt | ||
|
||
if self.client_secret is None: | ||
self.client_secret = request.secure_code | ||
elif self.client_secret != request.secure_code: | ||
salt = "" | ||
|
||
if self.client_shuffle_seed == 0: | ||
self.client_shuffle_seed = randint(0, 100) | ||
return SaltReply(salt_reply=salt) | ||
|
||
def uploadSet(self, request, context): | ||
client_id = request.client_id | ||
ids = request.hashedID | ||
self.psi_intersection.add_collection(ids) | ||
logging.info(f"{len(self.psi_intersection.collection)}-th collection added") | ||
return UploadSetResponse(status=1) | ||
|
||
|
||
def downloadIntersection(self, request, context): | ||
intersection = self.psi_intersection.get_intersection() | ||
return DownloadIntersectionResponse(intersection=intersection) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
# TODO: add security utils here |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import unittest | ||
|
||
|
||
from bigdl.ppml.fl.psi.psi import PSI | ||
from bigdl.ppml.fl.nn.fl_server import FLServer | ||
from bigdl.ppml.fl.nn.fl_context import init_fl_context | ||
from bigdl.ppml.fl.utils import FLTest | ||
|
||
class TestPSI(FLTest): | ||
def setUp(self) -> None: | ||
self.fl_server = FLServer(1) | ||
self.fl_server.set_port(self.port) | ||
self.fl_server.build() | ||
self.fl_server.start() | ||
|
||
def tearDown(self) -> None: | ||
self.fl_server.stop() | ||
|
||
|
||
def test_psi_get_salt(self): | ||
init_fl_context("1", self.target) | ||
psi = PSI() | ||
salt = psi.get_salt() | ||
assert (isinstance(salt, str)) | ||
|
||
def test_psi_pipeline(self): | ||
init_fl_context("1", self.target) | ||
psi = PSI() | ||
salt = psi.get_salt() | ||
key = ["k1", "k2"] | ||
psi.upload_set(key, salt) | ||
intersection = psi.download_intersection() | ||
assert (isinstance(intersection, list)) | ||
self.assertEqual(len(intersection), 2) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
File renamed without changes.
File renamed without changes.