From 9195deef8b3c260983934010bc7f60efa93e6817 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Tue, 14 Nov 2023 14:25:34 -0800 Subject: [PATCH] fix: Support to Bool input for Onnx models (#2130) Signed-off-by: Jason Wang --- .../azure/synapse/ml/onnx/ONNXUtils.scala | 12 +++++++-- .../synapse/ml/onnx/ONNXModelSuite.scala | 25 +++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/deep-learning/src/main/scala/com/microsoft/azure/synapse/ml/onnx/ONNXUtils.scala b/deep-learning/src/main/scala/com/microsoft/azure/synapse/ml/onnx/ONNXUtils.scala index 9a98a5fb3f..4405e3aa40 100644 --- a/deep-learning/src/main/scala/com/microsoft/azure/synapse/ml/onnx/ONNXUtils.scala +++ b/deep-learning/src/main/scala/com/microsoft/azure/synapse/ml/onnx/ONNXUtils.scala @@ -150,6 +150,7 @@ object ONNXUtils { inferredShapes.length.toLong +: inferredShapes.head } + // scalastyle:off cyclomatic.complexity private[onnx] def loadTensorBuffer(env: OrtEnvironment, tensorInfo: TensorInfo, batchedValues: Seq[_], @@ -168,8 +169,15 @@ object ONNXUtils { assertBufferElementsWritten(size, actualCount, shape) buffer.rewind() OnnxTensor.createTensor(env, buffer, shape) + case OnnxJavaType.BOOL => + val buffer = ByteBuffer.allocateDirect(size) + val bool2byte: Boolean => Byte = b => if (b) 1.toByte else 0.toByte + val actualCount = writeNestedSeqToBuffer[Boolean](batchedValues, (bool2byte andThen buffer.put)(_)) + assertBufferElementsWritten(size, actualCount, shape) + buffer.rewind() + OnnxTensor.createTensor(env, buffer, shape, OnnxJavaType.BOOL) case OnnxJavaType.INT8 => - val buffer = ByteBuffer.allocate(size) + val buffer = ByteBuffer.allocateDirect(size) val actualCount = writeNestedSeqToBuffer[Byte](batchedValues, buffer.put(_)) assertBufferElementsWritten(size, actualCount, shape) buffer.rewind() @@ -197,7 +205,7 @@ object ONNXUtils { OnnxTensor.createTensor(env, flattened, shape) case other => throw new NotImplementedError(s"Tensor input type $other not supported. " + - s"Only FLOAT, DOUBLE, INT8, INT16, INT32, INT64, STRING types are supported.") + s"Only FLOAT, DOUBLE, BOOL, INT8, INT16, INT32, INT64, STRING types are supported.") } } diff --git a/deep-learning/src/test/scala/com/microsoft/azure/synapse/ml/onnx/ONNXModelSuite.scala b/deep-learning/src/test/scala/com/microsoft/azure/synapse/ml/onnx/ONNXModelSuite.scala index b228938f17..5cd28f5719 100644 --- a/deep-learning/src/test/scala/com/microsoft/azure/synapse/ml/onnx/ONNXModelSuite.scala +++ b/deep-learning/src/test/scala/com/microsoft/azure/synapse/ml/onnx/ONNXModelSuite.scala @@ -38,6 +38,7 @@ class ONNXModelSuite extends TestBase new TestObject(onnxMNIST, testDfMNIST), new TestObject(onnxAdultsIncome, testDfAdultsIncome), new TestObject(onnxGH1902, testDfGH1902), + new TestObject(onnxGH1996, testDfGH1996), new TestObject(onnxResNet50, testDfResNet50) ) @@ -285,6 +286,17 @@ class ONNXModelSuite extends TestBase .setMiniBatchSize(5000) } + private lazy val onnxGH1996 = { + spark + val model = downloadModel("GH1996.onnx", baseUrl) + new ONNXModel() + .setModelLocation(model.getPath) + .setDeviceType("CPU") + .setFeedDict(Map("A" -> "i1", "B" -> "i2")) + .setFetchDict(Map("Output" -> "Y")) + .setMiniBatchSize(5) + } + private lazy val testDfGH1902 = { val testDf = Seq( (39L, " State-gov", 77516L, " Bachelors", 13L, " Never-married", " Adm-clerical", @@ -296,6 +308,8 @@ class ONNXModelSuite extends TestBase testDf } + private lazy val testDfGH1996 = Seq((true, true), (true, false), (false, false)).toDF("i1", "i2") + test("ONNXModel can run transform for issue 1902") { val Array(row1, row2) = onnxGH1902.transform(testDfGH1902) .select("probability", "prediction") @@ -310,6 +324,17 @@ class ONNXModelSuite extends TestBase assert(row2._2 === 1.0) } + test("ONNXModel can run transform on boolean type (GH1996)") { + val Array(row1, row2, row3) = onnxGH1996.transform(testDfGH1996) + .orderBy(col("i1"), col("i2"), col("Output")) + .as[(Boolean, Boolean, Boolean)] + .collect() + + assert(row1 === (false, false, false)) + assert(row2 === (true, false, false)) + assert(row3 === (true, true, true)) + } + test("ONNXModel can translate zipmap output properly") { val Array(row1, row2) = onnxAdultsIncome.transform(testDfAdultsIncome) .select("probability", "prediction")