From c364d88ecbbbec29de6be0d089448b90c80b1eac Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Mon, 30 May 2022 21:10:29 +0800 Subject: [PATCH] Refine ppml crypto (#4719) * update * update * update scala * update API * rename * fix typo * fix ut * fix style --- scala/ppml/pom.xml | 3 - .../analytics/bigdl/ppml/PPMLContext.scala | 45 ++-- .../bigdl/ppml/crypto/BigDLEncrypt.scala | 210 +++++++++++++++ .../analytics/bigdl/ppml/crypto/Crypto.scala | 117 +++++++-- .../bigdl/ppml/crypto/FernetEncrypt.scala | 239 ------------------ .../dataframe/EncryptedDataFrameReader.scala | 9 +- .../dataframe/EncryptedDataFrameWriter.scala | 29 +++ .../bigdl/ppml/utils/EncryptIOArguments.scala | 7 +- .../bigdl/ppml/crypto/EncryptSpec.scala | 89 +++++++ .../dataframe/EncryptDataFrameSpec.scala | 23 +- 10 files changed, 473 insertions(+), 298 deletions(-) create mode 100644 scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/BigDLEncrypt.scala delete mode 100644 scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/FernetEncrypt.scala create mode 100644 scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameWriter.scala create mode 100644 scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/EncryptSpec.scala diff --git a/scala/ppml/pom.xml b/scala/ppml/pom.xml index f536533497e..33dc01cc901 100644 --- a/scala/ppml/pom.xml +++ b/scala/ppml/pom.xml @@ -12,13 +12,10 @@ jar - provided 1.37.0 ${project.parent.basedir} - - org.apache.spark diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/PPMLContext.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/PPMLContext.scala index cbca5b6b5bd..17e83ea32a2 100644 --- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/PPMLContext.scala +++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/PPMLContext.scala @@ -18,8 +18,7 @@ package com.intel.analytics.bigdl.ppml import com.intel.analytics.bigdl.dllib.NNContext.{checkScalaVersion, checkSparkVersion, createSparkConf, initConf, initNNContext} import com.intel.analytics.bigdl.dllib.utils.Log4Error -import com.intel.analytics.bigdl.ppml.crypto.CryptoMode.CryptoMode -import com.intel.analytics.bigdl.ppml.crypto.{CryptoMode, EncryptRuntimeException, FernetEncrypt} +import com.intel.analytics.bigdl.ppml.crypto.{AES_CBC_PKCS5PADDING, Crypto, CryptoMode, DECRYPT, ENCRYPT, EncryptRuntimeException, BigDLEncrypt, PLAIN_TEXT} import com.intel.analytics.bigdl.ppml.utils.Supportive import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.input.PortableDataStream @@ -62,14 +61,13 @@ class PPMLContext protected(kms: KeyManagementService, sparkSession: SparkSessio */ def textFile(path: String, minPartitions: Int = sparkSession.sparkContext.defaultMinPartitions, - cryptoMode: CryptoMode = CryptoMode.PLAIN_TEXT): RDD[String] = { + cryptoMode: CryptoMode = PLAIN_TEXT): RDD[String] = { cryptoMode match { - case CryptoMode.PLAIN_TEXT => + case PLAIN_TEXT => sparkSession.sparkContext.textFile(path, minPartitions) - case CryptoMode.AES_CBC_PKCS5PADDING => - PPMLContext.textFile(sparkSession.sparkContext, path, dataKeyPlainText, minPartitions) case _ => - throw new IllegalArgumentException("unknown EncryptMode " + cryptoMode.toString) + PPMLContext.textFile(sparkSession.sparkContext, path, dataKeyPlainText, + cryptoMode, minPartitions) } } @@ -90,10 +88,10 @@ class PPMLContext protected(kms: KeyManagementService, sparkSession: SparkSessio */ def write(dataFrame: DataFrame, cryptoMode: CryptoMode): DataFrameWriter[Row] = { cryptoMode match { - case CryptoMode.PLAIN_TEXT => + case PLAIN_TEXT => dataFrame.write - case CryptoMode.AES_CBC_PKCS5PADDING => - PPMLContext.write(sparkSession, dataKeyPlainText, dataFrame) + case AES_CBC_PKCS5PADDING => + PPMLContext.write(sparkSession, cryptoMode, dataKeyPlainText, dataFrame) case _ => throw new IllegalArgumentException("unknown EncryptMode " + cryptoMode.toString) } @@ -101,12 +99,15 @@ class PPMLContext protected(kms: KeyManagementService, sparkSession: SparkSessio } object PPMLContext{ - private[bigdl] def registerUDF(spark: SparkSession, - dataKeyPlaintext: String) = { + private[bigdl] def registerUDF( + spark: SparkSession, + cryptoMode: CryptoMode, + dataKeyPlaintext: String) = { val bcKey = spark.sparkContext.broadcast(dataKeyPlaintext) val convertCase = (x: String) => { - val fernetCryptos = new FernetEncrypt() - new String(fernetCryptos.encryptBytes(x.getBytes, bcKey.value)) + val crypto = Crypto(cryptoMode) + crypto.init(cryptoMode, ENCRYPT, dataKeyPlaintext) + new String(crypto.doFinal(x.getBytes)._1) } spark.udf.register("convertUDF", convertCase) } @@ -114,6 +115,7 @@ object PPMLContext{ private[bigdl] def textFile(sc: SparkContext, path: String, dataKeyPlaintext: String, + cryptoMode: CryptoMode, minPartitions: Int = -1): RDD[String] = { Log4Error.invalidInputError(dataKeyPlaintext != "", "dataKeyPlainText should not be empty, please loadKeys first.") @@ -122,19 +124,22 @@ object PPMLContext{ } else { sc.binaryFiles(path) } - val fernetCryptos = new FernetEncrypt data.mapPartitions { iterator => { Supportive.logger.info("Decrypting bytes with JavaAESCBC...") - fernetCryptos.decryptBigContent(iterator, dataKeyPlaintext) + val crypto = Crypto(cryptoMode) + crypto.init(cryptoMode, DECRYPT, dataKeyPlaintext) + crypto.decryptBigContent(iterator) }}.flatMap(_.split("\n")) } - private[bigdl] def write(sparkSession: SparkSession, - dataKeyPlaintext: String, - dataFrame: DataFrame): DataFrameWriter[Row] = { + private[bigdl] def write( + sparkSession: SparkSession, + cryptoMode: CryptoMode, + dataKeyPlaintext: String, + dataFrame: DataFrame): DataFrameWriter[Row] = { val tableName = "ppml_save_table" dataFrame.createOrReplaceTempView(tableName) - PPMLContext.registerUDF(sparkSession, dataKeyPlaintext) + PPMLContext.registerUDF(sparkSession, cryptoMode, dataKeyPlaintext) // Select all and encrypt columns. val convertSql = "select " + dataFrame.schema.map(column => "convertUDF(" + column.name + ") as " + column.name).mkString(", ") + diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/BigDLEncrypt.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/BigDLEncrypt.scala new file mode 100644 index 00000000000..3ce60bdaebc --- /dev/null +++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/BigDLEncrypt.scala @@ -0,0 +1,210 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.bigdl.ppml.crypto + +import com.intel.analytics.bigdl.dllib.utils.{File, Log4Error} +import com.intel.analytics.bigdl.ppml.crypto.CryptoMode +import org.apache.hadoop.fs.Path + +import java.io._ +import java.security.SecureRandom +import java.time.Instant +import java.util.Arrays +import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} +import javax.crypto.{Cipher, Mac} +import org.apache.spark.input.PortableDataStream + +import java.nio.ByteBuffer +import scala.util.Random + +class BigDLEncrypt extends Crypto { + protected var cipher: Cipher = null + protected var mac: Mac = null + protected var ivParameterSpec: IvParameterSpec = null + protected var opMode: OperationMode = null + protected var initializationVector: Array[Byte] = null + override def init(cryptoMode: CryptoMode, mode: OperationMode, dataKeyPlaintext: String): Unit = { + opMode = mode + val secret = dataKeyPlaintext.getBytes() + // key encrypt + val signingKey = Arrays.copyOfRange(secret, 0, 16) + val encryptKey = Arrays.copyOfRange(secret, 16, 32) +// initializationVector = Arrays.copyOfRange(secret, 0, 16) + val r = new Random(signingKey.sum) + initializationVector = Array.tabulate(16)(_ => (r.nextInt(256) - 128).toByte) + ivParameterSpec = new IvParameterSpec(initializationVector) + val encryptionKeySpec = new SecretKeySpec(encryptKey, cryptoMode.secretKeyAlgorithm) + cipher = Cipher.getInstance(cryptoMode.encryptionAlgorithm) + cipher.init(mode.opmode, encryptionKeySpec, ivParameterSpec) + mac = Mac.getInstance(cryptoMode.signingAlgorithm) + val signingKeySpec = new SecretKeySpec(signingKey, cryptoMode.signingAlgorithm) + mac.init(signingKeySpec) + } + + protected var signingDataStream: DataOutputStream = null + + override def genFileHeader(): Array[Byte] = { + Log4Error.invalidOperationError(cipher != null, + s"you should init BigDLEncrypt first.") + val timestamp: Instant = Instant.now() + val signingByteBuffer = ByteBuffer.allocate(1 + 8 + ivParameterSpec.getIV.length) + val version: Byte = (0x80).toByte + signingByteBuffer.put(version) + signingByteBuffer.putLong(timestamp.getEpochSecond()) + signingByteBuffer.put(ivParameterSpec.getIV()) + signingByteBuffer.array() + } + + override def verifyFileHeader(header: Array[Byte]): Unit = { + val headerBuffer = ByteBuffer.wrap(header) + val version: Byte = headerBuffer.get() + if (version.compare((0x80).toByte) != 0) { + throw new EncryptRuntimeException("File header version error!") + } + val timestampSeconds: Long = headerBuffer.getLong + val initializationVector: Array[Byte] = header.slice(1 + 8, header.length) + if (!initializationVector.sameElements(this.initializationVector)) { + throw new EncryptRuntimeException("File header not match!" + + "expected: " + this.initializationVector.mkString(",") + + ", but got: " + initializationVector.mkString(", ")) + } + } + + override def update(content: Array[Byte]): Array[Byte] = { + val cipherText: Array[Byte] = cipher.update(content) + mac.update(cipherText) + cipherText + } + + override def update(content: Array[Byte], offset: Int, len: Int): Array[Byte] = { + val cipherText: Array[Byte] = cipher.update(content, offset, len) + mac.update(cipherText, offset, len) + cipherText + } + + override def doFinal(content: Array[Byte]): (Array[Byte], Array[Byte]) = { + val cipherText: Array[Byte] = cipher.doFinal(content) + val hmac: Array[Byte] = mac.doFinal(cipherText) + (cipherText, hmac) + } + + override def doFinal(content: Array[Byte], offset: Int, len: Int): (Array[Byte], Array[Byte]) = { + val cipherText: Array[Byte] = cipher.doFinal(content, offset, len) + val hmac: Array[Byte] = mac.doFinal(cipherText.slice(offset, offset + len)) + (cipherText, hmac) + } + + val blockSize = 1024 * 1024 // 1m per update + val byteBuffer = new Array[Byte](blockSize) + override def encryptStream(inputStream: DataInputStream, outputStream: DataOutputStream): Unit = { + val header = genFileHeader() + outputStream.write(header) + while (inputStream.available() > blockSize) { + val readLen = inputStream.read(byteBuffer) + outputStream.write(update(byteBuffer, 0, readLen)) + } + val last = inputStream.read(byteBuffer) + val (lastSlice, hmac) = doFinal(byteBuffer, 0, last) + outputStream.write(lastSlice) + outputStream.write(hmac) + outputStream.flush() + } + + val hmacSize = 32 + override def decryptStream(inputStream: DataInputStream, outputStream: DataOutputStream): Unit = { + val header = read(inputStream, 25) + verifyFileHeader(header) + while (inputStream.available() > blockSize) { + val readLen = inputStream.read(byteBuffer) + outputStream.write(update(byteBuffer, 0, readLen)) + } + val last = inputStream.read(byteBuffer) + val inputHmac = byteBuffer.slice(last - hmacSize, last) + val (lastSlice, streamHmac) = doFinal(byteBuffer, 0, last - hmacSize) + if(inputHmac.sameElements(streamHmac)) { + throw new EncryptRuntimeException("hmac not match") + } + outputStream.write(lastSlice) + outputStream.flush() + } + + override def decryptFile(binaryFilePath: String, savePath: String): Unit = { + Log4Error.invalidInputError(savePath != null && savePath != "", + "decrypted file save path should be specified") + val fs = File.getFileSystem(binaryFilePath) + val bis = fs.open(new Path(binaryFilePath)) + val outs = fs.create(new Path(savePath)) + encryptStream(bis, outs) + bis.close() + outs.close() + } + + override def encryptFile(binaryFilePath: String, savePath: String): Unit = { + Log4Error.invalidInputError(savePath != null && savePath != "", + "decrypted file save path should be specified") + val fs = File.getFileSystem(binaryFilePath) + val bis = fs.open(new Path(binaryFilePath)) + val outs = fs.create(new Path(savePath)) + decryptStream(bis, outs) + bis.close() + outs.close() + } + + private def read(stream: DataInputStream, numBytes: Int): Array[Byte] = { + val retval = new Array[Byte](numBytes) + val bytesRead: Int = stream.read(retval) + if (bytesRead < numBytes) { + throw new EncryptRuntimeException("Not enough bits to read!") + } + retval + } + + override def decryptBigContent( + ite: Iterator[(String, PortableDataStream)]): Iterator[String] = { + var result: Iterator[String] = Iterator[String]() + + while (ite.hasNext == true) { + val inputStream: DataInputStream = ite.next._2.open() + verifyFileHeader(read(inputStream, 25)) + + // do first + var lastString = "" + while (inputStream.available() > blockSize) { + val readLen = inputStream.read(byteBuffer) + Log4Error.unKnowExceptionError(readLen != blockSize) + val currentSplitDecryptString = new String(byteBuffer, 0, readLen) + val splitDecryptString = lastString + currentSplitDecryptString + val splitDecryptStringArray = splitDecryptString.split("\r").flatMap(_.split("\n")) + lastString = splitDecryptStringArray.last + result = result ++ splitDecryptStringArray.dropRight(1) + } + // do last + val last = inputStream.read(byteBuffer) + val inputHmac = byteBuffer.slice(last - hmacSize, last) + val (lastSlice, streamHmac) = doFinal(byteBuffer, 0, last - hmacSize) + if (inputHmac.sameElements(streamHmac)) { + throw new EncryptRuntimeException("hmac not match") + } + val lastDecryptString = lastString + new String(lastSlice) + val splitDecryptStringArray = lastDecryptString.split("\r").flatMap(_.split("\n")) + result = result ++ splitDecryptStringArray + } + result + + } + +} diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/Crypto.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/Crypto.scala index 40baadb69e6..764ed6e3d84 100644 --- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/Crypto.scala +++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/Crypto.scala @@ -16,25 +16,108 @@ package com.intel.analytics.bigdl.ppml.crypto +import com.intel.analytics.bigdl.dllib.nn.NormMode.Value import com.intel.analytics.bigdl.ppml.utils.Supportive +import org.apache.spark.input.PortableDataStream + +import java.io.{DataInputStream, DataOutputStream} +import javax.crypto.Cipher trait Crypto extends Supportive with Serializable { - def encryptFile(sourceFilePath: String, saveFilePath: String, dataKeyPlaintext: String) - def decryptFile(sourceFilePath: String, saveFilePath: String, dataKeyPlaintext: String) - def encryptBytes(sourceBytes: Array[Byte], dataKeyPlaintext: String): Array[Byte] - def decryptBytes(sourceBytes: Array[Byte], dataKeyPlaintext: String): Array[Byte] -} - -object CryptoMode extends Enumeration { - type CryptoMode = Value - val PLAIN_TEXT = value("plain_text", "plain_text") - val AES_CBC_PKCS5PADDING = value("AES/CBC/PKCS5Padding", "AES/CBC/PKCS5Padding") - val UNKNOWN = value("UNKNOWN", "UNKNOWN") - class EncryptModeEnumVal(name: String, val value: String) extends Val(nextId, name) - protected final def value(name: String, value: String): EncryptModeEnumVal = { - new EncryptModeEnumVal(name, value) - } - def parse(s: String): Value = { - values.find(_.toString.toLowerCase() == s.toLowerCase).getOrElse(CryptoMode.UNKNOWN) + def init(cryptoMode: CryptoMode, mode: OperationMode, dataKeyPlaintext: String): Unit + + def decryptBigContent(ite: Iterator[(String, PortableDataStream)]): Iterator[String] + + def genFileHeader(): Array[Byte] + + def verifyFileHeader(header: Array[Byte]): Unit + + def update(content: Array[Byte]): Array[Byte] + + def update(content: Array[Byte], offset: Int, len: Int): Array[Byte] + + def doFinal(content: Array[Byte]): (Array[Byte], Array[Byte]) + + def doFinal(content: Array[Byte], offset: Int, len: Int): (Array[Byte], Array[Byte]) + + def encryptStream(inputStream: DataInputStream, outputStream: DataOutputStream): Unit + + def decryptStream(inputStream: DataInputStream, outputStream: DataOutputStream): Unit + + def decryptFile(binaryFilePath: String, savePath: String): Unit + + def encryptFile(binaryFilePath: String, savePath: String): Unit +} + +object Crypto { + def apply(cryptoMode: CryptoMode): Crypto = { + cryptoMode match { + case AES_CBC_PKCS5PADDING => + new BigDLEncrypt() + case default => + throw new EncryptRuntimeException("No such crypto mode!") + } } } + +// object OperationMode extends Enumeration { +// type OperationMode = Value +// val ENCRYPT, DECRYPT = Value +// } +sealed trait OperationMode extends Serializable { + def opmode: Int +} + +case object ENCRYPT extends OperationMode { + override def opmode: Int = Cipher.ENCRYPT_MODE +} +case object DECRYPT extends OperationMode { + override def opmode: Int = Cipher.DECRYPT_MODE +} + +trait CryptoMode extends Serializable { + def encryptionAlgorithm: String + def signingAlgorithm: String + def secretKeyAlgorithm: String +} + +object CryptoMode { + def parse(s: String): CryptoMode = { + s.toLowerCase() match { + case "aes/cbc/pkcs5padding" => + AES_CBC_PKCS5PADDING + case "plain_text" => + PLAIN_TEXT + } + } +} + +case object AES_CBC_PKCS5PADDING extends CryptoMode { + override def encryptionAlgorithm: String = "AES/CBC/PKCS5Padding" + + override def signingAlgorithm: String = "HmacSHA256" + + override def secretKeyAlgorithm: String = "AES" +} + +case object PLAIN_TEXT extends CryptoMode { + override def encryptionAlgorithm: String = "plain_text" + + override def signingAlgorithm: String = "" + + override def secretKeyAlgorithm: String = "" +} + +// object CryptoMode extends Enumeration { +// type CryptoMode = Value +// val PLAIN_TEXT = value("plain_text", "plain_text") +// val AES_CBC_PKCS5PADDING = value("AES/CBC/PKCS5Padding", "AES/CBC/PKCS5Padding") +// val UNKNOWN = value("UNKNOWN", "UNKNOWN") +// class EncryptModeEnumVal(name: String, val value: String) extends Val(nextId, name) +// protected final def value(name: String, value: String): EncryptModeEnumVal = { +// new EncryptModeEnumVal(name, value) +// } +// def parse(s: String): Value = { +// values.find(_.toString.toLowerCase() == s.toLowerCase).getOrElse(CryptoMode.UNKNOWN) +// } +// } diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/FernetEncrypt.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/FernetEncrypt.scala deleted file mode 100644 index f10c261abed..00000000000 --- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/FernetEncrypt.scala +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Copyright 2016 The BigDL Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.intel.analytics.bigdl.ppml.crypto - -import com.intel.analytics.bigdl.dllib.utils.Log4Error - -import java.io._ -import java.nio.file.{Files, Path, Paths} -import java.security.SecureRandom -import java.time.Instant -import java.util.Arrays -import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} -import javax.crypto.{Cipher, Mac} -import org.apache.spark.input.PortableDataStream - -class FernetEncrypt extends Crypto { - - def encryptFile(binaryFilePath: String, savePath: String, dataKeyPlaintext: String): Unit = { - Log4Error.invalidInputError(savePath != null && savePath != "", - "encrypted file save path should be specified") - // Plaintext original file is read as binary - val content: Array[Byte] = readBinaryFile(binaryFilePath) - val encryptedBytes = timing("FernetCryptos encrypting a single file") { - encryptContent(content, dataKeyPlaintext) - } - timing("FernetCryptos save a encrypted file") { - writeBinaryFile(savePath, encryptedBytes) - } - } - - def decryptFile(binaryFilePath: String, savePath: String, dataKeyPlaintext: String): Unit = { - Log4Error.invalidInputError(savePath != null && savePath != "", - "decrypted file save path should be specified") - val content: Array[Byte] = readBinaryFile(binaryFilePath) // Ciphertext file is read into Bytes - val decryptedBytes = timing("FernetCryptos decrypt a single file...") { - decryptContent(content, dataKeyPlaintext) - } - timing("FernetCryptos save a decrypted file") { - writeBinaryFile(savePath, decryptedBytes) - } - } - - def encryptBytes(sourceBytes: Array[Byte], dataKeyPlaintext: String): Array[Byte] = { - encryptContent(sourceBytes, dataKeyPlaintext) - } - - def decryptBytes(sourceBytes: Array[Byte], dataKeyPlaintext: String): Array[Byte] = { - timing("FernetCryptos decrypting bytes") { - decryptContent(sourceBytes, dataKeyPlaintext) - } - } - - private def readBinaryFile(binaryFilePath: String): Array[Byte] = { - Files.readAllBytes(Paths.get(binaryFilePath)) - } - - private def writeBinaryFile(savePath: String, content: Array[Byte]): Path = { - Files.write(Paths.get(savePath), content) - } - - private def writeStringToFile(savePath: String, content: String): Unit = { - val bw = new BufferedWriter(new FileWriter(new File(savePath))) - bw.write(content) - } - - private def read(stream: DataInputStream, numBytes: Int): Array[Byte] = { - val retval = new Array[Byte](numBytes) - val bytesRead: Int = stream.read(retval) - if (bytesRead < numBytes) { - throw new EncryptRuntimeException("Not enough bits to read!") - } - retval - } - - private def encryptContent(content: Array[Byte], dataKeyPlaintext: String): Array[Byte] = { - - val secret = dataKeyPlaintext.getBytes() - - // get IV - val random = new SecureRandom() - val initializationVector: Array[Byte] = new Array[Byte](16) - random.nextBytes(initializationVector) - val ivParameterSpec: IvParameterSpec = new IvParameterSpec(initializationVector) - - // key encrypt - val signingKey: Array[Byte] = Arrays.copyOfRange(secret, 0, 16) - val encryptKey: Array[Byte] = Arrays.copyOfRange(secret, 16, 32) - val encryptionKeySpec: SecretKeySpec = new SecretKeySpec(encryptKey, "AES") - - val cipher: Cipher = Cipher.getInstance(CryptoMode.AES_CBC_PKCS5PADDING.value) - cipher.init(Cipher.ENCRYPT_MODE, encryptionKeySpec, ivParameterSpec) - - val cipherText: Array[Byte] = cipher.doFinal(content) - val timestamp: Instant = Instant.now() - - // sign - val byteStream: ByteArrayOutputStream = new ByteArrayOutputStream(25 + cipherText.length) - val dataStream: DataOutputStream = new DataOutputStream(byteStream) - - val version: Byte = (0x80).toByte - dataStream.writeByte(version) - dataStream.writeLong(timestamp.getEpochSecond()) - dataStream.write(ivParameterSpec.getIV()) - dataStream.write(cipherText) - - val mac: Mac = Mac.getInstance("HmacSHA256") - val signingKeySpec = new SecretKeySpec(signingKey, "HmacSHA256") - mac.init(signingKeySpec) - val hmac: Array[Byte] = mac.doFinal(byteStream.toByteArray()) - - // to bytes - val outByteStream: ByteArrayOutputStream = new ByteArrayOutputStream(57 + cipherText.length) - val dataOutStream: DataOutputStream = new DataOutputStream(outByteStream) - dataOutStream.writeByte(version) - dataOutStream.writeLong(timestamp.getEpochSecond()) - dataOutStream.write(ivParameterSpec.getIV()) - dataOutStream.write(cipherText) - dataOutStream.write(hmac) - - if (timestamp == null) { - throw new EncryptRuntimeException("Timestamp cannot be null") - } - if (ivParameterSpec == null || ivParameterSpec.getIV().length != 16) { - throw new EncryptRuntimeException("Initialization Vector must be 128 bits") - } - if (cipherText == null || cipherText.length % 16 != 0) { - throw new EncryptRuntimeException("Ciphertext must be a multkmsServerIPle of 128 bits") - } - if (hmac == null || hmac.length != 32) { - throw new EncryptRuntimeException("Hmac must be 256 bits") - } - - outByteStream.toByteArray() - } - - private def decryptContent(content: Array[Byte], dataKeyPlaintext: String): Array[Byte] = { - - val secret: Array[Byte] = dataKeyPlaintext.getBytes() - - val inputStream: ByteArrayInputStream = new ByteArrayInputStream(content) - val dataStream: DataInputStream = new DataInputStream(inputStream) - val version: Byte = dataStream.readByte() - if (version.compare((0x80).toByte) != 0) { - throw new EncryptRuntimeException("Version error!") - } - val encryptKey: Array[Byte] = Arrays.copyOfRange(secret, 16, 32) - - val timestampSeconds: Long = dataStream.readLong() - - val initializationVector: Array[Byte] = read(dataStream, 16) - val ivParameterSpec = new IvParameterSpec(initializationVector) - - val cipherText: Array[Byte] = read(dataStream, content.length - 57) - - val hmac: Array[Byte] = read(dataStream, 32) - if (initializationVector.length != 16) { - throw new EncryptRuntimeException("Initialization Vector must be 128 bits") - } - if (cipherText == null || cipherText.length % 16 != 0) { - throw new EncryptRuntimeException("Ciphertext must be a multkmsServerIPle of 128 bits") - } - if (hmac == null || hmac.length != 32) { - throw new EncryptRuntimeException("hmac must be 256 bits") - } - - val secretKeySpec = new SecretKeySpec(encryptKey, "AES") - val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding") - cipher.init(Cipher.DECRYPT_MODE, secretKeySpec, ivParameterSpec) - cipher.doFinal(cipherText) - } - - def decryptBigContent( - ite: Iterator[(String, PortableDataStream)], - dataKeyPlaintext: String): Iterator[String] = { - val secret: Array[Byte] = dataKeyPlaintext.getBytes() - var result: Iterator[String] = Iterator[String]() - - while (ite.hasNext == true) { - val inputStream: DataInputStream = ite.next._2.open() - val version: Byte = inputStream.readByte() - if (version.compare((0x80).toByte) != 0) { - throw new EncryptRuntimeException("Version error!") - } - val encryptKey: Array[Byte] = Arrays.copyOfRange(secret, 16, 32) - - val timestampSeconds: Long = inputStream.readLong() - val initializationVector: Array[Byte] = read(inputStream, 16) - val ivParameterSpec = new IvParameterSpec(initializationVector) - - val secretKeySpec = new SecretKeySpec(encryptKey, "AES") - val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding") - cipher.init(Cipher.DECRYPT_MODE, secretKeySpec, ivParameterSpec) - - val blockSize = 102400000 // 100m per update - var lastString = "" - while (inputStream.available() > blockSize + 32) { - val splitEncryptedBytes = read(inputStream, blockSize) - val currentSplitDecryptString = new String(cipher.update(splitEncryptedBytes)) - val splitDecryptString = lastString + currentSplitDecryptString - val splitDecryptStringArray = splitDecryptString.split("\r").flatMap(_.split("\n")) - lastString = splitDecryptStringArray.last - result = result ++ splitDecryptStringArray.dropRight(1) - } - - val lastCipherText: Array[Byte] = read(inputStream, inputStream.available() - 32) - val lastDecryptString = lastString + (new String(cipher.doFinal(lastCipherText))) - val splitDecryptStringArray = lastDecryptString.split("\r").flatMap(_.split("\n")) - result = result ++ splitDecryptStringArray - - val hmac: Array[Byte] = read(inputStream, 32) - if (initializationVector.length != 16) { - throw new EncryptRuntimeException("Initialization Vector must be 128 bits") - } - if (hmac == null || hmac.length != 32) { - throw new EncryptRuntimeException("hmac must be 256 bits") - } - if (inputStream.available > 0) { - throw new EncryptRuntimeException("inputStream still has contents") - } - } - result - - } - -} diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameReader.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameReader.scala index 90b995008f7..29ac00bda87 100644 --- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameReader.scala +++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameReader.scala @@ -17,8 +17,7 @@ package com.intel.analytics.bigdl.ppml.crypto.dataframe import com.intel.analytics.bigdl.ppml.PPMLContext -import com.intel.analytics.bigdl.ppml.crypto.CryptoMode -import com.intel.analytics.bigdl.ppml.crypto.CryptoMode.CryptoMode +import com.intel.analytics.bigdl.ppml.crypto.{AES_CBC_PKCS5PADDING, CryptoMode, PLAIN_TEXT} import com.intel.analytics.bigdl.ppml.crypto.dataframe.EncryptedDataFrameReader.toDataFrame import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{StringType, StructField, StructType} @@ -42,11 +41,11 @@ class EncryptedDataFrameReader( } def csv(path: String): DataFrame = { encryptMode match { - case CryptoMode.PLAIN_TEXT => + case PLAIN_TEXT => sparkSession.read.options(extraOptions).csv(path) - case CryptoMode.AES_CBC_PKCS5PADDING => + case AES_CBC_PKCS5PADDING => val rdd = PPMLContext.textFile(sparkSession.sparkContext, path, - dataKeyPlainText) + dataKeyPlainText, encryptMode) // TODO: support more options if (extraOptions.contains("header") && extraOptions("header").toLowerCase() == "true") { diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameWriter.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameWriter.scala new file mode 100644 index 00000000000..b2df831dbe1 --- /dev/null +++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameWriter.scala @@ -0,0 +1,29 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.bigdl.ppml.crypto.dataframe + +import com.intel.analytics.bigdl.ppml.crypto.CryptoMode +import org.apache.spark.sql.SparkSession + +class EncryptedDataFrameWriter( + sparkSession: SparkSession, + encryptMode: CryptoMode, + dataKeyPlainText: String) { + + + +} diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/utils/EncryptIOArguments.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/utils/EncryptIOArguments.scala index b35c49224e0..a6703096d5d 100644 --- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/utils/EncryptIOArguments.scala +++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/utils/EncryptIOArguments.scala @@ -17,8 +17,7 @@ package com.intel.analytics.bigdl.ppml.utils import com.intel.analytics.bigdl.ppml.PPMLContext -import com.intel.analytics.bigdl.ppml.crypto.{CryptoMode, EncryptRuntimeException} -import com.intel.analytics.bigdl.ppml.crypto.CryptoMode.CryptoMode +import com.intel.analytics.bigdl.ppml.crypto.{CryptoMode, EncryptRuntimeException, PLAIN_TEXT} import com.intel.analytics.bigdl.ppml.kms.{EHSMKeyManagementService, KMS_CONVENTION, SimpleKeyManagementService} import java.io.File @@ -26,8 +25,8 @@ import java.io.File case class EncryptIOArguments( inputPath: String = "./input", outputPath: String = "./output", - inputEncryptMode: CryptoMode = CryptoMode.PLAIN_TEXT, - outputEncryptMode: CryptoMode = CryptoMode.PLAIN_TEXT, + inputEncryptMode: CryptoMode = PLAIN_TEXT, + outputEncryptMode: CryptoMode = PLAIN_TEXT, inputPartitionNum: Int = 4, outputPartitionNum: Int = 4, primaryKeyPath: String = "./primaryKeyPath", diff --git a/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/EncryptSpec.scala b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/EncryptSpec.scala new file mode 100644 index 00000000000..eb16ffe9e88 --- /dev/null +++ b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/EncryptSpec.scala @@ -0,0 +1,89 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.bigdl.ppml.crypto + +import com.intel.analytics.bigdl.dllib.common.zooUtils +import com.intel.analytics.bigdl.dllib.utils.File +import com.intel.analytics.bigdl.ppml.kms.SimpleKeyManagementService +import org.apache.hadoop.fs.Path +import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} + +import java.io.{BufferedReader, BufferedWriter, FileReader, FileWriter} +import java.nio.file.{Files, Paths, StandardOpenOption} +import scala.util.Random + +class EncryptSpec extends FlatSpec with Matchers with BeforeAndAfter { + val (appid, appkey) = generateKeys() + val simpleKms = SimpleKeyManagementService(appid, appkey) + val dir = zooUtils.createTmpDir("PPMLUT", "rwx------").toFile() + val primaryKeyPath = dir + "/primary.key" + val dataKeyPath = dir + "/data.key" + simpleKms.retrievePrimaryKey(primaryKeyPath) + simpleKms.retrieveDataKey(primaryKeyPath, dataKeyPath) + var dataKeyPlaintext: String = null + val (plainFileName, encryptFileName, data) = generateCsvData() + val fs = File.getFileSystem(plainFileName) + + def generateKeys(): (String, String) = { + val appid: String = (1 to 12).map(x => Random.nextInt(10)).mkString + val appkey: String = (1 to 12).map(x => Random.nextInt(10)).mkString + (appid, appkey) + } + def generateCsvData(): (String, String, String) = { + val fileName = dir + "/people.csv" + val encryptFileName = dir + "/en_people.csv" + val fw = new FileWriter(fileName) + val data = new StringBuilder() + data.append(s"name,age,job\n") + data.append(s"yvomq,59,Developer\ngdni,40,Engineer\npglyal,33,Engineer") + fw.append(data) + fw.close() + + val crypto = new BigDLEncrypt() + dataKeyPlaintext = simpleKms.retrieveDataKeyPlainText(primaryKeyPath, dataKeyPath) + crypto.init(AES_CBC_PKCS5PADDING, ENCRYPT, dataKeyPlaintext) + Files.write(Paths.get(encryptFileName), crypto.genFileHeader()) + val encryptedBytes = crypto.doFinal(data.toString().getBytes) + Files.write(Paths.get(encryptFileName), encryptedBytes._1, StandardOpenOption.APPEND) + Files.write(Paths.get(encryptFileName), encryptedBytes._2, StandardOpenOption.APPEND) + (fileName, encryptFileName, data.toString()) + } + + "encrypt stream" should "work" in { + val encrypt = new BigDLEncrypt() + encrypt.init(AES_CBC_PKCS5PADDING, ENCRYPT, dataKeyPlaintext) + val bis = fs.open(new Path(plainFileName)) + val outs = fs.create(new Path(dir + "/en_o.csv")) + encrypt.encryptStream(bis, outs) + bis.close() + outs.flush() + outs.close() + Thread.sleep(1000) + + val decrypt = new BigDLEncrypt() + decrypt.init(AES_CBC_PKCS5PADDING, DECRYPT, dataKeyPlaintext) + val bis2 = fs.open(new Path(dir + "/en_o.csv")) + val outs2 = fs.create(new Path(dir + "/de_o.csv")) + decrypt.decryptStream(bis2, outs2) + outs2.close() + outs2.flush() + bis2.close() + val originFile = Files.readAllBytes(Paths.get(plainFileName)) + val deFile = Files.readAllBytes(Paths.get(dir.toString, "/de_o.csv")) + originFile.sameElements(deFile) should be (true) + } +} diff --git a/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptDataFrameSpec.scala b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptDataFrameSpec.scala index 39cb57809e6..5e352c11912 100644 --- a/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptDataFrameSpec.scala +++ b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptDataFrameSpec.scala @@ -18,14 +18,14 @@ package com.intel.analytics.bigdl.ppml.crypto.dataframe import com.intel.analytics.bigdl.dllib.common.zooUtils import com.intel.analytics.bigdl.ppml.PPMLContext -import com.intel.analytics.bigdl.ppml.crypto.{CryptoMode, FernetEncrypt} +import com.intel.analytics.bigdl.ppml.crypto.{AES_CBC_PKCS5PADDING, CryptoMode, ENCRYPT, BigDLEncrypt, PLAIN_TEXT} import com.intel.analytics.bigdl.ppml.kms.SimpleKeyManagementService import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} import java.io.FileWriter -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Paths, StandardOpenOption} import scala.util.Random class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{ @@ -55,10 +55,13 @@ class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{ fw.append(data) fw.close() - val fernetCryptos = new FernetEncrypt() + val crypto = new BigDLEncrypt() val dataKeyPlaintext = simpleKms.retrieveDataKeyPlainText(primaryKeyPath, dataKeyPath) - val encryptedBytes = fernetCryptos.encryptBytes(data.toString().getBytes, dataKeyPlaintext) - Files.write(Paths.get(encryptFileName), encryptedBytes) + crypto.init(AES_CBC_PKCS5PADDING, ENCRYPT, dataKeyPlaintext) + Files.write(Paths.get(encryptFileName), crypto.genFileHeader()) + val encryptedBytes = crypto.doFinal(data.toString().getBytes) + Files.write(Paths.get(encryptFileName), encryptedBytes._1, StandardOpenOption.APPEND) + Files.write(Paths.get(encryptFileName), encryptedBytes._2, StandardOpenOption.APPEND) (fileName, encryptFileName, data.toString()) } val ppmlArgs = Map( @@ -73,7 +76,7 @@ class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{ "textfile read from plaint text file" should "work" in { val file = sc.textFile(plainFileName).collect() file.mkString("\n") should be (data) - val file2 = sc.textFile(encryptFileName, cryptoMode = CryptoMode.AES_CBC_PKCS5PADDING).collect() + val file2 = sc.textFile(encryptFileName, cryptoMode = AES_CBC_PKCS5PADDING).collect() file2.mkString("\n") should be (data) } @@ -89,7 +92,7 @@ class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{ } "read from plain csv with header" should "work" in { - val df = sc.read(cryptoMode = CryptoMode.PLAIN_TEXT) + val df = sc.read(cryptoMode = PLAIN_TEXT) .option("header", "true").csv(plainFileName) val d = df.schema.map(_.name).mkString(",") + "\n" + df.collect().map(v => s"${v.get(0)},${v.get(1)},${v.get(2)}").mkString("\n") @@ -97,7 +100,7 @@ class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{ } "read from encrypted csv with header" should "work" in { - val df = sc.read(cryptoMode = CryptoMode.AES_CBC_PKCS5PADDING) + val df = sc.read(cryptoMode = AES_CBC_PKCS5PADDING) .option("header", "true").csv(encryptFileName) val d = df.schema.map(_.name).mkString(",") + "\n" + df.collect().map(v => s"${v.get(0)},${v.get(1)},${v.get(2)}").mkString("\n") @@ -105,13 +108,13 @@ class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{ } "read from plain csv without header" should "work" in { - val df = sc.read(cryptoMode = CryptoMode.PLAIN_TEXT).csv(plainFileName) + val df = sc.read(cryptoMode = PLAIN_TEXT).csv(plainFileName) val d = df.collect().map(v => s"${v.get(0)},${v.get(1)},${v.get(2)}").mkString("\n") d should be (data) } "read from encrypted csv without header" should "work" in { - val df = sc.read(cryptoMode = CryptoMode.AES_CBC_PKCS5PADDING).csv(encryptFileName) + val df = sc.read(cryptoMode = AES_CBC_PKCS5PADDING).csv(encryptFileName) val d = df.collect().map(v => s"${v.get(0)},${v.get(1)},${v.get(2)}").mkString("\n") d should be (data) }