From 9b23be2e95fec756066ca0ed3188c3db2602b757 Mon Sep 17 00:00:00 2001 From: schintap Date: Fri, 30 Nov 2018 12:48:56 -0600 Subject: [PATCH] [SPARK-26201] Fix python broadcast with encryption MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Python with rpc and disk encryption enabled along with a python broadcast variable and just read the value back on the driver side the job failed with: Traceback (most recent call last): File "broadcast.py", line 37, in words_new.value File "/pyspark.zip/pyspark/broadcast.py", line 137, in value File "pyspark.zip/pyspark/broadcast.py", line 122, in load_from_path File "pyspark.zip/pyspark/broadcast.py", line 128, in load EOFError: Ran out of input To reproduce use configs: --conf spark.network.crypto.enabled=true --conf spark.io.encryption.enabled=true Code: words_new = sc.broadcast(["scala", "java", "hadoop", "spark", "akka"]) words_new.value print(words_new.value) ## How was this patch tested? words_new = sc.broadcast([“scala”, “java”, “hadoop”, “spark”, “akka”]) textFile = sc.textFile(“README.md”) wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word + words_new.value[1], 1)).reduceByKey(lambda a, b: a+b) count = wordCounts.count() print(count) words_new.value print(words_new.value) Closes #23166 from redsanket/SPARK-26201. Authored-by: schintap Signed-off-by: Thomas Graves --- .../apache/spark/api/python/PythonRDD.scala | 29 ++++++++++++++++--- python/pyspark/broadcast.py | 21 ++++++++++---- python/pyspark/tests/test_broadcast.py | 15 ++++++++++ 3 files changed, 56 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8b5a7a9aefea5..5ed5070558af7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -660,6 +660,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial with Logging { private var encryptionServer: PythonServer[Unit] = null + private var decryptionServer: PythonServer[Unit] = null /** * Read data from disks, then copy it to `out` @@ -708,16 +709,36 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial override def handleConnection(sock: Socket): Unit = { val env = SparkEnv.get val in = sock.getInputStream() - val dir = new File(Utils.getLocalDir(env.conf)) - val file = File.createTempFile("broadcast", "", dir) - path = file.getAbsolutePath - val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path)) + val abspath = new File(path).getAbsolutePath + val out = env.serializerManager.wrapForEncryption(new FileOutputStream(abspath)) DechunkedInputStream.dechunkAndCopyToOutput(in, out) } } Array(encryptionServer.port, encryptionServer.secret) } + def setupDecryptionServer(): Array[Any] = { + decryptionServer = new PythonServer[Unit]("broadcast-decrypt-server-for-driver") { + override def handleConnection(sock: Socket): Unit = { + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream())) + Utils.tryWithSafeFinally { + val in = SparkEnv.get.serializerManager.wrapForEncryption(new FileInputStream(path)) + Utils.tryWithSafeFinally { + Utils.copyStream(in, out, false) + } { + in.close() + } + out.flush() + } { + JavaUtils.closeQuietly(out) + } + } + } + Array(decryptionServer.port, decryptionServer.secret) + } + + def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult() + def waitTillDataReceived(): Unit = encryptionServer.getResult() } // scalastyle:on no.finalize diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 1c7f2a7418df0..29358b5740e51 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -77,11 +77,12 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None, # we're on the driver. We want the pickled data to end up in a file (maybe encrypted) f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) self._path = f.name - python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) + self._sc = sc + self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) if sc._encryption_enabled: # with encryption, we ask the jvm to do the encryption for us, we send it data # over a socket - port, auth_secret = python_broadcast.setupEncryptionServer() + port, auth_secret = self._python_broadcast.setupEncryptionServer() (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret) broadcast_out = ChunkedStream(encryption_sock_file, 8192) else: @@ -89,12 +90,14 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None, broadcast_out = f self.dump(value, broadcast_out) if sc._encryption_enabled: - python_broadcast.waitTillDataReceived() - self._jbroadcast = sc._jsc.broadcast(python_broadcast) + self._python_broadcast.waitTillDataReceived() + self._jbroadcast = sc._jsc.broadcast(self._python_broadcast) self._pickle_registry = pickle_registry else: # we're on an executor self._jbroadcast = None + self._sc = None + self._python_broadcast = None if sock_file is not None: # the jvm is doing decryption for us. Read the value # immediately from the sock_file @@ -134,7 +137,15 @@ def value(self): """ Return the broadcasted value """ if not hasattr(self, "_value") and self._path is not None: - self._value = self.load_from_path(self._path) + # we only need to decrypt it here when encryption is enabled and + # if its on the driver, since executor decryption is handled already + if self._sc is not None and self._sc._encryption_enabled: + port, auth_secret = self._python_broadcast.setupDecryptionServer() + (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret) + self._python_broadcast.waitTillBroadcastDataSent() + return self.load(decrypted_sock_file) + else: + self._value = self.load_from_path(self._path) return self._value def unpersist(self, blocking=False): diff --git a/python/pyspark/tests/test_broadcast.py b/python/pyspark/tests/test_broadcast.py index a98626e8f4bc9..11d31d24bb011 100644 --- a/python/pyspark/tests/test_broadcast.py +++ b/python/pyspark/tests/test_broadcast.py @@ -67,6 +67,21 @@ def test_broadcast_with_encryption(self): def test_broadcast_no_encryption(self): self._test_multiple_broadcasts() + def _test_broadcast_on_driver(self, *extra_confs): + conf = SparkConf() + for key, value in extra_confs: + conf.set(key, value) + conf.setMaster("local-cluster[2,1,1024]") + self.sc = SparkContext(conf=conf) + bs = self.sc.broadcast(value=5) + self.assertEqual(5, bs.value) + + def test_broadcast_value_driver_no_encryption(self): + self._test_broadcast_on_driver() + + def test_broadcast_value_driver_encryption(self): + self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true")) + class BroadcastFrameProtocolTest(unittest.TestCase):