Skip to content

Commit

Permalink
fix data conversion to use StringArray param
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Aug 17, 2018
1 parent 0ca02ac commit aa925fe
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 40 deletions.
10 changes: 5 additions & 5 deletions notebooks/samples/105 - Regression with DataConversion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@
"outputs": [],
"source": [
"from mmlspark import DataConversion\n",
"flightDelay = DataConversion(col=\"Quarter,Month,DayofMonth,DayOfWeek,\"\n",
" + \"OriginAirportID,DestAirportID,\"\n",
" + \"CRSDepTime,CRSArrTime\",\n",
"flightDelay = DataConversion(cols=[\"Quarter\",\"Month\",\"DayofMonth\",\"DayOfWeek\",\n",
" + \"OriginAirportID\",\"DestAirportID\",\n",
" + \"CRSDepTime\",\"CRSArrTime\"],\n",
" convertTo=\"double\") \\\n",
" .transform(flightDelay)\n",
"flightDelay.printSchema()\n",
Expand Down Expand Up @@ -151,10 +151,10 @@
"from mmlspark import TrainRegressor, TrainedRegressorModel\n",
"from pyspark.ml.regression import LinearRegression\n",
"\n",
"trainCat = DataConversion(col=\"Carrier,DepTimeBlk,ArrTimeBlk\",\n",
"trainCat = DataConversion(cols=[\"Carrier\",\"DepTimeBlk\",\"ArrTimeBlk\"],\n",
" convertTo=\"toCategorical\") \\\n",
" .transform(train)\n",
"testCat = DataConversion(col=\"Carrier,DepTimeBlk,ArrTimeBlk\",\n",
"testCat = DataConversion(cols=[\"Carrier\",\"DepTimeBlk\",\"ArrTimeBlk\"],\n",
" convertTo=\"toCategorical\") \\\n",
" .transform(test)\n",
"lr = LinearRegression().setSolver(\"l-bfgs\").setRegParam(0.1) \\\n",
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/src/main/scala/PySparkWrapperTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ abstract class PySparkWrapperTest(entryPoint: PipelineStage,
entryPointName match {
case "_CNTKModel" | "MultiTokenizer" | "NltTokenizeTransform" | "TextTransform"
| "TextNormalizerTransform" | "WordTokenizeTransform" => "inputCol=\"col5\""
case "DataConversion" => "col=\"col1\", convertTo=\"double\""
case "DataConversion" => "cols=[\"col1\"], convertTo=\"double\""
case "DropColumns" => "cols=[\"col1\"]"
case "EnsembleByKey" => "keys=[\"col1\"], cols=[\"col3\"]"
case "FastVectorAssembler" => "inputCols=\"col1\""
Expand Down
23 changes: 5 additions & 18 deletions src/data-conversion/src/main/scala/DataConversion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ class DataConversion(override val uid: String) extends Transformer with MMLParam
/** Comma separated list of columns whose type will be converted
* @group param
*/
val col: Param[String] = StringParam(this, "col",
"Comma separated list of columns whose type will be converted", "")
val cols: StringArrayParam = new StringArrayParam(this, "cols",
"Comma separated list of columns whose type will be converted")

/** @group getParam */
final def getCol: String = $(col)
final def getCols: Array[String] = $(cols)

/** @group setParam */
def setCol(value: String): this.type = set(col, value)
def setCols(value: Array[String]): this.type = set(cols, value)

/** The result type
* @group param
Expand Down Expand Up @@ -64,10 +64,9 @@ class DataConversion(override val uid: String) extends Transformer with MMLParam
* @return The transformed dataset
*/
override def transform(dataset: Dataset[_]): DataFrame = {
require($(col) != null, "No column name specified")
require(dataset != null, "No dataset supplied")
require(dataset.columns.length != 0, "Dataset with no columns cannot be converted")
val colsList = $(col).split(",").map(_.trim)
val colsList = $(cols).map(_.trim)
val errorList = verifyCols(dataset.toDF(), colsList)
if (errorList.nonEmpty) {
throw new NoSuchElementException
Expand Down Expand Up @@ -98,18 +97,6 @@ class DataConversion(override val uid: String) extends Transformer with MMLParam
res
}

/** Transforms the dataset
* @param dataset The input dataset, to be transformed
* @param paramMap ParamMap which contains parameter value to override the default value
* @return the DataFrame that results from data conversion
*/
override def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = {
setCol(paramMap.getOrElse(new Param("col", "col","Name of column whose type will be converted"), ""))
setConvertTo(paramMap.getOrElse(new Param("convertTo", "convertTo","Result type"), ""))
setDateTimeFormat(paramMap.getOrElse(new Param("dateTimeFormat", "dateTimeFormat", "Time string format"), ""))
transform(dataset)
}

/** Transform the schema
* @param schema
* @return modified schema
Expand Down
33 changes: 17 additions & 16 deletions src/data-conversion/src/test/scala/VerifyDataConversion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ class VerifyDataConversions extends TestBase {
Types tested are boolean, Byte, Short, Int, Long, Float, Double, and string
*/
test("Test convert all types to Boolean") {
val r1 = new DataConversion().setCol("byte").setConvertTo("boolean").transform(masterInDF)
val r2 = new DataConversion().setCol("short").setConvertTo("boolean").transform(r1)
val r3 = new DataConversion().setCol("int").setConvertTo("boolean").transform(r2)
val r4 = new DataConversion().setCol("long").setConvertTo("boolean").transform(r3)
val r5 = new DataConversion().setCol("float").setConvertTo("boolean").transform(r4)
val r6 = new DataConversion().setCol("double").setConvertTo("boolean").transform(r5)
val r1 = new DataConversion().setCols(Array("byte")).setConvertTo("boolean").transform(masterInDF)
val r2 = new DataConversion().setCols(Array("short")).setConvertTo("boolean").transform(r1)
val r3 = new DataConversion().setCols(Array("int")).setConvertTo("boolean").transform(r2)
val r4 = new DataConversion().setCols(Array("long")).setConvertTo("boolean").transform(r3)
val r5 = new DataConversion().setCols(Array("float")).setConvertTo("boolean").transform(r4)
val r6 = new DataConversion().setCols(Array("double")).setConvertTo("boolean").transform(r5)
val expectedRes = Seq(( true, true, true, true, true, true, true, "7", "8.0"),
(false, true, true, true, true, true, true, "16", "17.456"),
(true, true, true, true, true, true, true, "100", "200.12345"))
Expand All @@ -75,7 +75,7 @@ class VerifyDataConversions extends TestBase {
*/
test("Test convert string to boolean throws an exception") {
assertThrows[Exception] {
new DataConversion().setCol("intstring").setConvertTo("boolean").transform(masterInDF)
new DataConversion().setCols(Array("intstring")).setConvertTo("boolean").transform(masterInDF)
}
}

Expand Down Expand Up @@ -177,54 +177,55 @@ class VerifyDataConversions extends TestBase {
// Test convert to categorical:
test("Test convert to categorical") {
val inDF = Seq(("piano", 1, 2), ("drum", 3, 4), ("guitar", 5, 6)).toDF("instruments", "c1", "c2")
val res = new DataConversion().setCol("instruments").setConvertTo("toCategorical").transform(inDF)
val res = new DataConversion().setCols(Array("instruments")).setConvertTo("toCategorical").transform(inDF)
assert(SparkSchema.isCategorical(res, "instruments"))
}

// Test clearing categorical
test("Test that categorical features will be cleared") {
val inDF = Seq(("piano", 1, 2), ("drum", 3, 4), ("guitar", 5, 6)).toDF("instruments", "c1", "c2")
val res = new DataConversion().setCol("instruments").setConvertTo("toCategorical").transform(inDF)
val res = new DataConversion().setCols(Array("instruments")).setConvertTo("toCategorical").transform(inDF)
assert(SparkSchema.isCategorical(res, "instruments"))
val res2 = new DataConversion().setCol("instruments").setConvertTo("clearCategorical").transform(res)
val res2 = new DataConversion().setCols(Array("instruments")).setConvertTo("clearCategorical").transform(res)
assert(!SparkSchema.isCategorical(res2, "instruments"))
assert(inDF.except(res2).count == 0)
}

// Verify that a TimestampType is converted to a LongType
test("Test timestamp to long conversion") {
val res = new DataConversion().setCol("Col0").setConvertTo("long")
val res = new DataConversion().setCols(Array("Col0")).setConvertTo("long")
.setDateTimeFormat("yyyy-MM-dd HH:mm:ss.SSS").transform(tsDF)
assert(res.schema("Col0").dataType == LongType)
assert(lDF.except(res).count == 0)
}

// Test the reverse - long to timestamp
test("Test long to timestamp conversion") {
val res = new DataConversion().setCol("Col0").setConvertTo("date")
val res = new DataConversion().setCols(Array("Col0")).setConvertTo("date")
.setDateTimeFormat("yyyy-MM-dd HH:mm:ss.SSS").transform(lDF)
assert(res.schema("Col0").dataType == TimestampType)
assert(tsDF.except(res).count == 0)
}

test("Test timestamp to string conversion") {
val res = new DataConversion().setCol("Col0").setConvertTo("string")
val res = new DataConversion().setCols(Array("Col0")).setConvertTo("string")
.setDateTimeFormat("yyyy-MM-dd HH:mm:ss.SSS").transform(tsDF)
assert(res.schema("Col0").dataType == StringType)
assert(sDF.except(res).count == 0)
}

test("Test date string to timestamp conversion") {
val res = new DataConversion().setCol("Col0").setConvertTo("date")
val res = new DataConversion().setCols(Array("Col0")).setConvertTo("date")
.setDateTimeFormat("yyyy-MM-dd HH:mm:ss.SSS").transform(sDF)
val res2 = new DataConversion().setCol("Col0").setConvertTo("long")
val res2 = new DataConversion().setCols(Array("Col0")).setConvertTo("long")
.setDateTimeFormat("yyyy-MM-dd HH:mm:ss.SSS").transform(res)
assert(res.schema("Col0").dataType == TimestampType)
assert(tsDF.except(res).count == 0)
}

def generateRes(convTo: String, inDF: DataFrame): DataFrame = {
val result = new DataConversion().setCol("bool, byte, short, int, long, float, double, intstring, doublestring")
val result = new DataConversion()
.setCols(Array("bool", "byte", "short", "int", "long", "float", "double", "intstring", "doublestring"))
.setConvertTo(convTo).transform(masterInDF)
result
}
Expand Down

0 comments on commit aa925fe

Please sign in to comment.