Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for isNaN and datetime related instructions in UDF compiler #593

Merged
merged 10 commits into from
Sep 5, 2020
20 changes: 13 additions & 7 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ When translating UDFs to Catalyst expressions, the supported UDF functions are l
| | lhs >> rhs |
| | lhs >>> rhs |
| Conditional | if |
| | case |
| Math | abs(x) |
| | case |
| Math | abs(x) |
| | cos(x) |
| | acos(x) |
| | asin(x) |
Expand Down Expand Up @@ -358,8 +358,14 @@ When translating UDFs to Catalyst expressions, the supported UDF functions are l
| | x.contains(CharSequence s) |
| | x.indexOf(String str) |
| | x.indexOf(String str, int fromIndex) |
| |x.replaceAll(String regex, String replacement) |
| |x.split(String regex) |
| |x.split(String regex, int limit) |
| |x.getBytes() |
| |x.getBytes(String charsetName) |
| | x.replaceAll(String regex, String replacement) |
| | x.split(String regex) |
| | x.split(String regex, int limit) |
| | x.getBytes() |
| | x.getBytes(String charsetName) |
| Date and Time | LocalDateTime.parse(x, DateTimeFormatter.ofPattern(pattern)).getYear |
| | LocalDateTime.parse(x, DateTimeFormatter.ofPattern(pattern)).getMonthValue |
| | LocalDateTime.parse(x, DateTimeFormatter.ofPattern(pattern)).getDayOfMonth |
| | LocalDateTime.parse(x, DateTimeFormatter.ofPattern(pattern)).getHour |
| | LocalDateTime.parse(x, DateTimeFormatter.ofPattern(pattern)).getMinute |
| | LocalDateTime.parse(x, DateTimeFormatter.ofPattern(pattern)).getSecond |
210 changes: 176 additions & 34 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,39 @@ private object Repr {
var string: Expression = Literal.default(StringType)
}

case class DateTimeFormatter private (private[Repr] val pattern: Expression) extends CompilerInternal("java.time.format.DateTimeFormatter") {
def invoke(methodName: String, args: List[Expression]): Expression = {
methodName match {
case _ =>
throw new SparkException(s"Unsupported DateTimeFormatter op ${methodName}")
}
}
}
object DateTimeFormatter {
private def apply(pattern: Expression): DateTimeFormatter = new DateTimeFormatter(pattern)
def ofPattern(pattern: Expression): DateTimeFormatter = DateTimeFormatter(pattern)
}

case class LocalDateTime private (private val dateTime: Expression) extends CompilerInternal("java.time.LocalDateTime") {
abellina marked this conversation as resolved.
Show resolved Hide resolved
def invoke(methodName: String, args: List[Expression]): Expression = {
methodName match {
case "getYear" => Year(dateTime)
case "getMonthValue" => Month(dateTime)
case "getDayOfMonth" => DayOfMonth(dateTime)
case "getHour" => Hour(dateTime)
case "getMinute" => Minute(dateTime)
case "getSecond" => Second(dateTime)
case _ =>
throw new SparkException(s"Unsupported DateTimeFormatter op ${methodName}")
}
}
}
object LocalDateTime {
private def apply(pattern: Expression): LocalDateTime = { new LocalDateTime(pattern) }
def parse(text: Expression, formatter: DateTimeFormatter): LocalDateTime = {
LocalDateTime(new ParseToTimestamp(text, formatter.pattern))
}
}
}

/**
Expand Down Expand Up @@ -309,6 +342,15 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
mathOp(lambdaReflection, method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("scala.Predef$")) {
State(locals,
predefOp(lambdaReflection, method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("java.lang.Double")) {
State(locals, doubleOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.lang.Float")) {
State(locals, floatOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.lang.String")) {
State(locals, stringOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.lang.StringBuilder")) {
Expand All @@ -318,13 +360,35 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
val retval = args.head.asInstanceOf[Repr.StringBuilder]
.invoke(method.getName, args.tail)
State(locals, retval :: rest, cond, expr)
} else if (declaringClassName.equals("java.time.format.DateTimeFormatter")) {
State(locals, dateTimeFormatterOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.time.LocalDateTime")) {
State(locals, localDateTimeOp(method.getName, args) :: rest, cond, expr)
} else {
// Other functions
throw new SparkException("Unsupported instruction: " + Opcode.INVOKEVIRTUAL)
throw new SparkException(s"Unsupported instruction: ${Opcode.INVOKEVIRTUAL} ${declaringClassName}")
}
}

def mathOp(lambdaReflection: LambdaReflection,
private def checkArgs(methodName: String,
expectedTypes: List[DataType],
args: List[Expression]): Unit = {
if (args.length != expectedTypes.length) {
throw new SparkException(
s"${methodName} operation expects ${expectedTypes.length} " +
s"argument(s), including an objref, but instead got ${args.length} " +
s"argument(s)")
}
args.view.zip(expectedTypes.view).foreach { case (arg, expectedType) =>
if (arg.dataType != expectedType) {
throw new SparkException(s"${arg.dataType} argument found for " +
s"${methodName} where " +
s"${expectedType} argument is expected.")
}
}
}

private def mathOp(lambdaReflection: LambdaReflection,
methodName: String, args: List[Expression]): Expression = {
// Math unary functions
if (args.length != 2) {
Expand All @@ -334,7 +398,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
}
// Make sure that the objref is scala.math.package$.
args.head match {
case Literal(index, IntegerType) =>
case IntegerLiteral(index) =>
if (!lambdaReflection.lookupField(index.asInstanceOf[Int])
.getType.getName.equals("scala.math.package$")) {
throw new SparkException("Unsupported math function objref: " + args.head)
Expand Down Expand Up @@ -364,56 +428,86 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
}
}

def stringOp(methodName: String, args: List[Expression]): Expression = {
def checkArgs(expectedTypes: List[DataType]): Unit = {
if (args.length != expectedTypes.length) {
throw new SparkException(
s"String.${methodName} operation expects ${expectedTypes.length} " +
s"argument(s), including an objref, but instead got ${args.length} " +
s"argument(s)")
}
args.view.zip(expectedTypes.view).foreach { case (arg, expectedType) =>
if (arg.dataType != expectedType) {
throw new SparkException(s"${arg.dataType} argument found for " +
s"String.${methodName} where " +
s"${expectedType} argument is expected.")
private def predefOp(lambdaReflection: LambdaReflection,
methodName: String, args: List[Expression]): Expression = {
// Make sure that the objref is scala.math.package$.
args.head match {
case IntegerLiteral(index) =>
if (!lambdaReflection.lookupField(index.asInstanceOf[Int])
.getType.getName.equals("scala.Predef$")) {
throw new SparkException("Unsupported predef function objref: " + args.head)
}
}
case _ =>
throw new SparkException("Unsupported predef function objref: " + args.head)
}
// Translate to Catalyst
methodName match {
case "double2Double" =>
checkArgs(methodName, List(IntegerType, DoubleType), args)
args.last
case "float2Float" =>
checkArgs(methodName, List(IntegerType, FloatType), args)
args.last
case _ => throw new SparkException("Unsupported predef function: " + methodName)
}
}

private def doubleOp(methodName: String, args: List[Expression]): Expression = {
methodName match {
case "isNaN" =>
checkArgs(methodName, List(DoubleType), args)
IsNaN(args.head)
case _ =>
throw new SparkException(s"Unsupported Double function: " +
s"Double.${methodName}")
}
}

private def floatOp(methodName: String, args: List[Expression]): Expression = {
methodName match {
case "isNaN" =>
checkArgs(methodName, List(FloatType), args)
IsNaN(args.head)
case _ =>
throw new SparkException(s"Unsupported Float function: " +
s"Float.${methodName}")
}
}

private def stringOp(methodName: String, args: List[Expression]): Expression = {
methodName match {
case "concat" =>
checkArgs(List(StringType, StringType))
checkArgs(methodName, List(StringType, StringType), args)
Concat(args)
case "contains" =>
checkArgs(List(StringType, StringType))
checkArgs(methodName, List(StringType, StringType), args)
Contains(args.head, args.last)
case "endsWith" =>
checkArgs(List(StringType, StringType))
checkArgs(methodName, List(StringType, StringType), args)
EndsWith(args.head, args.last)
case "equals" =>
checkArgs(List(StringType, StringType))
checkArgs(methodName, List(StringType, StringType), args)
Cast(EqualNullSafe(args.head, args.last), IntegerType)
case "equalsIgnoreCase" =>
checkArgs(List(StringType, StringType))
checkArgs(methodName, List(StringType, StringType), args)
Cast(EqualNullSafe(Upper(args.head), Upper(args.last)), IntegerType)
case "isEmpty" =>
checkArgs(List(StringType))
checkArgs(methodName, List(StringType), args)
Cast(EqualTo(Length(args.head), Literal(0)), IntegerType)
case "length" =>
checkArgs(List(StringType))
checkArgs(methodName, List(StringType), args)
Length(args.head)
case "startsWith" =>
checkArgs(List(StringType, StringType))
checkArgs(methodName, List(StringType, StringType), args)
StartsWith(args.head, args.last)
case "toLowerCase" =>
checkArgs(List(StringType))
checkArgs(methodName, List(StringType), args)
Lower(args.head)
case "toUpperCase" =>
checkArgs(List(StringType))
checkArgs(methodName, List(StringType), args)
Upper(args.head)
case "trim" =>
checkArgs(List(StringType))
checkArgs(methodName, List(StringType), args)
StringTrim(args.head)
case "replace" =>
if (args.length != 3) {
Expand All @@ -436,7 +530,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
s"${args(2).dataType}")
}
case "substring" =>
checkArgs(StringType :: List.fill(args.length - 1)(IntegerType))
checkArgs(methodName, StringType :: List.fill(args.length - 1)(IntegerType), args)
Substring(args(0),
Add(args(1), Literal(1)),
Subtract(if (args.length == 3) args(2) else Length(args(0)),
Expand Down Expand Up @@ -485,14 +579,14 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
s"argument(s)")
}
case "replaceAll" =>
checkArgs(List(StringType, StringType, StringType))
checkArgs(methodName, List(StringType, StringType, StringType), args)
RegExpReplace(args(0), args(1), args(2))
case "split" =>
if (args.length == 2) {
checkArgs(List(StringType, StringType))
checkArgs(methodName, List(StringType, StringType), args)
StringSplit(args(0), args(1), Literal(-1))
} else if (args.length == 3) {
checkArgs(List(StringType, StringType, IntegerType))
checkArgs(methodName, List(StringType, StringType, IntegerType), args)
StringSplit(args(0), args(1), args(2))
} else {
throw new SparkException(
Expand All @@ -502,10 +596,10 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
}
case "getBytes" =>
if (args.length == 1) {
checkArgs(List(StringType))
checkArgs(methodName, List(StringType), args)
Encode(args.head, Literal(Charset.defaultCharset.toString))
} else if (args.length == 2) {
checkArgs(List(StringType, StringType))
checkArgs(methodName, List(StringType, StringType), args)
Encode(args.head, args.last)
} else {
throw new SparkException(
Expand All @@ -518,6 +612,54 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
s"String.${methodName}")
}
}

private def dateTimeFormatterOp(methodName: String, args: List[Expression]): Expression = {
def checkPattern(pattern: String): Boolean = {
pattern.foldLeft(false){
case (escapedText, '\'') => !escapedText
case (false, c) if "VzOXxZ".exists(_ == c) =>
// The pattern isn't timezone agnostic.
throw new SparkException("Unsupported pattern: " +
"only timezone agnostic patterns are supported")
case (escapedText, _) => escapedText
}
}
methodName match {
case "ofPattern" =>
checkArgs(methodName, List(StringType), args)
// The pattern needs to be known at compile time as we need to check
// whether the pattern is timezone agnostic. If it isn't, it needs
// to fall back to JVM.
args.head match {
case StringLiteral(pattern) =>
checkPattern(pattern)
Repr.DateTimeFormatter.ofPattern(args.head)
case _ =>
// The pattern isn't known at compile time.
throw new SparkException("Unsupported pattern: only string literals are supported")
}
case _ =>
throw new SparkException(s"Unsupported function: " +
s"DateTimeFormatter.${methodName}")
}
}

private def localDateTimeOp(methodName: String, args: List[Expression]): Expression = {
methodName match {
case "parse" =>
checkArgs(methodName, List(StringType), List(args.head))
if (!args.last.isInstanceOf[Repr.DateTimeFormatter]) {
throw new SparkException("Unexpected argument for LocalDateTime.parse")
}
Repr.LocalDateTime.parse(args.head, args.last.asInstanceOf[Repr.DateTimeFormatter])
case "getYear" | "getMonthValue" | "getDayOfMonth" |
"getHour" | "getMinute" | "getSecond" =>
args.head.asInstanceOf[Repr.LocalDateTime].invoke(methodName, args.tail)
case _ =>
throw new SparkException(s"Unsupported function: " +
s"DateTimeFormatter.${methodName}")
}
}
}

/**
Expand Down
Loading