diff --git a/scala/ppml/pom.xml b/scala/ppml/pom.xml
index f536533497e..33dc01cc901 100644
--- a/scala/ppml/pom.xml
+++ b/scala/ppml/pom.xml
@@ -12,13 +12,10 @@
jar
- provided
1.37.0
${project.parent.basedir}
-
-
org.apache.spark
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/PPMLContext.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/PPMLContext.scala
index cbca5b6b5bd..17e83ea32a2 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/PPMLContext.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/PPMLContext.scala
@@ -18,8 +18,7 @@ package com.intel.analytics.bigdl.ppml
import com.intel.analytics.bigdl.dllib.NNContext.{checkScalaVersion, checkSparkVersion, createSparkConf, initConf, initNNContext}
import com.intel.analytics.bigdl.dllib.utils.Log4Error
-import com.intel.analytics.bigdl.ppml.crypto.CryptoMode.CryptoMode
-import com.intel.analytics.bigdl.ppml.crypto.{CryptoMode, EncryptRuntimeException, FernetEncrypt}
+import com.intel.analytics.bigdl.ppml.crypto.{AES_CBC_PKCS5PADDING, Crypto, CryptoMode, DECRYPT, ENCRYPT, EncryptRuntimeException, BigDLEncrypt, PLAIN_TEXT}
import com.intel.analytics.bigdl.ppml.utils.Supportive
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.input.PortableDataStream
@@ -62,14 +61,13 @@ class PPMLContext protected(kms: KeyManagementService, sparkSession: SparkSessio
*/
def textFile(path: String,
minPartitions: Int = sparkSession.sparkContext.defaultMinPartitions,
- cryptoMode: CryptoMode = CryptoMode.PLAIN_TEXT): RDD[String] = {
+ cryptoMode: CryptoMode = PLAIN_TEXT): RDD[String] = {
cryptoMode match {
- case CryptoMode.PLAIN_TEXT =>
+ case PLAIN_TEXT =>
sparkSession.sparkContext.textFile(path, minPartitions)
- case CryptoMode.AES_CBC_PKCS5PADDING =>
- PPMLContext.textFile(sparkSession.sparkContext, path, dataKeyPlainText, minPartitions)
case _ =>
- throw new IllegalArgumentException("unknown EncryptMode " + cryptoMode.toString)
+ PPMLContext.textFile(sparkSession.sparkContext, path, dataKeyPlainText,
+ cryptoMode, minPartitions)
}
}
@@ -90,10 +88,10 @@ class PPMLContext protected(kms: KeyManagementService, sparkSession: SparkSessio
*/
def write(dataFrame: DataFrame, cryptoMode: CryptoMode): DataFrameWriter[Row] = {
cryptoMode match {
- case CryptoMode.PLAIN_TEXT =>
+ case PLAIN_TEXT =>
dataFrame.write
- case CryptoMode.AES_CBC_PKCS5PADDING =>
- PPMLContext.write(sparkSession, dataKeyPlainText, dataFrame)
+ case AES_CBC_PKCS5PADDING =>
+ PPMLContext.write(sparkSession, cryptoMode, dataKeyPlainText, dataFrame)
case _ =>
throw new IllegalArgumentException("unknown EncryptMode " + cryptoMode.toString)
}
@@ -101,12 +99,15 @@ class PPMLContext protected(kms: KeyManagementService, sparkSession: SparkSessio
}
object PPMLContext{
- private[bigdl] def registerUDF(spark: SparkSession,
- dataKeyPlaintext: String) = {
+ private[bigdl] def registerUDF(
+ spark: SparkSession,
+ cryptoMode: CryptoMode,
+ dataKeyPlaintext: String) = {
val bcKey = spark.sparkContext.broadcast(dataKeyPlaintext)
val convertCase = (x: String) => {
- val fernetCryptos = new FernetEncrypt()
- new String(fernetCryptos.encryptBytes(x.getBytes, bcKey.value))
+ val crypto = Crypto(cryptoMode)
+ crypto.init(cryptoMode, ENCRYPT, dataKeyPlaintext)
+ new String(crypto.doFinal(x.getBytes)._1)
}
spark.udf.register("convertUDF", convertCase)
}
@@ -114,6 +115,7 @@ object PPMLContext{
private[bigdl] def textFile(sc: SparkContext,
path: String,
dataKeyPlaintext: String,
+ cryptoMode: CryptoMode,
minPartitions: Int = -1): RDD[String] = {
Log4Error.invalidInputError(dataKeyPlaintext != "",
"dataKeyPlainText should not be empty, please loadKeys first.")
@@ -122,19 +124,22 @@ object PPMLContext{
} else {
sc.binaryFiles(path)
}
- val fernetCryptos = new FernetEncrypt
data.mapPartitions { iterator => {
Supportive.logger.info("Decrypting bytes with JavaAESCBC...")
- fernetCryptos.decryptBigContent(iterator, dataKeyPlaintext)
+ val crypto = Crypto(cryptoMode)
+ crypto.init(cryptoMode, DECRYPT, dataKeyPlaintext)
+ crypto.decryptBigContent(iterator)
}}.flatMap(_.split("\n"))
}
- private[bigdl] def write(sparkSession: SparkSession,
- dataKeyPlaintext: String,
- dataFrame: DataFrame): DataFrameWriter[Row] = {
+ private[bigdl] def write(
+ sparkSession: SparkSession,
+ cryptoMode: CryptoMode,
+ dataKeyPlaintext: String,
+ dataFrame: DataFrame): DataFrameWriter[Row] = {
val tableName = "ppml_save_table"
dataFrame.createOrReplaceTempView(tableName)
- PPMLContext.registerUDF(sparkSession, dataKeyPlaintext)
+ PPMLContext.registerUDF(sparkSession, cryptoMode, dataKeyPlaintext)
// Select all and encrypt columns.
val convertSql = "select " + dataFrame.schema.map(column =>
"convertUDF(" + column.name + ") as " + column.name).mkString(", ") +
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/BigDLEncrypt.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/BigDLEncrypt.scala
new file mode 100644
index 00000000000..3ce60bdaebc
--- /dev/null
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/BigDLEncrypt.scala
@@ -0,0 +1,210 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.analytics.bigdl.ppml.crypto
+
+import com.intel.analytics.bigdl.dllib.utils.{File, Log4Error}
+import com.intel.analytics.bigdl.ppml.crypto.CryptoMode
+import org.apache.hadoop.fs.Path
+
+import java.io._
+import java.security.SecureRandom
+import java.time.Instant
+import java.util.Arrays
+import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
+import javax.crypto.{Cipher, Mac}
+import org.apache.spark.input.PortableDataStream
+
+import java.nio.ByteBuffer
+import scala.util.Random
+
+class BigDLEncrypt extends Crypto {
+ protected var cipher: Cipher = null
+ protected var mac: Mac = null
+ protected var ivParameterSpec: IvParameterSpec = null
+ protected var opMode: OperationMode = null
+ protected var initializationVector: Array[Byte] = null
+ override def init(cryptoMode: CryptoMode, mode: OperationMode, dataKeyPlaintext: String): Unit = {
+ opMode = mode
+ val secret = dataKeyPlaintext.getBytes()
+ // key encrypt
+ val signingKey = Arrays.copyOfRange(secret, 0, 16)
+ val encryptKey = Arrays.copyOfRange(secret, 16, 32)
+// initializationVector = Arrays.copyOfRange(secret, 0, 16)
+ val r = new Random(signingKey.sum)
+ initializationVector = Array.tabulate(16)(_ => (r.nextInt(256) - 128).toByte)
+ ivParameterSpec = new IvParameterSpec(initializationVector)
+ val encryptionKeySpec = new SecretKeySpec(encryptKey, cryptoMode.secretKeyAlgorithm)
+ cipher = Cipher.getInstance(cryptoMode.encryptionAlgorithm)
+ cipher.init(mode.opmode, encryptionKeySpec, ivParameterSpec)
+ mac = Mac.getInstance(cryptoMode.signingAlgorithm)
+ val signingKeySpec = new SecretKeySpec(signingKey, cryptoMode.signingAlgorithm)
+ mac.init(signingKeySpec)
+ }
+
+ protected var signingDataStream: DataOutputStream = null
+
+ override def genFileHeader(): Array[Byte] = {
+ Log4Error.invalidOperationError(cipher != null,
+ s"you should init BigDLEncrypt first.")
+ val timestamp: Instant = Instant.now()
+ val signingByteBuffer = ByteBuffer.allocate(1 + 8 + ivParameterSpec.getIV.length)
+ val version: Byte = (0x80).toByte
+ signingByteBuffer.put(version)
+ signingByteBuffer.putLong(timestamp.getEpochSecond())
+ signingByteBuffer.put(ivParameterSpec.getIV())
+ signingByteBuffer.array()
+ }
+
+ override def verifyFileHeader(header: Array[Byte]): Unit = {
+ val headerBuffer = ByteBuffer.wrap(header)
+ val version: Byte = headerBuffer.get()
+ if (version.compare((0x80).toByte) != 0) {
+ throw new EncryptRuntimeException("File header version error!")
+ }
+ val timestampSeconds: Long = headerBuffer.getLong
+ val initializationVector: Array[Byte] = header.slice(1 + 8, header.length)
+ if (!initializationVector.sameElements(this.initializationVector)) {
+ throw new EncryptRuntimeException("File header not match!" +
+ "expected: " + this.initializationVector.mkString(",") +
+ ", but got: " + initializationVector.mkString(", "))
+ }
+ }
+
+ override def update(content: Array[Byte]): Array[Byte] = {
+ val cipherText: Array[Byte] = cipher.update(content)
+ mac.update(cipherText)
+ cipherText
+ }
+
+ override def update(content: Array[Byte], offset: Int, len: Int): Array[Byte] = {
+ val cipherText: Array[Byte] = cipher.update(content, offset, len)
+ mac.update(cipherText, offset, len)
+ cipherText
+ }
+
+ override def doFinal(content: Array[Byte]): (Array[Byte], Array[Byte]) = {
+ val cipherText: Array[Byte] = cipher.doFinal(content)
+ val hmac: Array[Byte] = mac.doFinal(cipherText)
+ (cipherText, hmac)
+ }
+
+ override def doFinal(content: Array[Byte], offset: Int, len: Int): (Array[Byte], Array[Byte]) = {
+ val cipherText: Array[Byte] = cipher.doFinal(content, offset, len)
+ val hmac: Array[Byte] = mac.doFinal(cipherText.slice(offset, offset + len))
+ (cipherText, hmac)
+ }
+
+ val blockSize = 1024 * 1024 // 1m per update
+ val byteBuffer = new Array[Byte](blockSize)
+ override def encryptStream(inputStream: DataInputStream, outputStream: DataOutputStream): Unit = {
+ val header = genFileHeader()
+ outputStream.write(header)
+ while (inputStream.available() > blockSize) {
+ val readLen = inputStream.read(byteBuffer)
+ outputStream.write(update(byteBuffer, 0, readLen))
+ }
+ val last = inputStream.read(byteBuffer)
+ val (lastSlice, hmac) = doFinal(byteBuffer, 0, last)
+ outputStream.write(lastSlice)
+ outputStream.write(hmac)
+ outputStream.flush()
+ }
+
+ val hmacSize = 32
+ override def decryptStream(inputStream: DataInputStream, outputStream: DataOutputStream): Unit = {
+ val header = read(inputStream, 25)
+ verifyFileHeader(header)
+ while (inputStream.available() > blockSize) {
+ val readLen = inputStream.read(byteBuffer)
+ outputStream.write(update(byteBuffer, 0, readLen))
+ }
+ val last = inputStream.read(byteBuffer)
+ val inputHmac = byteBuffer.slice(last - hmacSize, last)
+ val (lastSlice, streamHmac) = doFinal(byteBuffer, 0, last - hmacSize)
+ if(inputHmac.sameElements(streamHmac)) {
+ throw new EncryptRuntimeException("hmac not match")
+ }
+ outputStream.write(lastSlice)
+ outputStream.flush()
+ }
+
+ override def decryptFile(binaryFilePath: String, savePath: String): Unit = {
+ Log4Error.invalidInputError(savePath != null && savePath != "",
+ "decrypted file save path should be specified")
+ val fs = File.getFileSystem(binaryFilePath)
+ val bis = fs.open(new Path(binaryFilePath))
+ val outs = fs.create(new Path(savePath))
+ encryptStream(bis, outs)
+ bis.close()
+ outs.close()
+ }
+
+ override def encryptFile(binaryFilePath: String, savePath: String): Unit = {
+ Log4Error.invalidInputError(savePath != null && savePath != "",
+ "decrypted file save path should be specified")
+ val fs = File.getFileSystem(binaryFilePath)
+ val bis = fs.open(new Path(binaryFilePath))
+ val outs = fs.create(new Path(savePath))
+ decryptStream(bis, outs)
+ bis.close()
+ outs.close()
+ }
+
+ private def read(stream: DataInputStream, numBytes: Int): Array[Byte] = {
+ val retval = new Array[Byte](numBytes)
+ val bytesRead: Int = stream.read(retval)
+ if (bytesRead < numBytes) {
+ throw new EncryptRuntimeException("Not enough bits to read!")
+ }
+ retval
+ }
+
+ override def decryptBigContent(
+ ite: Iterator[(String, PortableDataStream)]): Iterator[String] = {
+ var result: Iterator[String] = Iterator[String]()
+
+ while (ite.hasNext == true) {
+ val inputStream: DataInputStream = ite.next._2.open()
+ verifyFileHeader(read(inputStream, 25))
+
+ // do first
+ var lastString = ""
+ while (inputStream.available() > blockSize) {
+ val readLen = inputStream.read(byteBuffer)
+ Log4Error.unKnowExceptionError(readLen != blockSize)
+ val currentSplitDecryptString = new String(byteBuffer, 0, readLen)
+ val splitDecryptString = lastString + currentSplitDecryptString
+ val splitDecryptStringArray = splitDecryptString.split("\r").flatMap(_.split("\n"))
+ lastString = splitDecryptStringArray.last
+ result = result ++ splitDecryptStringArray.dropRight(1)
+ }
+ // do last
+ val last = inputStream.read(byteBuffer)
+ val inputHmac = byteBuffer.slice(last - hmacSize, last)
+ val (lastSlice, streamHmac) = doFinal(byteBuffer, 0, last - hmacSize)
+ if (inputHmac.sameElements(streamHmac)) {
+ throw new EncryptRuntimeException("hmac not match")
+ }
+ val lastDecryptString = lastString + new String(lastSlice)
+ val splitDecryptStringArray = lastDecryptString.split("\r").flatMap(_.split("\n"))
+ result = result ++ splitDecryptStringArray
+ }
+ result
+
+ }
+
+}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/Crypto.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/Crypto.scala
index 40baadb69e6..764ed6e3d84 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/Crypto.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/Crypto.scala
@@ -16,25 +16,108 @@
package com.intel.analytics.bigdl.ppml.crypto
+import com.intel.analytics.bigdl.dllib.nn.NormMode.Value
import com.intel.analytics.bigdl.ppml.utils.Supportive
+import org.apache.spark.input.PortableDataStream
+
+import java.io.{DataInputStream, DataOutputStream}
+import javax.crypto.Cipher
trait Crypto extends Supportive with Serializable {
- def encryptFile(sourceFilePath: String, saveFilePath: String, dataKeyPlaintext: String)
- def decryptFile(sourceFilePath: String, saveFilePath: String, dataKeyPlaintext: String)
- def encryptBytes(sourceBytes: Array[Byte], dataKeyPlaintext: String): Array[Byte]
- def decryptBytes(sourceBytes: Array[Byte], dataKeyPlaintext: String): Array[Byte]
-}
-
-object CryptoMode extends Enumeration {
- type CryptoMode = Value
- val PLAIN_TEXT = value("plain_text", "plain_text")
- val AES_CBC_PKCS5PADDING = value("AES/CBC/PKCS5Padding", "AES/CBC/PKCS5Padding")
- val UNKNOWN = value("UNKNOWN", "UNKNOWN")
- class EncryptModeEnumVal(name: String, val value: String) extends Val(nextId, name)
- protected final def value(name: String, value: String): EncryptModeEnumVal = {
- new EncryptModeEnumVal(name, value)
- }
- def parse(s: String): Value = {
- values.find(_.toString.toLowerCase() == s.toLowerCase).getOrElse(CryptoMode.UNKNOWN)
+ def init(cryptoMode: CryptoMode, mode: OperationMode, dataKeyPlaintext: String): Unit
+
+ def decryptBigContent(ite: Iterator[(String, PortableDataStream)]): Iterator[String]
+
+ def genFileHeader(): Array[Byte]
+
+ def verifyFileHeader(header: Array[Byte]): Unit
+
+ def update(content: Array[Byte]): Array[Byte]
+
+ def update(content: Array[Byte], offset: Int, len: Int): Array[Byte]
+
+ def doFinal(content: Array[Byte]): (Array[Byte], Array[Byte])
+
+ def doFinal(content: Array[Byte], offset: Int, len: Int): (Array[Byte], Array[Byte])
+
+ def encryptStream(inputStream: DataInputStream, outputStream: DataOutputStream): Unit
+
+ def decryptStream(inputStream: DataInputStream, outputStream: DataOutputStream): Unit
+
+ def decryptFile(binaryFilePath: String, savePath: String): Unit
+
+ def encryptFile(binaryFilePath: String, savePath: String): Unit
+}
+
+object Crypto {
+ def apply(cryptoMode: CryptoMode): Crypto = {
+ cryptoMode match {
+ case AES_CBC_PKCS5PADDING =>
+ new BigDLEncrypt()
+ case default =>
+ throw new EncryptRuntimeException("No such crypto mode!")
+ }
}
}
+
+// object OperationMode extends Enumeration {
+// type OperationMode = Value
+// val ENCRYPT, DECRYPT = Value
+// }
+sealed trait OperationMode extends Serializable {
+ def opmode: Int
+}
+
+case object ENCRYPT extends OperationMode {
+ override def opmode: Int = Cipher.ENCRYPT_MODE
+}
+case object DECRYPT extends OperationMode {
+ override def opmode: Int = Cipher.DECRYPT_MODE
+}
+
+trait CryptoMode extends Serializable {
+ def encryptionAlgorithm: String
+ def signingAlgorithm: String
+ def secretKeyAlgorithm: String
+}
+
+object CryptoMode {
+ def parse(s: String): CryptoMode = {
+ s.toLowerCase() match {
+ case "aes/cbc/pkcs5padding" =>
+ AES_CBC_PKCS5PADDING
+ case "plain_text" =>
+ PLAIN_TEXT
+ }
+ }
+}
+
+case object AES_CBC_PKCS5PADDING extends CryptoMode {
+ override def encryptionAlgorithm: String = "AES/CBC/PKCS5Padding"
+
+ override def signingAlgorithm: String = "HmacSHA256"
+
+ override def secretKeyAlgorithm: String = "AES"
+}
+
+case object PLAIN_TEXT extends CryptoMode {
+ override def encryptionAlgorithm: String = "plain_text"
+
+ override def signingAlgorithm: String = ""
+
+ override def secretKeyAlgorithm: String = ""
+}
+
+// object CryptoMode extends Enumeration {
+// type CryptoMode = Value
+// val PLAIN_TEXT = value("plain_text", "plain_text")
+// val AES_CBC_PKCS5PADDING = value("AES/CBC/PKCS5Padding", "AES/CBC/PKCS5Padding")
+// val UNKNOWN = value("UNKNOWN", "UNKNOWN")
+// class EncryptModeEnumVal(name: String, val value: String) extends Val(nextId, name)
+// protected final def value(name: String, value: String): EncryptModeEnumVal = {
+// new EncryptModeEnumVal(name, value)
+// }
+// def parse(s: String): Value = {
+// values.find(_.toString.toLowerCase() == s.toLowerCase).getOrElse(CryptoMode.UNKNOWN)
+// }
+// }
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/FernetEncrypt.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/FernetEncrypt.scala
deleted file mode 100644
index f10c261abed..00000000000
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/FernetEncrypt.scala
+++ /dev/null
@@ -1,239 +0,0 @@
-/*
- * Copyright 2016 The BigDL Authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.intel.analytics.bigdl.ppml.crypto
-
-import com.intel.analytics.bigdl.dllib.utils.Log4Error
-
-import java.io._
-import java.nio.file.{Files, Path, Paths}
-import java.security.SecureRandom
-import java.time.Instant
-import java.util.Arrays
-import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
-import javax.crypto.{Cipher, Mac}
-import org.apache.spark.input.PortableDataStream
-
-class FernetEncrypt extends Crypto {
-
- def encryptFile(binaryFilePath: String, savePath: String, dataKeyPlaintext: String): Unit = {
- Log4Error.invalidInputError(savePath != null && savePath != "",
- "encrypted file save path should be specified")
- // Plaintext original file is read as binary
- val content: Array[Byte] = readBinaryFile(binaryFilePath)
- val encryptedBytes = timing("FernetCryptos encrypting a single file") {
- encryptContent(content, dataKeyPlaintext)
- }
- timing("FernetCryptos save a encrypted file") {
- writeBinaryFile(savePath, encryptedBytes)
- }
- }
-
- def decryptFile(binaryFilePath: String, savePath: String, dataKeyPlaintext: String): Unit = {
- Log4Error.invalidInputError(savePath != null && savePath != "",
- "decrypted file save path should be specified")
- val content: Array[Byte] = readBinaryFile(binaryFilePath) // Ciphertext file is read into Bytes
- val decryptedBytes = timing("FernetCryptos decrypt a single file...") {
- decryptContent(content, dataKeyPlaintext)
- }
- timing("FernetCryptos save a decrypted file") {
- writeBinaryFile(savePath, decryptedBytes)
- }
- }
-
- def encryptBytes(sourceBytes: Array[Byte], dataKeyPlaintext: String): Array[Byte] = {
- encryptContent(sourceBytes, dataKeyPlaintext)
- }
-
- def decryptBytes(sourceBytes: Array[Byte], dataKeyPlaintext: String): Array[Byte] = {
- timing("FernetCryptos decrypting bytes") {
- decryptContent(sourceBytes, dataKeyPlaintext)
- }
- }
-
- private def readBinaryFile(binaryFilePath: String): Array[Byte] = {
- Files.readAllBytes(Paths.get(binaryFilePath))
- }
-
- private def writeBinaryFile(savePath: String, content: Array[Byte]): Path = {
- Files.write(Paths.get(savePath), content)
- }
-
- private def writeStringToFile(savePath: String, content: String): Unit = {
- val bw = new BufferedWriter(new FileWriter(new File(savePath)))
- bw.write(content)
- }
-
- private def read(stream: DataInputStream, numBytes: Int): Array[Byte] = {
- val retval = new Array[Byte](numBytes)
- val bytesRead: Int = stream.read(retval)
- if (bytesRead < numBytes) {
- throw new EncryptRuntimeException("Not enough bits to read!")
- }
- retval
- }
-
- private def encryptContent(content: Array[Byte], dataKeyPlaintext: String): Array[Byte] = {
-
- val secret = dataKeyPlaintext.getBytes()
-
- // get IV
- val random = new SecureRandom()
- val initializationVector: Array[Byte] = new Array[Byte](16)
- random.nextBytes(initializationVector)
- val ivParameterSpec: IvParameterSpec = new IvParameterSpec(initializationVector)
-
- // key encrypt
- val signingKey: Array[Byte] = Arrays.copyOfRange(secret, 0, 16)
- val encryptKey: Array[Byte] = Arrays.copyOfRange(secret, 16, 32)
- val encryptionKeySpec: SecretKeySpec = new SecretKeySpec(encryptKey, "AES")
-
- val cipher: Cipher = Cipher.getInstance(CryptoMode.AES_CBC_PKCS5PADDING.value)
- cipher.init(Cipher.ENCRYPT_MODE, encryptionKeySpec, ivParameterSpec)
-
- val cipherText: Array[Byte] = cipher.doFinal(content)
- val timestamp: Instant = Instant.now()
-
- // sign
- val byteStream: ByteArrayOutputStream = new ByteArrayOutputStream(25 + cipherText.length)
- val dataStream: DataOutputStream = new DataOutputStream(byteStream)
-
- val version: Byte = (0x80).toByte
- dataStream.writeByte(version)
- dataStream.writeLong(timestamp.getEpochSecond())
- dataStream.write(ivParameterSpec.getIV())
- dataStream.write(cipherText)
-
- val mac: Mac = Mac.getInstance("HmacSHA256")
- val signingKeySpec = new SecretKeySpec(signingKey, "HmacSHA256")
- mac.init(signingKeySpec)
- val hmac: Array[Byte] = mac.doFinal(byteStream.toByteArray())
-
- // to bytes
- val outByteStream: ByteArrayOutputStream = new ByteArrayOutputStream(57 + cipherText.length)
- val dataOutStream: DataOutputStream = new DataOutputStream(outByteStream)
- dataOutStream.writeByte(version)
- dataOutStream.writeLong(timestamp.getEpochSecond())
- dataOutStream.write(ivParameterSpec.getIV())
- dataOutStream.write(cipherText)
- dataOutStream.write(hmac)
-
- if (timestamp == null) {
- throw new EncryptRuntimeException("Timestamp cannot be null")
- }
- if (ivParameterSpec == null || ivParameterSpec.getIV().length != 16) {
- throw new EncryptRuntimeException("Initialization Vector must be 128 bits")
- }
- if (cipherText == null || cipherText.length % 16 != 0) {
- throw new EncryptRuntimeException("Ciphertext must be a multkmsServerIPle of 128 bits")
- }
- if (hmac == null || hmac.length != 32) {
- throw new EncryptRuntimeException("Hmac must be 256 bits")
- }
-
- outByteStream.toByteArray()
- }
-
- private def decryptContent(content: Array[Byte], dataKeyPlaintext: String): Array[Byte] = {
-
- val secret: Array[Byte] = dataKeyPlaintext.getBytes()
-
- val inputStream: ByteArrayInputStream = new ByteArrayInputStream(content)
- val dataStream: DataInputStream = new DataInputStream(inputStream)
- val version: Byte = dataStream.readByte()
- if (version.compare((0x80).toByte) != 0) {
- throw new EncryptRuntimeException("Version error!")
- }
- val encryptKey: Array[Byte] = Arrays.copyOfRange(secret, 16, 32)
-
- val timestampSeconds: Long = dataStream.readLong()
-
- val initializationVector: Array[Byte] = read(dataStream, 16)
- val ivParameterSpec = new IvParameterSpec(initializationVector)
-
- val cipherText: Array[Byte] = read(dataStream, content.length - 57)
-
- val hmac: Array[Byte] = read(dataStream, 32)
- if (initializationVector.length != 16) {
- throw new EncryptRuntimeException("Initialization Vector must be 128 bits")
- }
- if (cipherText == null || cipherText.length % 16 != 0) {
- throw new EncryptRuntimeException("Ciphertext must be a multkmsServerIPle of 128 bits")
- }
- if (hmac == null || hmac.length != 32) {
- throw new EncryptRuntimeException("hmac must be 256 bits")
- }
-
- val secretKeySpec = new SecretKeySpec(encryptKey, "AES")
- val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
- cipher.init(Cipher.DECRYPT_MODE, secretKeySpec, ivParameterSpec)
- cipher.doFinal(cipherText)
- }
-
- def decryptBigContent(
- ite: Iterator[(String, PortableDataStream)],
- dataKeyPlaintext: String): Iterator[String] = {
- val secret: Array[Byte] = dataKeyPlaintext.getBytes()
- var result: Iterator[String] = Iterator[String]()
-
- while (ite.hasNext == true) {
- val inputStream: DataInputStream = ite.next._2.open()
- val version: Byte = inputStream.readByte()
- if (version.compare((0x80).toByte) != 0) {
- throw new EncryptRuntimeException("Version error!")
- }
- val encryptKey: Array[Byte] = Arrays.copyOfRange(secret, 16, 32)
-
- val timestampSeconds: Long = inputStream.readLong()
- val initializationVector: Array[Byte] = read(inputStream, 16)
- val ivParameterSpec = new IvParameterSpec(initializationVector)
-
- val secretKeySpec = new SecretKeySpec(encryptKey, "AES")
- val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
- cipher.init(Cipher.DECRYPT_MODE, secretKeySpec, ivParameterSpec)
-
- val blockSize = 102400000 // 100m per update
- var lastString = ""
- while (inputStream.available() > blockSize + 32) {
- val splitEncryptedBytes = read(inputStream, blockSize)
- val currentSplitDecryptString = new String(cipher.update(splitEncryptedBytes))
- val splitDecryptString = lastString + currentSplitDecryptString
- val splitDecryptStringArray = splitDecryptString.split("\r").flatMap(_.split("\n"))
- lastString = splitDecryptStringArray.last
- result = result ++ splitDecryptStringArray.dropRight(1)
- }
-
- val lastCipherText: Array[Byte] = read(inputStream, inputStream.available() - 32)
- val lastDecryptString = lastString + (new String(cipher.doFinal(lastCipherText)))
- val splitDecryptStringArray = lastDecryptString.split("\r").flatMap(_.split("\n"))
- result = result ++ splitDecryptStringArray
-
- val hmac: Array[Byte] = read(inputStream, 32)
- if (initializationVector.length != 16) {
- throw new EncryptRuntimeException("Initialization Vector must be 128 bits")
- }
- if (hmac == null || hmac.length != 32) {
- throw new EncryptRuntimeException("hmac must be 256 bits")
- }
- if (inputStream.available > 0) {
- throw new EncryptRuntimeException("inputStream still has contents")
- }
- }
- result
-
- }
-
-}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameReader.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameReader.scala
index 90b995008f7..29ac00bda87 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameReader.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameReader.scala
@@ -17,8 +17,7 @@
package com.intel.analytics.bigdl.ppml.crypto.dataframe
import com.intel.analytics.bigdl.ppml.PPMLContext
-import com.intel.analytics.bigdl.ppml.crypto.CryptoMode
-import com.intel.analytics.bigdl.ppml.crypto.CryptoMode.CryptoMode
+import com.intel.analytics.bigdl.ppml.crypto.{AES_CBC_PKCS5PADDING, CryptoMode, PLAIN_TEXT}
import com.intel.analytics.bigdl.ppml.crypto.dataframe.EncryptedDataFrameReader.toDataFrame
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StringType, StructField, StructType}
@@ -42,11 +41,11 @@ class EncryptedDataFrameReader(
}
def csv(path: String): DataFrame = {
encryptMode match {
- case CryptoMode.PLAIN_TEXT =>
+ case PLAIN_TEXT =>
sparkSession.read.options(extraOptions).csv(path)
- case CryptoMode.AES_CBC_PKCS5PADDING =>
+ case AES_CBC_PKCS5PADDING =>
val rdd = PPMLContext.textFile(sparkSession.sparkContext, path,
- dataKeyPlainText)
+ dataKeyPlainText, encryptMode)
// TODO: support more options
if (extraOptions.contains("header") &&
extraOptions("header").toLowerCase() == "true") {
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameWriter.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameWriter.scala
new file mode 100644
index 00000000000..b2df831dbe1
--- /dev/null
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptedDataFrameWriter.scala
@@ -0,0 +1,29 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.analytics.bigdl.ppml.crypto.dataframe
+
+import com.intel.analytics.bigdl.ppml.crypto.CryptoMode
+import org.apache.spark.sql.SparkSession
+
+class EncryptedDataFrameWriter(
+ sparkSession: SparkSession,
+ encryptMode: CryptoMode,
+ dataKeyPlainText: String) {
+
+
+
+}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/utils/EncryptIOArguments.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/utils/EncryptIOArguments.scala
index b35c49224e0..a6703096d5d 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/utils/EncryptIOArguments.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/utils/EncryptIOArguments.scala
@@ -17,8 +17,7 @@
package com.intel.analytics.bigdl.ppml.utils
import com.intel.analytics.bigdl.ppml.PPMLContext
-import com.intel.analytics.bigdl.ppml.crypto.{CryptoMode, EncryptRuntimeException}
-import com.intel.analytics.bigdl.ppml.crypto.CryptoMode.CryptoMode
+import com.intel.analytics.bigdl.ppml.crypto.{CryptoMode, EncryptRuntimeException, PLAIN_TEXT}
import com.intel.analytics.bigdl.ppml.kms.{EHSMKeyManagementService, KMS_CONVENTION, SimpleKeyManagementService}
import java.io.File
@@ -26,8 +25,8 @@ import java.io.File
case class EncryptIOArguments(
inputPath: String = "./input",
outputPath: String = "./output",
- inputEncryptMode: CryptoMode = CryptoMode.PLAIN_TEXT,
- outputEncryptMode: CryptoMode = CryptoMode.PLAIN_TEXT,
+ inputEncryptMode: CryptoMode = PLAIN_TEXT,
+ outputEncryptMode: CryptoMode = PLAIN_TEXT,
inputPartitionNum: Int = 4,
outputPartitionNum: Int = 4,
primaryKeyPath: String = "./primaryKeyPath",
diff --git a/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/EncryptSpec.scala b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/EncryptSpec.scala
new file mode 100644
index 00000000000..eb16ffe9e88
--- /dev/null
+++ b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/EncryptSpec.scala
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.analytics.bigdl.ppml.crypto
+
+import com.intel.analytics.bigdl.dllib.common.zooUtils
+import com.intel.analytics.bigdl.dllib.utils.File
+import com.intel.analytics.bigdl.ppml.kms.SimpleKeyManagementService
+import org.apache.hadoop.fs.Path
+import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
+
+import java.io.{BufferedReader, BufferedWriter, FileReader, FileWriter}
+import java.nio.file.{Files, Paths, StandardOpenOption}
+import scala.util.Random
+
+class EncryptSpec extends FlatSpec with Matchers with BeforeAndAfter {
+ val (appid, appkey) = generateKeys()
+ val simpleKms = SimpleKeyManagementService(appid, appkey)
+ val dir = zooUtils.createTmpDir("PPMLUT", "rwx------").toFile()
+ val primaryKeyPath = dir + "/primary.key"
+ val dataKeyPath = dir + "/data.key"
+ simpleKms.retrievePrimaryKey(primaryKeyPath)
+ simpleKms.retrieveDataKey(primaryKeyPath, dataKeyPath)
+ var dataKeyPlaintext: String = null
+ val (plainFileName, encryptFileName, data) = generateCsvData()
+ val fs = File.getFileSystem(plainFileName)
+
+ def generateKeys(): (String, String) = {
+ val appid: String = (1 to 12).map(x => Random.nextInt(10)).mkString
+ val appkey: String = (1 to 12).map(x => Random.nextInt(10)).mkString
+ (appid, appkey)
+ }
+ def generateCsvData(): (String, String, String) = {
+ val fileName = dir + "/people.csv"
+ val encryptFileName = dir + "/en_people.csv"
+ val fw = new FileWriter(fileName)
+ val data = new StringBuilder()
+ data.append(s"name,age,job\n")
+ data.append(s"yvomq,59,Developer\ngdni,40,Engineer\npglyal,33,Engineer")
+ fw.append(data)
+ fw.close()
+
+ val crypto = new BigDLEncrypt()
+ dataKeyPlaintext = simpleKms.retrieveDataKeyPlainText(primaryKeyPath, dataKeyPath)
+ crypto.init(AES_CBC_PKCS5PADDING, ENCRYPT, dataKeyPlaintext)
+ Files.write(Paths.get(encryptFileName), crypto.genFileHeader())
+ val encryptedBytes = crypto.doFinal(data.toString().getBytes)
+ Files.write(Paths.get(encryptFileName), encryptedBytes._1, StandardOpenOption.APPEND)
+ Files.write(Paths.get(encryptFileName), encryptedBytes._2, StandardOpenOption.APPEND)
+ (fileName, encryptFileName, data.toString())
+ }
+
+ "encrypt stream" should "work" in {
+ val encrypt = new BigDLEncrypt()
+ encrypt.init(AES_CBC_PKCS5PADDING, ENCRYPT, dataKeyPlaintext)
+ val bis = fs.open(new Path(plainFileName))
+ val outs = fs.create(new Path(dir + "/en_o.csv"))
+ encrypt.encryptStream(bis, outs)
+ bis.close()
+ outs.flush()
+ outs.close()
+ Thread.sleep(1000)
+
+ val decrypt = new BigDLEncrypt()
+ decrypt.init(AES_CBC_PKCS5PADDING, DECRYPT, dataKeyPlaintext)
+ val bis2 = fs.open(new Path(dir + "/en_o.csv"))
+ val outs2 = fs.create(new Path(dir + "/de_o.csv"))
+ decrypt.decryptStream(bis2, outs2)
+ outs2.close()
+ outs2.flush()
+ bis2.close()
+ val originFile = Files.readAllBytes(Paths.get(plainFileName))
+ val deFile = Files.readAllBytes(Paths.get(dir.toString, "/de_o.csv"))
+ originFile.sameElements(deFile) should be (true)
+ }
+}
diff --git a/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptDataFrameSpec.scala b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptDataFrameSpec.scala
index 39cb57809e6..5e352c11912 100644
--- a/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptDataFrameSpec.scala
+++ b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/crypto/dataframe/EncryptDataFrameSpec.scala
@@ -18,14 +18,14 @@ package com.intel.analytics.bigdl.ppml.crypto.dataframe
import com.intel.analytics.bigdl.dllib.common.zooUtils
import com.intel.analytics.bigdl.ppml.PPMLContext
-import com.intel.analytics.bigdl.ppml.crypto.{CryptoMode, FernetEncrypt}
+import com.intel.analytics.bigdl.ppml.crypto.{AES_CBC_PKCS5PADDING, CryptoMode, ENCRYPT, BigDLEncrypt, PLAIN_TEXT}
import com.intel.analytics.bigdl.ppml.kms.SimpleKeyManagementService
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
import java.io.FileWriter
-import java.nio.file.{Files, Paths}
+import java.nio.file.{Files, Paths, StandardOpenOption}
import scala.util.Random
class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{
@@ -55,10 +55,13 @@ class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{
fw.append(data)
fw.close()
- val fernetCryptos = new FernetEncrypt()
+ val crypto = new BigDLEncrypt()
val dataKeyPlaintext = simpleKms.retrieveDataKeyPlainText(primaryKeyPath, dataKeyPath)
- val encryptedBytes = fernetCryptos.encryptBytes(data.toString().getBytes, dataKeyPlaintext)
- Files.write(Paths.get(encryptFileName), encryptedBytes)
+ crypto.init(AES_CBC_PKCS5PADDING, ENCRYPT, dataKeyPlaintext)
+ Files.write(Paths.get(encryptFileName), crypto.genFileHeader())
+ val encryptedBytes = crypto.doFinal(data.toString().getBytes)
+ Files.write(Paths.get(encryptFileName), encryptedBytes._1, StandardOpenOption.APPEND)
+ Files.write(Paths.get(encryptFileName), encryptedBytes._2, StandardOpenOption.APPEND)
(fileName, encryptFileName, data.toString())
}
val ppmlArgs = Map(
@@ -73,7 +76,7 @@ class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{
"textfile read from plaint text file" should "work" in {
val file = sc.textFile(plainFileName).collect()
file.mkString("\n") should be (data)
- val file2 = sc.textFile(encryptFileName, cryptoMode = CryptoMode.AES_CBC_PKCS5PADDING).collect()
+ val file2 = sc.textFile(encryptFileName, cryptoMode = AES_CBC_PKCS5PADDING).collect()
file2.mkString("\n") should be (data)
}
@@ -89,7 +92,7 @@ class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{
}
"read from plain csv with header" should "work" in {
- val df = sc.read(cryptoMode = CryptoMode.PLAIN_TEXT)
+ val df = sc.read(cryptoMode = PLAIN_TEXT)
.option("header", "true").csv(plainFileName)
val d = df.schema.map(_.name).mkString(",") + "\n" +
df.collect().map(v => s"${v.get(0)},${v.get(1)},${v.get(2)}").mkString("\n")
@@ -97,7 +100,7 @@ class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{
}
"read from encrypted csv with header" should "work" in {
- val df = sc.read(cryptoMode = CryptoMode.AES_CBC_PKCS5PADDING)
+ val df = sc.read(cryptoMode = AES_CBC_PKCS5PADDING)
.option("header", "true").csv(encryptFileName)
val d = df.schema.map(_.name).mkString(",") + "\n" +
df.collect().map(v => s"${v.get(0)},${v.get(1)},${v.get(2)}").mkString("\n")
@@ -105,13 +108,13 @@ class EncryptDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfter{
}
"read from plain csv without header" should "work" in {
- val df = sc.read(cryptoMode = CryptoMode.PLAIN_TEXT).csv(plainFileName)
+ val df = sc.read(cryptoMode = PLAIN_TEXT).csv(plainFileName)
val d = df.collect().map(v => s"${v.get(0)},${v.get(1)},${v.get(2)}").mkString("\n")
d should be (data)
}
"read from encrypted csv without header" should "work" in {
- val df = sc.read(cryptoMode = CryptoMode.AES_CBC_PKCS5PADDING).csv(encryptFileName)
+ val df = sc.read(cryptoMode = AES_CBC_PKCS5PADDING).csv(encryptFileName)
val d = df.collect().map(v => s"${v.get(0)},${v.get(1)},${v.get(2)}").mkString("\n")
d should be (data)
}