Skip to content

Commit

Permalink
fix: Support to Bool input for Onnx models (#2130)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Wang <[email protected]>
  • Loading branch information
memoryz authored Nov 14, 2023
1 parent 4c4fc8a commit 9195dee
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ object ONNXUtils {
inferredShapes.length.toLong +: inferredShapes.head
}

// scalastyle:off cyclomatic.complexity
private[onnx] def loadTensorBuffer(env: OrtEnvironment,
tensorInfo: TensorInfo,
batchedValues: Seq[_],
Expand All @@ -168,8 +169,15 @@ object ONNXUtils {
assertBufferElementsWritten(size, actualCount, shape)
buffer.rewind()
OnnxTensor.createTensor(env, buffer, shape)
case OnnxJavaType.BOOL =>
val buffer = ByteBuffer.allocateDirect(size)
val bool2byte: Boolean => Byte = b => if (b) 1.toByte else 0.toByte
val actualCount = writeNestedSeqToBuffer[Boolean](batchedValues, (bool2byte andThen buffer.put)(_))
assertBufferElementsWritten(size, actualCount, shape)
buffer.rewind()
OnnxTensor.createTensor(env, buffer, shape, OnnxJavaType.BOOL)
case OnnxJavaType.INT8 =>
val buffer = ByteBuffer.allocate(size)
val buffer = ByteBuffer.allocateDirect(size)
val actualCount = writeNestedSeqToBuffer[Byte](batchedValues, buffer.put(_))
assertBufferElementsWritten(size, actualCount, shape)
buffer.rewind()
Expand Down Expand Up @@ -197,7 +205,7 @@ object ONNXUtils {
OnnxTensor.createTensor(env, flattened, shape)
case other =>
throw new NotImplementedError(s"Tensor input type $other not supported. " +
s"Only FLOAT, DOUBLE, INT8, INT16, INT32, INT64, STRING types are supported.")
s"Only FLOAT, DOUBLE, BOOL, INT8, INT16, INT32, INT64, STRING types are supported.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ONNXModelSuite extends TestBase
new TestObject(onnxMNIST, testDfMNIST),
new TestObject(onnxAdultsIncome, testDfAdultsIncome),
new TestObject(onnxGH1902, testDfGH1902),
new TestObject(onnxGH1996, testDfGH1996),
new TestObject(onnxResNet50, testDfResNet50)
)

Expand Down Expand Up @@ -285,6 +286,17 @@ class ONNXModelSuite extends TestBase
.setMiniBatchSize(5000)
}

private lazy val onnxGH1996 = {
spark
val model = downloadModel("GH1996.onnx", baseUrl)
new ONNXModel()
.setModelLocation(model.getPath)
.setDeviceType("CPU")
.setFeedDict(Map("A" -> "i1", "B" -> "i2"))
.setFetchDict(Map("Output" -> "Y"))
.setMiniBatchSize(5)
}

private lazy val testDfGH1902 = {
val testDf = Seq(
(39L, " State-gov", 77516L, " Bachelors", 13L, " Never-married", " Adm-clerical",
Expand All @@ -296,6 +308,8 @@ class ONNXModelSuite extends TestBase
testDf
}

private lazy val testDfGH1996 = Seq((true, true), (true, false), (false, false)).toDF("i1", "i2")

test("ONNXModel can run transform for issue 1902") {
val Array(row1, row2) = onnxGH1902.transform(testDfGH1902)
.select("probability", "prediction")
Expand All @@ -310,6 +324,17 @@ class ONNXModelSuite extends TestBase
assert(row2._2 === 1.0)
}

test("ONNXModel can run transform on boolean type (GH1996)") {
val Array(row1, row2, row3) = onnxGH1996.transform(testDfGH1996)
.orderBy(col("i1"), col("i2"), col("Output"))
.as[(Boolean, Boolean, Boolean)]
.collect()

assert(row1 === (false, false, false))
assert(row2 === (true, false, false))
assert(row3 === (true, true, true))
}

test("ONNXModel can translate zipmap output properly") {
val Array(row1, row2) = onnxAdultsIncome.transform(testDfAdultsIncome)
.select("probability", "prediction")
Expand Down

0 comments on commit 9195dee

Please sign in to comment.