Skip to content

Commit

Permalink
throw exception instead of ignoring non-literals input
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed May 23, 2018
1 parent a1f3a5b commit ca9caa0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,14 @@ object MaskLike {
if (i == null) defaultCharCount else i.asInstanceOf[Int]
case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " +
s"${IntegerType.simpleString}, but got literal of ${dt.simpleString}")
case _ => defaultCharCount
case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}")
}

def extractReplacement(e: Expression): String = e match {
case Literal(s, StringType | NullType) => if (s == null) null else s.toString
case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " +
s"${StringType.simpleString}, but got literal of ${dt.simpleString}")
case _ => null
case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Row(null, null, null, null)))
checkAnswer(sql("select mask(null)"), Row(null))
checkAnswer(sql("select mask('AAaa11', null, null, null)"), Row("XXxxnn"))
checkAnswer(df.selectExpr("mask(a, a)"), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null)))
intercept[AnalysisException] {
checkAnswer(df.selectExpr("mask(a, a)"), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null)))
}

checkAnswer(
df.selectExpr(
Expand All @@ -323,7 +325,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Row(null, null, null, null, null)))
checkAnswer(sql("select mask_first_n(null)"), Row(null))
checkAnswer(sql("select mask_first_n('A1aA1a', null, null, null, null)"), Row("XnxX1a"))
checkAnswer(spark.range(1).selectExpr("mask_first_n('A1aA1a', id)"), Row("XnxX1a"))
intercept[AnalysisException] {
checkAnswer(spark.range(1).selectExpr("mask_first_n('A1aA1a', id)"), Row("XnxX1a"))
}

checkAnswer(
df.selectExpr(
Expand All @@ -338,7 +342,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Row(null, null, null, null, null)))
checkAnswer(sql("select mask_last_n(null)"), Row(null))
checkAnswer(sql("select mask_last_n('A1aA1a', null, null, null, null)"), Row("A1xXnx"))
checkAnswer(spark.range(1).selectExpr("mask_last_n('A1aA1a', id)"), Row("A1xXnx"))
intercept[AnalysisException] {
checkAnswer(spark.range(1).selectExpr("mask_last_n('A1aA1a', id)"), Row("A1xXnx"))
}

checkAnswer(
df.selectExpr(
Expand All @@ -353,7 +359,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Row(null, null, null, null, null)))
checkAnswer(sql("select mask_show_first_n(null)"), Row(null))
checkAnswer(sql("select mask_show_first_n('A1aA1a', null, null, null, null)"), Row("A1aAnx"))
checkAnswer(spark.range(1).selectExpr("mask_show_first_n('A1aA1a', id)"), Row("A1aAnx"))
intercept[AnalysisException] {
checkAnswer(spark.range(1).selectExpr("mask_show_first_n('A1aA1a', id)"), Row("A1aAnx"))
}

checkAnswer(
df.selectExpr(
Expand All @@ -368,7 +376,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Row(null, null, null, null, null)))
checkAnswer(sql("select mask_show_last_n(null)"), Row(null))
checkAnswer(sql("select mask_show_last_n('A1aA1a', null, null, null, null)"), Row("XnaA1a"))
checkAnswer(spark.range(1).selectExpr("mask_show_last_n('A1aA1a', id)"), Row("XnaA1a"))
intercept[AnalysisException] {
checkAnswer(spark.range(1).selectExpr("mask_show_last_n('A1aA1a', id)"), Row("XnaA1a"))
}

checkAnswer(sql("select mask_hash(null)"), Row(null))
}
Expand Down

0 comments on commit ca9caa0

Please sign in to comment.