From 2739898271ebc9fea789db9391aba3f8c6037bbe Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Wed, 3 Jun 2020 09:18:38 +0800 Subject: [PATCH] Support bool value in tf dataset (#2401) * support bool value in tf dataset * fix style --- .../scala/com/intel/analytics/bigdl/common/TFUtils.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/common/TFUtils.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/common/TFUtils.scala index 8bde8aa6e24..d7f23f6f0c0 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/common/TFUtils.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/common/TFUtils.scala @@ -96,7 +96,7 @@ object TFUtils { val dataType = t.dataType() val numericDataTypes = Set(DataType.FLOAT, - DataType.UINT8, DataType.INT32, DataType.INT64, DataType.DOUBLE) + DataType.UINT8, DataType.INT32, DataType.INT64, DataType.DOUBLE, DataType.BOOL) if (dataType == DataType.STRING) { val outputTensor = output.asInstanceOf[Tensor[Array[Byte]]] @@ -148,6 +148,13 @@ object TFUtils { val buffer = DoubleBuffer.wrap(arr) t.writeTo(buffer) double2float(arr, outputTensor.storage().array(), outputTensor.storageOffset() - 1) + case DataType.BOOL => + val outputTensor = output.asInstanceOf[Tensor[Float]] + val arr = new Array[Byte](t.numBytes()) + assert(t.numBytes() == shape.product, "sanity check") + val buffer = ByteBuffer.wrap(arr) + t.writeTo(buffer) + byte2float(arr, outputTensor.storage().array(), outputTensor.storageOffset() - 1) } } else {