Skip to content

Commit

Permalink
Pyppml textfile and init method (#4995)
Browse files Browse the repository at this point in the history
* add PPMLContextWrapper.scala

* add method load_keys

* add init context with conf

* fix java map to scala map

* add read function

* update for debug

* update for debug

* py args to scala args

* update

* add api package

* update

* add import

* update

* del old file

* del files for test

* revert

* add textFile & init method

* update

* add py textfile method

* update

* get DefaultMinPartitions

* update

* update

* fix init context with SparkConf

* fix java list to scala list

* update test

Co-authored-by: Zhou <[email protected]>
  • Loading branch information
PatrickkZ and PatrickkZ authored Jul 4, 2022
1 parent 9e28e5a commit 9630ee0
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 34 deletions.
16 changes: 13 additions & 3 deletions python/ppml/src/bigdl/ppml/ppml_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@


class PPMLContext(JavaValue):
def __init__(self, app_name, conf=None):
def __init__(self, app_name, ppml_args=None, spark_conf=None):
self.bigdl_type = "float"
args = [app_name]
if conf:
args.append(conf)
if ppml_args:
args.append(ppml_args)
if spark_conf:
args.append(spark_conf.getAll())
super().__init__(None, self.bigdl_type, *args)
self.sparkSession = callBigDlFunc(self.bigdl_type, "getSparkSession", self.value)

def load_keys(self, primary_key_path, data_key_path):
callBigDlFunc(self.bigdl_type, "loadKeys", self.value, primary_key_path, data_key_path)
Expand All @@ -42,6 +45,13 @@ def write(self, dataframe, crypto_mode):
df_writer = callBigDlFunc(self.bigdl_type, "write", self.value, dataframe, crypto_mode)
return EncryptedDataFrameWriter(self.bigdl_type, df_writer)

def textfile(self, path, min_partitions=None, crypto_mode="plain_text"):
if min_partitions is None:
min_partitions = callBigDlFunc(self.bigdl_type, "getDefaultMinPartitions", self.sparkSession)
if isinstance(crypto_mode, CryptoMode):
crypto_mode = crypto_mode.value
return callBigDlFunc(self.bigdl_type, "textFile", self.value, path, min_partitions, crypto_mode)


class EncryptedDataFrameReader:
def __init__(self, bigdl_type, df_reader):
Expand Down
42 changes: 30 additions & 12 deletions python/ppml/test/bigdl/ppml/test_ppml_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,27 @@ class TestPPMLContext(unittest.TestCase):

df = None
data_content = None
csv_content = None
sc = None

@classmethod
def setUpClass(cls) -> None:
if not os.path.exists(resource_path):
os.mkdir(resource_path)

cls.csv_content = "name,age,job\n" + \
"jack,18,Developer\n" + \
"alex,20,Researcher\n" + \
"xuoui,25,Developer\n" + \
"hlsgu,29,Researcher\n" + \
"xvehlbm,45,Developer\n" + \
"ehhxoni,23,Developer\n" + \
"capom,60,Developer\n" + \
"pjt,24,Developer"

# create a tmp csv file
with open(os.path.join(resource_path, "people.csv"), "w", encoding="utf-8", newline="") as f:
csv_writer = csv.writer(f)
csv_writer.writerow(["name", "age", "job"])
csv_writer.writerow(["jack", "18", "Developer"])
csv_writer.writerow(["alex", "20", "Researcher"])
csv_writer.writerow(["xuoui", "25", "Developer"])
csv_writer.writerow(["hlsgu", "29", "Researcher"])
csv_writer.writerow(["xvehlbm", "45", "Developer"])
csv_writer.writerow(["ehhxoni", "23", "Developer"])
csv_writer.writerow(["capom", "60", "Developer"])
csv_writer.writerow(["pjt", "24", "Developer"])
with open(os.path.join(resource_path, "people.csv"), "w") as file:
file.write(cls.csv_content)

# generate primaryKey and dataKey
primary_key_path = os.path.join(resource_path, "primaryKey")
Expand All @@ -78,7 +80,10 @@ def setUpClass(cls) -> None:
cls.data_content = '\n'.join([str(v['language']) + "," + str(v['user'])
for v in cls.df.orderBy('language').collect()])

cls.sc = PPMLContext("testApp", args)
from pyspark import SparkConf
spark_conf = SparkConf()
spark_conf.setMaster("local[4]")
cls.sc = PPMLContext("testApp", args, spark_conf)

@classmethod
def tearDownClass(cls) -> None:
Expand Down Expand Up @@ -141,6 +146,19 @@ def test_write_and_read_encrypted_parquet(self):
for v in df_from_parquet.orderBy('language').collect()])
self.assertEqual(content, self.data_content)

def test_plain_text_file(self):
path = os.path.join(resource_path, "people.csv")
rdd = self.sc.textfile(path)
rdd_content = '\n'.join([line for line in rdd.collect()])

self.assertEqual(rdd_content, self.csv_content)

def test_encrypted_text_file(self):
path = os.path.join(resource_path, "encrypted/people.csv")
rdd = self.sc.textfile(path=path, crypto_mode=CryptoMode.AES_CBC_PKCS5PADDING)
rdd_content = '\n'.join([line for line in rdd.collect()])
self.assertEqual(rdd_content, self.csv_content)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ import com.intel.analytics.bigdl.ppml.PPMLContext
import com.intel.analytics.bigdl.ppml.crypto.{AES_CBC_PKCS5PADDING, BigDLEncrypt, CryptoMode, ENCRYPT, EncryptRuntimeException}
import com.intel.analytics.bigdl.ppml.crypto.dataframe.{EncryptedDataFrameReader, EncryptedDataFrameWriter}
import com.intel.analytics.bigdl.ppml.kms.{KMS_CONVENTION, SimpleKeyManagementService}
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row}
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}
import org.slf4j.{Logger, LoggerFactory}
import scala.collection.JavaConverters._

import java.io.File
import java.util
Expand All @@ -42,7 +45,53 @@ class PPMLContextPython[T]() {

def createPPMLContext(appName: String, ppmlArgs: util.Map[String, String]): PPMLContext = {
logger.debug("create PPMLContextWrapper with appName & ppmlArgs")
val args = parseArgs(ppmlArgs)
PPMLContext.initPPMLContext(appName, args)
}

def createPPMLContext(appName: String, ppmlArgs: util.Map[String, String],
confs: util.List[util.List[String]]): PPMLContext = {
logger.debug("create PPMLContextWrapper with appName & ppmlArgs & sparkConf")
val args = parseArgs(ppmlArgs)
val sparkConf = new SparkConf()
confs.asScala.foreach(conf => sparkConf.set(conf.get(0), conf.get(1)))

PPMLContext.initPPMLContext(sparkConf, appName, args)
}

def read(sc: PPMLContext, cryptoModeStr: String): EncryptedDataFrameReader = {
logger.debug("read file with crypto mode " + cryptoModeStr)
val cryptoMode = CryptoMode.parse(cryptoModeStr)
sc.read(cryptoMode)
}

def write(sc: PPMLContext, dataFrame: DataFrame,
cryptoModeStr: String): EncryptedDataFrameWriter = {
logger.debug("write file with crypt mode " + cryptoModeStr)
val cryptoMode = CryptoMode.parse(cryptoModeStr)
sc.write(dataFrame, cryptoMode)
}

def loadKeys(sc: PPMLContext,
primaryKeyPath: String, dataKeyPath: String): Unit = {
sc.loadKeys(primaryKeyPath, dataKeyPath)
}

def textFile(sc: PPMLContext, path: String,
minPartitions: Int, cryptoModeStr: String): RDD[String] = {
val cryptoMode = CryptoMode.parse(cryptoModeStr)
sc.textFile(path, minPartitions, cryptoMode)
}

def getSparkSession(sc: PPMLContext): SparkSession = {
sc.getSparkSession()
}

def getDefaultMinPartitions(sparkSession: SparkSession): Int = {
sparkSession.sparkContext.defaultMinPartitions
}

private def parseArgs(ppmlArgs: util.Map[String, String]): Map[String, String] = {
val kmsArgs = scala.collection.mutable.Map[String, String]()
val kmsType = ppmlArgs.get("kms_type")
kmsArgs("spark.bigdl.kms.type") = kmsType
Expand All @@ -61,26 +110,9 @@ class PPMLContextPython[T]() {
if (new File(ppmlArgs.get("data_key_path")).exists()) {
kmsArgs("spark.bigdl.kms.key.data") = ppmlArgs.get("data_key_path")
}
PPMLContext.initPPMLContext(appName, kmsArgs.toMap)
}

def read(sc: PPMLContext, cryptoModeStr: String): EncryptedDataFrameReader = {
logger.debug("read file with crypto mode " + cryptoModeStr)
val cryptoMode = CryptoMode.parse(cryptoModeStr)
sc.read(cryptoMode)
kmsArgs.toMap
}

def write(sc: PPMLContext, dataFrame: DataFrame,
cryptoModeStr: String): EncryptedDataFrameWriter = {
logger.debug("write file with crypt mode " + cryptoModeStr)
val cryptoMode = CryptoMode.parse(cryptoModeStr)
sc.write(dataFrame, cryptoMode)
}

def loadKeys(sc: PPMLContext,
primaryKeyPath: String, dataKeyPath: String): Unit = {
sc.loadKeys(primaryKeyPath, dataKeyPath)
}

/**
* EncryptedDataFrameReader method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

import java.io.File
import java.util
import scala.collection.JavaConverters._

class PPMLContextPythonSpec extends DataFrameHelper{
Expand Down Expand Up @@ -56,6 +57,17 @@ class PPMLContextPythonSpec extends DataFrameHelper{
ppmlContextPython.createPPMLContext("testApp", pyPPMLArgs.asJava)
}

"init PPMLContext with app name & args & sparkConf" should "work" in {
val confs: util.List[util.List[String]] = new util.ArrayList[util.List[String]]()
conf.getAll.foreach(tuple => {
val conf = new util.ArrayList[String]()
conf.add(tuple._1)
conf.add(tuple._2)
confs.add(conf)
})
ppmlContextPython.createPPMLContext("testApp", pyPPMLArgs.asJava, confs)
}

"read plain csv file" should "work" in {
val encryptedDataFrameReader = ppmlContextPython.read(sc, "plain_text")
ppmlContextPython.option(encryptedDataFrameReader, "header", "true")
Expand Down Expand Up @@ -182,4 +194,22 @@ class PPMLContextPythonSpec extends DataFrameHelper{
parquetContent should be (dataContent)
}

"textFile method with plain csv file" should "work" in {
val minPartitions = sc.getSparkSession().sparkContext.defaultMinPartitions
val cryptoMode = "plain_text"
val rdd = ppmlContextPython.textFile(sc, plainFileName, minPartitions, cryptoMode)
val rddContent = rdd.collect().mkString("\n")

rddContent + "\n" should be (data)
}

"textFile method with encrypted csv file" should "work" in {
val minPartitions = sc.getSparkSession().sparkContext.defaultMinPartitions
val cryptoMode = "AES/CBC/PKCS5Padding"
val rdd = ppmlContextPython.textFile(sc, encryptFileName, minPartitions, cryptoMode)
val rddContent = rdd.collect().mkString("\n")

rddContent + "\n" should be (data)
}

}

0 comments on commit 9630ee0

Please sign in to comment.