Skip to content

Commit

Permalink
Infernce Model string support of TFNet (intel-analytics#2452)
Browse files Browse the repository at this point in the history
Noted InferenceModel predict batchSize is removed
  • Loading branch information
Song Jiaming authored Jun 15, 2020
1 parent ad67877 commit 410d7b9
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class FlinkInference(params: SerParams)

override def open(parameters: Configuration): Unit = {
inferenceCnt = 0
model = params.model
// println("in open method, ", model)
logger = Logger.getLogger(getClass)
pre = new PreProcessing(params)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,31 @@ class PreProcessing(param: SerParams) {
val instance = Instances.fromArrow(byteBuffer)

val kvMap = instance.instances.flatMap(insMap => {
val oneInsMap = insMap.map(kv => {
if (kv._2.isInstanceOf[String]) {
val oneInsMap = insMap.map(kv =>
if (kv._1.contains("string")) {
(kv._1, decodeString(kv._2.asInstanceOf[String]))
}
else if (kv._2.isInstanceOf[String]) {
(kv._1, decodeImage(kv._2.asInstanceOf[String]))
} else {
(kv._1, decodeTensor(kv._2.asInstanceOf[(
ArrayBuffer[Int], ArrayBuffer[Float], ArrayBuffer[Int], ArrayBuffer[Int])]))
}
}).toList
}).toList
// Seq(T(oneInsMap.head, oneInsMap.tail: _*))
val arr = oneInsMap.map(x => x._2)
Seq(T.array(arr.toArray))
})
kvMap.head
}
def decodeString(s: String): Tensor[String] = {

val eleList = s.split("\\|")
val tensor = Tensor[String](eleList.length)
(1 to eleList.length).foreach(i => {
tensor.setValue(i, eleList(i - 1))
})
tensor
}
def decodeImage(s: String, idx: Int = 0): Tensor[Float] = {
byteBuffer = java.util.Base64.getDecoder.decode(s)
val mat = OpenCVMethod.fromImageBytes(byteBuffer, Imgcodecs.CV_LOAD_IMAGE_UNCHANGED)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.analytics.zoo.serving

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.zoo.serving.utils.{ClusterServingHelper, SerParams}
import org.scalatest.{FlatSpec, Matchers}

class InferenceSpec extends FlatSpec with Matchers {
"TF String input" should "work" in {
// val configPath = "/home/litchy/pro/analytics-zoo/config.yaml"
val str = "abc|dff|aoa"
val eleList = str.split("\\|")
// val helper = new ClusterServingHelper(configPath)
// helper.initArgs()
// val param = new SerParams(helper)
// val model = helper.loadInferenceModel()
// val res = model.doPredict(t)
// res
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,12 @@ class PreProcessingSpec extends FlatSpec with Matchers {
val a = pre.decodeTensor(info)
a
}
"decode string tensor" should "work" in {
val pre = new PreProcessing(null)
val str = "abc|dff|aoa"
val tensor = pre.decodeString(str)
assert(tensor.valueAt(1) == "abc")
assert(tensor.valueAt(2) == "dff")
assert(tensor.valueAt(3) == "aoa")
}
}

0 comments on commit 410d7b9

Please sign in to comment.