Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyppml textfile and init method #4995

Merged
merged 35 commits into from
Jul 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3fb9e10
add PPMLContextWrapper.scala
PatrickkZ Jun 6, 2022
217d9b4
add method load_keys
PatrickkZ Jun 6, 2022
846846c
add init context with conf
PatrickkZ Jun 6, 2022
29d0548
fix java map to scala map
PatrickkZ Jun 6, 2022
dafcae6
add read function
PatrickkZ Jun 6, 2022
f1c930a
update for debug
PatrickkZ Jun 7, 2022
fbadae1
update for debug
PatrickkZ Jun 7, 2022
8a203a7
py args to scala args
PatrickkZ Jun 7, 2022
9e6acc6
update
PatrickkZ Jun 7, 2022
ad9864f
add api package
PatrickkZ Jun 7, 2022
6b0af0b
update
PatrickkZ Jun 7, 2022
3992125
add import
PatrickkZ Jun 7, 2022
13bf876
Merge branch 'main' of https://github.com/intel-analytics/BigDL
PatrickkZ Jun 10, 2022
c7256e1
Merge branch 'main' of https://github.com/intel-analytics/BigDL
PatrickkZ Jun 13, 2022
7fffed2
Merge remote-tracking branch 'upstream/main'
PatrickkZ Jun 22, 2022
675c4d9
update
PatrickkZ Jun 22, 2022
856b5e8
Merge remote-tracking branch 'upstream/main'
PatrickkZ Jun 23, 2022
c911665
del old file
PatrickkZ Jun 23, 2022
d9b2591
Merge remote-tracking branch 'upstream/main'
PatrickkZ Jun 24, 2022
7fcb20b
del files for test
PatrickkZ Jun 27, 2022
c22c9d3
revert
PatrickkZ Jun 27, 2022
07fc1c5
Merge remote-tracking branch 'upstream/main'
PatrickkZ Jun 27, 2022
94a918c
Merge remote-tracking branch 'upstream/main'
PatrickkZ Jun 28, 2022
aa6a89e
Merge remote-tracking branch 'upstream/main'
PatrickkZ Jun 29, 2022
a82b965
Merge remote-tracking branch 'upstream/main'
PatrickkZ Jul 1, 2022
4eb11ab
add textFile & init method
PatrickkZ Jul 1, 2022
39f7cad
update
PatrickkZ Jul 1, 2022
28791ef
add py textfile method
PatrickkZ Jul 1, 2022
c46fde8
update
PatrickkZ Jul 1, 2022
df882f5
get DefaultMinPartitions
PatrickkZ Jul 1, 2022
5a285b1
update
PatrickkZ Jul 1, 2022
e5830da
update
PatrickkZ Jul 1, 2022
0553222
fix init context with SparkConf
PatrickkZ Jul 1, 2022
503c35e
fix java list to scala list
PatrickkZ Jul 1, 2022
a969656
update test
PatrickkZ Jul 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions python/ppml/src/bigdl/ppml/ppml_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@


class PPMLContext(JavaValue):
def __init__(self, app_name, conf=None):
def __init__(self, app_name, ppml_args=None, spark_conf=None):
self.bigdl_type = "float"
args = [app_name]
if conf:
args.append(conf)
if ppml_args:
args.append(ppml_args)
if spark_conf:
args.append(spark_conf.getAll())
super().__init__(None, self.bigdl_type, *args)
self.sparkSession = callBigDlFunc(self.bigdl_type, "getSparkSession", self.value)

def load_keys(self, primary_key_path, data_key_path):
callBigDlFunc(self.bigdl_type, "loadKeys", self.value, primary_key_path, data_key_path)
Expand All @@ -42,6 +45,13 @@ def write(self, dataframe, crypto_mode):
df_writer = callBigDlFunc(self.bigdl_type, "write", self.value, dataframe, crypto_mode)
return EncryptedDataFrameWriter(self.bigdl_type, df_writer)

def textfile(self, path, min_partitions=None, crypto_mode="plain_text"):
if min_partitions is None:
min_partitions = callBigDlFunc(self.bigdl_type, "getDefaultMinPartitions", self.sparkSession)
if isinstance(crypto_mode, CryptoMode):
crypto_mode = crypto_mode.value
return callBigDlFunc(self.bigdl_type, "textFile", self.value, path, min_partitions, crypto_mode)


class EncryptedDataFrameReader:
def __init__(self, bigdl_type, df_reader):
Expand Down
42 changes: 30 additions & 12 deletions python/ppml/test/bigdl/ppml/test_ppml_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,27 @@ class TestPPMLContext(unittest.TestCase):

df = None
data_content = None
csv_content = None
sc = None

@classmethod
def setUpClass(cls) -> None:
if not os.path.exists(resource_path):
os.mkdir(resource_path)

cls.csv_content = "name,age,job\n" + \
"jack,18,Developer\n" + \
"alex,20,Researcher\n" + \
"xuoui,25,Developer\n" + \
"hlsgu,29,Researcher\n" + \
"xvehlbm,45,Developer\n" + \
"ehhxoni,23,Developer\n" + \
"capom,60,Developer\n" + \
"pjt,24,Developer"

# create a tmp csv file
with open(os.path.join(resource_path, "people.csv"), "w", encoding="utf-8", newline="") as f:
csv_writer = csv.writer(f)
csv_writer.writerow(["name", "age", "job"])
csv_writer.writerow(["jack", "18", "Developer"])
csv_writer.writerow(["alex", "20", "Researcher"])
csv_writer.writerow(["xuoui", "25", "Developer"])
csv_writer.writerow(["hlsgu", "29", "Researcher"])
csv_writer.writerow(["xvehlbm", "45", "Developer"])
csv_writer.writerow(["ehhxoni", "23", "Developer"])
csv_writer.writerow(["capom", "60", "Developer"])
csv_writer.writerow(["pjt", "24", "Developer"])
with open(os.path.join(resource_path, "people.csv"), "w") as file:
file.write(cls.csv_content)

# generate primaryKey and dataKey
primary_key_path = os.path.join(resource_path, "primaryKey")
Expand All @@ -78,7 +80,10 @@ def setUpClass(cls) -> None:
cls.data_content = '\n'.join([str(v['language']) + "," + str(v['user'])
for v in cls.df.orderBy('language').collect()])

cls.sc = PPMLContext("testApp", args)
from pyspark import SparkConf
spark_conf = SparkConf()
spark_conf.setMaster("local[4]")
cls.sc = PPMLContext("testApp", args, spark_conf)

@classmethod
def tearDownClass(cls) -> None:
Expand Down Expand Up @@ -141,6 +146,19 @@ def test_write_and_read_encrypted_parquet(self):
for v in df_from_parquet.orderBy('language').collect()])
self.assertEqual(content, self.data_content)

def test_plain_text_file(self):
path = os.path.join(resource_path, "people.csv")
rdd = self.sc.textfile(path)
rdd_content = '\n'.join([line for line in rdd.collect()])

self.assertEqual(rdd_content, self.csv_content)

def test_encrypted_text_file(self):
path = os.path.join(resource_path, "encrypted/people.csv")
rdd = self.sc.textfile(path=path, crypto_mode=CryptoMode.AES_CBC_PKCS5PADDING)
rdd_content = '\n'.join([line for line in rdd.collect()])
self.assertEqual(rdd_content, self.csv_content)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ import com.intel.analytics.bigdl.ppml.PPMLContext
import com.intel.analytics.bigdl.ppml.crypto.{AES_CBC_PKCS5PADDING, BigDLEncrypt, CryptoMode, ENCRYPT, EncryptRuntimeException}
import com.intel.analytics.bigdl.ppml.crypto.dataframe.{EncryptedDataFrameReader, EncryptedDataFrameWriter}
import com.intel.analytics.bigdl.ppml.kms.{KMS_CONVENTION, SimpleKeyManagementService}
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row}
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}
import org.slf4j.{Logger, LoggerFactory}
import scala.collection.JavaConverters._

import java.io.File
import java.util
Expand All @@ -42,7 +45,53 @@ class PPMLContextPython[T]() {

def createPPMLContext(appName: String, ppmlArgs: util.Map[String, String]): PPMLContext = {
logger.debug("create PPMLContextWrapper with appName & ppmlArgs")
val args = parseArgs(ppmlArgs)
PPMLContext.initPPMLContext(appName, args)
}

def createPPMLContext(appName: String, ppmlArgs: util.Map[String, String],
confs: util.List[util.List[String]]): PPMLContext = {
logger.debug("create PPMLContextWrapper with appName & ppmlArgs & sparkConf")
val args = parseArgs(ppmlArgs)
val sparkConf = new SparkConf()
confs.asScala.foreach(conf => sparkConf.set(conf.get(0), conf.get(1)))

PPMLContext.initPPMLContext(sparkConf, appName, args)
}

def read(sc: PPMLContext, cryptoModeStr: String): EncryptedDataFrameReader = {
logger.debug("read file with crypto mode " + cryptoModeStr)
val cryptoMode = CryptoMode.parse(cryptoModeStr)
sc.read(cryptoMode)
}

def write(sc: PPMLContext, dataFrame: DataFrame,
cryptoModeStr: String): EncryptedDataFrameWriter = {
logger.debug("write file with crypt mode " + cryptoModeStr)
val cryptoMode = CryptoMode.parse(cryptoModeStr)
sc.write(dataFrame, cryptoMode)
}

def loadKeys(sc: PPMLContext,
primaryKeyPath: String, dataKeyPath: String): Unit = {
sc.loadKeys(primaryKeyPath, dataKeyPath)
}

def textFile(sc: PPMLContext, path: String,
minPartitions: Int, cryptoModeStr: String): RDD[String] = {
val cryptoMode = CryptoMode.parse(cryptoModeStr)
sc.textFile(path, minPartitions, cryptoMode)
}

def getSparkSession(sc: PPMLContext): SparkSession = {
sc.getSparkSession()
}

def getDefaultMinPartitions(sparkSession: SparkSession): Int = {
sparkSession.sparkContext.defaultMinPartitions
}

private def parseArgs(ppmlArgs: util.Map[String, String]): Map[String, String] = {
val kmsArgs = scala.collection.mutable.Map[String, String]()
val kmsType = ppmlArgs.get("kms_type")
kmsArgs("spark.bigdl.kms.type") = kmsType
Expand All @@ -61,26 +110,9 @@ class PPMLContextPython[T]() {
if (new File(ppmlArgs.get("data_key_path")).exists()) {
kmsArgs("spark.bigdl.kms.key.data") = ppmlArgs.get("data_key_path")
}
PPMLContext.initPPMLContext(appName, kmsArgs.toMap)
}

def read(sc: PPMLContext, cryptoModeStr: String): EncryptedDataFrameReader = {
logger.debug("read file with crypto mode " + cryptoModeStr)
val cryptoMode = CryptoMode.parse(cryptoModeStr)
sc.read(cryptoMode)
kmsArgs.toMap
}

def write(sc: PPMLContext, dataFrame: DataFrame,
cryptoModeStr: String): EncryptedDataFrameWriter = {
logger.debug("write file with crypt mode " + cryptoModeStr)
val cryptoMode = CryptoMode.parse(cryptoModeStr)
sc.write(dataFrame, cryptoMode)
}

def loadKeys(sc: PPMLContext,
primaryKeyPath: String, dataKeyPath: String): Unit = {
sc.loadKeys(primaryKeyPath, dataKeyPath)
}

/**
* EncryptedDataFrameReader method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

import java.io.File
import java.util
import scala.collection.JavaConverters._

class PPMLContextPythonSpec extends DataFrameHelper{
Expand Down Expand Up @@ -56,6 +57,17 @@ class PPMLContextPythonSpec extends DataFrameHelper{
ppmlContextPython.createPPMLContext("testApp", pyPPMLArgs.asJava)
}

"init PPMLContext with app name & args & sparkConf" should "work" in {
val confs: util.List[util.List[String]] = new util.ArrayList[util.List[String]]()
conf.getAll.foreach(tuple => {
val conf = new util.ArrayList[String]()
conf.add(tuple._1)
conf.add(tuple._2)
confs.add(conf)
})
ppmlContextPython.createPPMLContext("testApp", pyPPMLArgs.asJava, confs)
}

"read plain csv file" should "work" in {
val encryptedDataFrameReader = ppmlContextPython.read(sc, "plain_text")
ppmlContextPython.option(encryptedDataFrameReader, "header", "true")
Expand Down Expand Up @@ -182,4 +194,22 @@ class PPMLContextPythonSpec extends DataFrameHelper{
parquetContent should be (dataContent)
}

"textFile method with plain csv file" should "work" in {
val minPartitions = sc.getSparkSession().sparkContext.defaultMinPartitions
val cryptoMode = "plain_text"
val rdd = ppmlContextPython.textFile(sc, plainFileName, minPartitions, cryptoMode)
val rddContent = rdd.collect().mkString("\n")

rddContent + "\n" should be (data)
}

"textFile method with encrypted csv file" should "work" in {
val minPartitions = sc.getSparkSession().sparkContext.defaultMinPartitions
val cryptoMode = "AES/CBC/PKCS5Padding"
val rdd = ppmlContextPython.textFile(sc, encryptFileName, minPartitions, cryptoMode)
val rddContent = rdd.collect().mkString("\n")

rddContent + "\n" should be (data)
}

}