From 345d7ac7c977a7310750b4fde42abdda9a3b3509 Mon Sep 17 00:00:00 2001 From: ngjaying Date: Sun, 29 Sep 2024 13:52:32 +0800 Subject: [PATCH] fix(onnx): type check problems (#3257) Signed-off-by: Jiyong Huang --- docs/directory.json | 8 ++ docs/en_US/guide/ai/onnx.md | 110 +++++++++--------- docs/{zh_CN => }/resources/mqttx_mnist.png | Bin .../resources/mqttx_sum_and_difference.png | Bin docs/zh_CN/guide/ai/onnx.md | 56 +++------ extensions/functions/onnx/install.sh | 6 +- extensions/functions/onnx/onnx.go | 11 +- 7 files changed, 84 insertions(+), 107 deletions(-) rename docs/{zh_CN => }/resources/mqttx_mnist.png (100%) rename docs/{zh_CN => }/resources/mqttx_sum_and_difference.png (100%) diff --git a/docs/directory.json b/docs/directory.json index f8a76c7ced..3d08784d56 100644 --- a/docs/directory.json +++ b/docs/directory.json @@ -157,6 +157,10 @@ "title": "使用 tensorflow_lite 原生插件调用 AI 模型", "path": "guide/ai/tensorflow_lite" }, + { + "title": "使用 onnx 原生插件调用 AI 模型", + "path": "guide/ai/onnx" + }, { "title": "使用外部函数调用 tensorflow_lite AI 模型", "path": "guide/ai/tensorflow_lite_external_function_tutorial" @@ -929,6 +933,10 @@ "title": "Running AI Model with TensorFlow Lite Function Plugin", "path": "guide/ai/tensorflow_lite" }, + { + "title": "Running AI Model with Onnx Function Plugin", + "path": "guide/ai/onnx" + }, { "title": "Running tensorflow_lite AI Model with External Function", "path": "guide/ai/tensorflow_lite_external_function_tutorial" diff --git a/docs/en_US/guide/ai/onnx.md b/docs/en_US/guide/ai/onnx.md index c1d5581b06..581e9ad924 100644 --- a/docs/en_US/guide/ai/onnx.md +++ b/docs/en_US/guide/ai/onnx.md @@ -10,7 +10,7 @@ By integrating eKuiper and ONNX, users can simply upload pre-built ONNX models a ## Prerequisites -### 模型下载 +### Download Models To run the ONNX interpreter, a trained model is needed. This tutorial will not cover training or model specifics; you can learn how to do this by checking the [ONNX tutorials](https://github.com/onnx/tutorials#converting-to-onnx-format). We can either train a new model or choose an existing one. @@ -36,9 +36,8 @@ Note that the model input data format must be a float array, so the data type mu ```shell POST /streams -Host: 192.168.116.128:9081 Content-Type: application/json -Content-Length: 109 + { "sql": "CREATE STREAM onnxPubImg (data array(float)) WITH (DATASOURCE=\"onnxPubImg\", FORMAT=\"json\")" } @@ -65,12 +64,11 @@ Rest API rule creation to call the model: { "log": {}, "mqtt": { - "server": "127.0.0.1:1883", + "server": "tcp://127.0.0.1:1883", "topic": "demoresult" } } - ], - "type": "string" + ] } ``` @@ -78,34 +76,55 @@ Rest API rule creation to call the model: The results are shown in the image below, indicating the predicted probabilities of different digits in the input image. -![result query](../../resources/mqttx_mnist.png) +![result query](../../../resources/mqttx_mnist.png) You can use a program like the one below to send images located in the ONNX directory. ```go -func TestSum(t *testing.T) { - const TOPIC = "sum_diff_pub" +func TestPic(t *testing.T) { +const TOPIC = "onnxPubImg" +images := []string{ +"img.png", +// 其他你需要的图像 +} opts := mqtt.NewClientOptions().AddBroker("tcp://localhost:1883") client := mqtt.NewClient(opts) if token := client.Connect(); token.Wait() && token.Error() != nil { panic(token.Error()) } - payloadF32 := []float32{0.2, 0.3, 0.6, 0.9} - payloadUnMarshal := MqttPayLoadFloat32Slice{ - Data: payloadF32, - } - payload, err := json.Marshal(payloadUnMarshal) - if err != nil { - fmt.Println(err) - return - } +for _, image := range images { +fmt.Println("Publishing " + image) +inputImage, err := NewProcessedImage(image, false) - if token := client.Publish(TOPIC, 2, true, payload); token.Wait() && token.Error() != nil { - fmt.Println(token.Error()) - } else { - fmt.Println("Published ") - } +if err != nil { +fmt.Println(err) +continue +} +// payload, err := os.ReadFile(image) +payloadF32 := inputImage.GetNetworkInput() + +data := make([]any, len(payloadF32)) +for i := 0; i < len(data); i++ { +data[i] = payloadF32[i] +} +payloadUnMarshal := MqttPayLoadFloat32Slice{ +Data: payloadF32, +} +payload, err := json.Marshal(payloadUnMarshal) +if err != nil { +fmt.Println(err) +continue +} else { +fmt.Println(string(payload)) +} +if token := client.Publish(TOPIC, 2, true, payload); token.Wait() && token.Error() != nil { +fmt.Println(token.Error()) +} else { +fmt.Println("Published " + image) +} +time.Sleep(1 * time.Second) +} client.Disconnect(0) } ``` @@ -129,10 +148,7 @@ The following image shows using the Rest API to call the model. ```shell POST /rules -Host: 192.168.116.128:9081 -User-Agent: Apifox/1.0.0 (https://apifox.com) Content-Type: application/json -Content-Length: 319 { "id": "ruleSum", @@ -141,12 +157,11 @@ Content-Length: 319 { "log": {}, "mqtt": { - "server": "127.0.0.1:1883", + "server": "tcp://127.0.0.1:1883", "topic": "demoresult" } } - ], - "type": "string" + ] } ``` @@ -167,35 +182,18 @@ The results are shown in the image below, with the inference returning: ] ``` -![result query](../../resources/mqttx_sum_and_difference.png) +![result query](../../../resources/mqttx_sum_and_difference.png) -You can use a program like the one below to send test data. +Send test data like below through MQTT client. -```go -func TestSum(t *testing.T) { - const TOPIC = "sum_diff_pub" - - opts := mqtt.NewClientOptions().AddBroker("tcp://localhost:1883") - client := mqtt.NewClient(opts) - if token := client.Connect(); token.Wait() && token.Error() != nil { - panic(token.Error()) - } - payloadF32 := []float32{0.2, 0.3, 0.6, 0.9} - payloadUnMarshal := MqttPayLoadFloat32Slice{ - Data: payloadF32, - } - payload, err := json.Marshal(payloadUnMarshal) - if err != nil { - fmt.Println(err) - return - } - - if token := client.Publish(TOPIC, 2, true, payload); token.Wait() && token.Error() != nil { - fmt.Println(token.Error()) - } else { - fmt.Println("Published ") - } - client.Disconnect(0) +```json +{ + "data": [ + 0.2, + 0.3, + 0.6, + 0.9 + ] } ``` diff --git a/docs/zh_CN/resources/mqttx_mnist.png b/docs/resources/mqttx_mnist.png similarity index 100% rename from docs/zh_CN/resources/mqttx_mnist.png rename to docs/resources/mqttx_mnist.png diff --git a/docs/zh_CN/resources/mqttx_sum_and_difference.png b/docs/resources/mqttx_sum_and_difference.png similarity index 100% rename from docs/zh_CN/resources/mqttx_sum_and_difference.png rename to docs/resources/mqttx_sum_and_difference.png diff --git a/docs/zh_CN/guide/ai/onnx.md b/docs/zh_CN/guide/ai/onnx.md index ff99ed1df5..59700880d4 100644 --- a/docs/zh_CN/guide/ai/onnx.md +++ b/docs/zh_CN/guide/ai/onnx.md @@ -33,9 +33,8 @@ ```shell POST /streams -Host: 192.168.116.128:9081 Content-Type: application/json -Content-Length: 109 + { "sql": "CREATE STREAM onnxPubImg (data array(float)) WITH (DATASOURCE=\"onnxPubImg\", FORMAT=\"json\")" } @@ -62,12 +61,11 @@ Rest API 创建规则以调用模型: { "log": {}, "mqtt": { - "server": "127.0.0.1:1883", + "server": "tcp://127.0.0.1:1883", "topic": "demoresult" } } - ], - "type": "string" + ] } ``` @@ -75,7 +73,7 @@ Rest API 创建规则以调用模型: 结果如下图所示,输入图片之后,推导出图片中不同数字的输出可能性。 -![结果查询](../../resources/mqttx_mnist.png) +![结果查询](../../../resources/mqttx_mnist.png) 你可以使用类似如下程序的方式来发送图片,图片位于ONNX目录下。 @@ -146,10 +144,6 @@ func TestPic(t *testing.T) { ```shell POST /rules -Host: 192.168.116.128:9081 -User-Agent: Apifox/1.0.0 (https://apifox.com) -Content-Type: application/json -Content-Length: 319 { "id": "ruleSum", @@ -158,12 +152,11 @@ Content-Length: 319 { "log": {}, "mqtt": { - "server": "127.0.0.1:1883", + "server": "tcp://127.0.0.1:1883", "topic": "demoresult" } } - ], - "type": "string" + ] } ``` @@ -175,35 +168,18 @@ Content-Length: 319 [{"onnx":[[1.9999883,0.60734314]]}] ``` -![结果查询](../../resources/mqttx_sum_and_difference.png) - -你可以使用类似如下程序的方式来发送测试数据。 - -```go -func TestSum(t *testing.T) { - const TOPIC = "sum_diff_pub" +![结果查询](../../../resources/mqttx_sum_and_difference.png) - opts := mqtt.NewClientOptions().AddBroker("tcp://localhost:1883") - client := mqtt.NewClient(opts) - if token := client.Connect(); token.Wait() && token.Error() != nil { - panic(token.Error()) - } - payloadF32 := []float32{0.2, 0.3, 0.6, 0.9} - payloadUnMarshal := MqttPayLoadFloat32Slice{ - Data: payloadF32, - } - payload, err := json.Marshal(payloadUnMarshal) - if err != nil { - fmt.Println(err) - return - } +你可以使用 MQTT 客户端发送测试数据。 - if token := client.Publish(TOPIC, 2, true, payload); token.Wait() && token.Error() != nil { - fmt.Println(token.Error()) - } else { - fmt.Println("Published ") - } - client.Disconnect(0) +```json +{ + "data": [ + 0.2, + 0.3, + 0.6, + 0.9 + ] } ``` diff --git a/extensions/functions/onnx/install.sh b/extensions/functions/onnx/install.sh index 06d8980e13..bc6d4af17a 100644 --- a/extensions/functions/onnx/install.sh +++ b/extensions/functions/onnx/install.sh @@ -17,20 +17,16 @@ dir=/usr/local/onnx -# 获取操作系统和架构信息 OS=$(uname -s) ARCH=$(uname -m) -# 检查操作系统和架构并返回对应的库路径 - - cur=$(dirname "$0") echo "Base path $cur" if [ -d "$dir" ]; then echo "SDK path $dir exists." else echo "Creating SDK path $dir" - mkdir -p $dir + mkdir -p $dir/lib echo "Created SDK path $dir" echo "Moving libs" diff --git a/extensions/functions/onnx/onnx.go b/extensions/functions/onnx/onnx.go index 234fa10ad1..eff77430c8 100644 --- a/extensions/functions/onnx/onnx.go +++ b/extensions/functions/onnx/onnx.go @@ -1,4 +1,4 @@ -// Copyright 2021-2024 思无邪. All rights reserved. +// Copyright 2021-2024 EMQ Technologies Co., Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -46,6 +46,7 @@ func (f *OnnxFunc) Validate(args []interface{}) error { } func (f *OnnxFunc) Exec(ctx api.FunctionContext, args []any) (any, bool) { + ctx.GetLogger().Debugf("onnx args %[1]T(%[1]v)", args) modelName, ok := args[0].(string) if !ok { return fmt.Errorf("onnx function first parameter must be a string, but got %[1]T(%[1]v)", args[0]), false @@ -58,19 +59,18 @@ func (f *OnnxFunc) Exec(ctx api.FunctionContext, args []any) (any, bool) { if len(args)-1 != inputCount { return fmt.Errorf("onnx function requires %d tensors but got %d", inputCount, len(args)-1), false } - ctx.GetLogger().Debugf("onnx function %s with %d tensors", modelName, inputCount) var inputTensors []ort.ArbitraryTensor // Set input tensors for i := 1; i < len(args); i++ { - inputInfo := interpreter.inputInfo[i-1] var arg []interface{} switch v := args[i].(type) { case []any: // only supports one dimensional arg. Even dim 0 must be an array of 1 element arg = v - return fmt.Errorf("onnx function parameter %d must be a bytea or array of bytea, but got %[1]T(%[1]v)", i, v), false + default: + return fmt.Errorf("onnx function parameter %d must be a bytea or array of bytea, but got %[1]T(%[1]v)", v), false } notSupportedDataLen := -1 @@ -256,10 +256,9 @@ func (f *OnnxFunc) Exec(ctx api.FunctionContext, args []any) (any, bool) { modelParaLen *= inputInfo.Dimensions[j] } ctx.GetLogger().Debugf("receive tensor %v, require %d length", arg, modelParaLen) - if modelParaLen != inputTensors[i].GetShape().FlattenedSize() { + if modelParaLen != inputTensors[i-1].GetShape().FlattenedSize() { return fmt.Errorf("onnx function input tensor %d must have %d elements but got %d", i-1, modelParaLen, len(arg)), false } - } // todo :optimize: avoid creating output tensor every time