From f293e443df212318f167faa587add7efbbe6a094 Mon Sep 17 00:00:00 2001 From: Song Jiaming Date: Thu, 25 Aug 2022 11:58:21 +0800 Subject: [PATCH] ppml python psi hashing (#5531) --- python/ppml/src/bigdl/ppml/fl/psi/psi.py | 10 ++++-- python/ppml/src/bigdl/ppml/fl/psi/utils.py | 14 +++++++- .../test/bigdl/ppml/fl/psi/test_hashing.py | 34 +++++++++++++++++++ .../ppml/test/bigdl/ppml/fl/psi/test_psi.py | 2 +- 4 files changed, 56 insertions(+), 4 deletions(-) create mode 100644 python/ppml/test/bigdl/ppml/fl/psi/test_hashing.py diff --git a/python/ppml/src/bigdl/ppml/fl/psi/psi.py b/python/ppml/src/bigdl/ppml/fl/psi/psi.py index 4095ea9d8d5..946adb86bf5 100644 --- a/python/ppml/src/bigdl/ppml/fl/psi/psi.py +++ b/python/ppml/src/bigdl/ppml/fl/psi/psi.py @@ -17,6 +17,7 @@ import logging from bigdl.dllib.utils.log4Error import invalidOperationError +from bigdl.ppml.fl.psi.utils import to_hex_string from ..nn.fl_client import FLClient from bigdl.ppml.fl.nn.generated.psi_service_pb2_grpc import * from ..nn.generated.psi_service_pb2 import DownloadIntersectionRequest, SaltRequest, UploadSetRequest @@ -24,22 +25,27 @@ class PSI(object): def __init__(self) -> None: self.stub = PSIServiceStub(FLClient.channel) + self.hashed_ids_to_ids = {} def get_salt(self, secure_code=""): return self.stub.getSalt(SaltRequest(secure_code=secure_code)).salt_reply def upload_set(self, ids, salt=""): # TODO: add hashing + hashed_ids = to_hex_string(ids, salt) + self.hashed_ids_to_ids = dict(zip(hashed_ids, ids)) + return self.stub.uploadSet( - UploadSetRequest(client_id=FLClient.client_id, hashedID=ids)) + UploadSetRequest(client_id=FLClient.client_id, hashedID=hashed_ids)) def download_intersection(self, max_try=100, retry=3): for i in range(max_try): intersection = self.stub.downloadIntersection( DownloadIntersectionRequest()).intersection if intersection is not None: - intersection = list(intersection) + hashed_intersection = list(intersection) logging.info(f"Intersection completed, size {len(intersection)}") + intersection = [self.hashed_ids_to_ids[i] for i in hashed_intersection] return intersection invalidOperationError(False, "Max retry reached, could not get intersection, exiting.") diff --git a/python/ppml/src/bigdl/ppml/fl/psi/utils.py b/python/ppml/src/bigdl/ppml/fl/psi/utils.py index 3f987f434ad..ebe637ea002 100644 --- a/python/ppml/src/bigdl/ppml/fl/psi/utils.py +++ b/python/ppml/src/bigdl/ppml/fl/psi/utils.py @@ -14,4 +14,16 @@ # limitations under the License. # -# TODO: add security utils here + +import hashlib + +def to_hex_string(ids, salt, padding_size=32): + hashing = hashlib.sha384() + hex_string = [] + for ch in ids: + hashing.update(bytearray(ch, 'utf-8') + bytearray(salt, 'utf-8')) + ch = hashing.hexdigest() + while len(ch) < padding_size: + ch.insert(0, '0') + hex_string.append(ch) + return hex_string \ No newline at end of file diff --git a/python/ppml/test/bigdl/ppml/fl/psi/test_hashing.py b/python/ppml/test/bigdl/ppml/fl/psi/test_hashing.py new file mode 100644 index 00000000000..5baf296209c --- /dev/null +++ b/python/ppml/test/bigdl/ppml/fl/psi/test_hashing.py @@ -0,0 +1,34 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from uuid import uuid4 + + +from bigdl.ppml.fl.psi.utils import * + +from bigdl.ppml.fl.utils import FLTest + +class TestHashing(FLTest): + def test_hashing(self): + ids = ['1', '2', '4', '5'] + salt = str(uuid4()) + hex_string = to_hex_string(ids, salt) + hex_string + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppml/test/bigdl/ppml/fl/psi/test_psi.py b/python/ppml/test/bigdl/ppml/fl/psi/test_psi.py index d31cce4eb7d..6ef5fdb59fb 100644 --- a/python/ppml/test/bigdl/ppml/fl/psi/test_psi.py +++ b/python/ppml/test/bigdl/ppml/fl/psi/test_psi.py @@ -47,7 +47,7 @@ def test_psi_pipeline(self): psi.upload_set(key, salt) intersection = psi.download_intersection() assert (isinstance(intersection, list)) - self.assertEqual(len(intersection), 2) + self.assertEqual(key, intersection) if __name__ == '__main__': unittest.main()