Skip to content

Commit

Permalink
fix(onnx): type check problems (#3257)
Browse files Browse the repository at this point in the history
Signed-off-by: Jiyong Huang <[email protected]>
  • Loading branch information
ngjaying authored Sep 29, 2024
1 parent b560b08 commit 345d7ac
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 107 deletions.
8 changes: 8 additions & 0 deletions docs/directory.json
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@
"title": "使用 tensorflow_lite 原生插件调用 AI 模型",
"path": "guide/ai/tensorflow_lite"
},
{
"title": "使用 onnx 原生插件调用 AI 模型",
"path": "guide/ai/onnx"
},
{
"title": "使用外部函数调用 tensorflow_lite AI 模型",
"path": "guide/ai/tensorflow_lite_external_function_tutorial"
Expand Down Expand Up @@ -929,6 +933,10 @@
"title": "Running AI Model with TensorFlow Lite Function Plugin",
"path": "guide/ai/tensorflow_lite"
},
{
"title": "Running AI Model with Onnx Function Plugin",
"path": "guide/ai/onnx"
},
{
"title": "Running tensorflow_lite AI Model with External Function",
"path": "guide/ai/tensorflow_lite_external_function_tutorial"
Expand Down
110 changes: 54 additions & 56 deletions docs/en_US/guide/ai/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ By integrating eKuiper and ONNX, users can simply upload pre-built ONNX models a

## Prerequisites

### 模型下载
### Download Models

To run the ONNX interpreter, a trained model is needed. This tutorial will not cover training or model specifics; you can learn how to do this by checking the [ONNX tutorials](https://github.com/onnx/tutorials#converting-to-onnx-format).
We can either train a new model or choose an existing one.
Expand All @@ -36,9 +36,8 @@ Note that the model input data format must be a float array, so the data type mu

```shell
POST /streams
Host: 192.168.116.128:9081
Content-Type: application/json
Content-Length: 109

{
"sql": "CREATE STREAM onnxPubImg (data array(float)) WITH (DATASOURCE=\"onnxPubImg\", FORMAT=\"json\")"
}
Expand All @@ -65,47 +64,67 @@ Rest API rule creation to call the model:
{
"log": {},
"mqtt": {
"server": "127.0.0.1:1883",
"server": "tcp://127.0.0.1:1883",
"topic": "demoresult"
}
}
],
"type": "string"
]
}
```

### Verifying Results

The results are shown in the image below, indicating the predicted probabilities of different digits in the input image.

![result query](../../resources/mqttx_mnist.png)
![result query](../../../resources/mqttx_mnist.png)

You can use a program like the one below to send images located in the ONNX directory.

```go
func TestSum(t *testing.T) {
const TOPIC = "sum_diff_pub"
func TestPic(t *testing.T) {
const TOPIC = "onnxPubImg"

images := []string{
"img.png",
// 其他你需要的图像
}
opts := mqtt.NewClientOptions().AddBroker("tcp://localhost:1883")
client := mqtt.NewClient(opts)
if token := client.Connect(); token.Wait() && token.Error() != nil {
panic(token.Error())
}
payloadF32 := []float32{0.2, 0.3, 0.6, 0.9}
payloadUnMarshal := MqttPayLoadFloat32Slice{
Data: payloadF32,
}
payload, err := json.Marshal(payloadUnMarshal)
if err != nil {
fmt.Println(err)
return
}
for _, image := range images {
fmt.Println("Publishing " + image)
inputImage, err := NewProcessedImage(image, false)

if token := client.Publish(TOPIC, 2, true, payload); token.Wait() && token.Error() != nil {
fmt.Println(token.Error())
} else {
fmt.Println("Published ")
}
if err != nil {
fmt.Println(err)
continue
}
// payload, err := os.ReadFile(image)
payloadF32 := inputImage.GetNetworkInput()

data := make([]any, len(payloadF32))
for i := 0; i < len(data); i++ {
data[i] = payloadF32[i]
}
payloadUnMarshal := MqttPayLoadFloat32Slice{
Data: payloadF32,
}
payload, err := json.Marshal(payloadUnMarshal)
if err != nil {
fmt.Println(err)
continue
} else {
fmt.Println(string(payload))
}
if token := client.Publish(TOPIC, 2, true, payload); token.Wait() && token.Error() != nil {
fmt.Println(token.Error())
} else {
fmt.Println("Published " + image)
}
time.Sleep(1 * time.Second)
}
client.Disconnect(0)
}
```
Expand All @@ -129,10 +148,7 @@ The following image shows using the Rest API to call the model.

```shell
POST /rules
Host: 192.168.116.128:9081
User-Agent: Apifox/1.0.0 (https://apifox.com)
Content-Type: application/json
Content-Length: 319

{
"id": "ruleSum",
Expand All @@ -141,12 +157,11 @@ Content-Length: 319
{
"log": {},
"mqtt": {
"server": "127.0.0.1:1883",
"server": "tcp://127.0.0.1:1883",
"topic": "demoresult"
}
}
],
"type": "string"
]
}
```

Expand All @@ -167,35 +182,18 @@ The results are shown in the image below, with the inference returning:
]
```

![result query](../../resources/mqttx_sum_and_difference.png)
![result query](../../../resources/mqttx_sum_and_difference.png)

You can use a program like the one below to send test data.
Send test data like below through MQTT client.

```go
func TestSum(t *testing.T) {
const TOPIC = "sum_diff_pub"

opts := mqtt.NewClientOptions().AddBroker("tcp://localhost:1883")
client := mqtt.NewClient(opts)
if token := client.Connect(); token.Wait() && token.Error() != nil {
panic(token.Error())
}
payloadF32 := []float32{0.2, 0.3, 0.6, 0.9}
payloadUnMarshal := MqttPayLoadFloat32Slice{
Data: payloadF32,
}
payload, err := json.Marshal(payloadUnMarshal)
if err != nil {
fmt.Println(err)
return
}

if token := client.Publish(TOPIC, 2, true, payload); token.Wait() && token.Error() != nil {
fmt.Println(token.Error())
} else {
fmt.Println("Published ")
}
client.Disconnect(0)
```json
{
"data": [
0.2,
0.3,
0.6,
0.9
]
}
```

Expand Down
File renamed without changes
56 changes: 16 additions & 40 deletions docs/zh_CN/guide/ai/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@

```shell
POST /streams
Host: 192.168.116.128:9081
Content-Type: application/json
Content-Length: 109

{
"sql": "CREATE STREAM onnxPubImg (data array(float)) WITH (DATASOURCE=\"onnxPubImg\", FORMAT=\"json\")"
}
Expand All @@ -62,20 +61,19 @@ Rest API 创建规则以调用模型:
{
"log": {},
"mqtt": {
"server": "127.0.0.1:1883",
"server": "tcp://127.0.0.1:1883",
"topic": "demoresult"
}
}
],
"type": "string"
]
}
```

### 验证结果

结果如下图所示,输入图片之后,推导出图片中不同数字的输出可能性。

![结果查询](../../resources/mqttx_mnist.png)
![结果查询](../../../resources/mqttx_mnist.png)

你可以使用类似如下程序的方式来发送图片,图片位于ONNX目录下。

Expand Down Expand Up @@ -146,10 +144,6 @@ func TestPic(t *testing.T) {

```shell
POST /rules
Host: 192.168.116.128:9081
User-Agent: Apifox/1.0.0 (https://apifox.com)
Content-Type: application/json
Content-Length: 319

{
"id": "ruleSum",
Expand All @@ -158,12 +152,11 @@ Content-Length: 319
{
"log": {},
"mqtt": {
"server": "127.0.0.1:1883",
"server": "tcp://127.0.0.1:1883",
"topic": "demoresult"
}
}
],
"type": "string"
]
}
```

Expand All @@ -175,35 +168,18 @@ Content-Length: 319
[{"onnx":[[1.9999883,0.60734314]]}]
```

![结果查询](../../resources/mqttx_sum_and_difference.png)

你可以使用类似如下程序的方式来发送测试数据。

```go
func TestSum(t *testing.T) {
const TOPIC = "sum_diff_pub"
![结果查询](../../../resources/mqttx_sum_and_difference.png)

opts := mqtt.NewClientOptions().AddBroker("tcp://localhost:1883")
client := mqtt.NewClient(opts)
if token := client.Connect(); token.Wait() && token.Error() != nil {
panic(token.Error())
}
payloadF32 := []float32{0.2, 0.3, 0.6, 0.9}
payloadUnMarshal := MqttPayLoadFloat32Slice{
Data: payloadF32,
}
payload, err := json.Marshal(payloadUnMarshal)
if err != nil {
fmt.Println(err)
return
}
你可以使用 MQTT 客户端发送测试数据。

if token := client.Publish(TOPIC, 2, true, payload); token.Wait() && token.Error() != nil {
fmt.Println(token.Error())
} else {
fmt.Println("Published ")
}
client.Disconnect(0)
```json
{
"data": [
0.2,
0.3,
0.6,
0.9
]
}
```

Expand Down
6 changes: 1 addition & 5 deletions extensions/functions/onnx/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,16 @@

dir=/usr/local/onnx

# 获取操作系统和架构信息
OS=$(uname -s)
ARCH=$(uname -m)

# 检查操作系统和架构并返回对应的库路径


cur=$(dirname "$0")
echo "Base path $cur"
if [ -d "$dir" ]; then
echo "SDK path $dir exists."
else
echo "Creating SDK path $dir"
mkdir -p $dir
mkdir -p $dir/lib
echo "Created SDK path $dir"
echo "Moving libs"

Expand Down
11 changes: 5 additions & 6 deletions extensions/functions/onnx/onnx.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021-2024 思无邪. All rights reserved.
// Copyright 2021-2024 EMQ Technologies Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,6 +46,7 @@ func (f *OnnxFunc) Validate(args []interface{}) error {
}

func (f *OnnxFunc) Exec(ctx api.FunctionContext, args []any) (any, bool) {
ctx.GetLogger().Debugf("onnx args %[1]T(%[1]v)", args)
modelName, ok := args[0].(string)
if !ok {
return fmt.Errorf("onnx function first parameter must be a string, but got %[1]T(%[1]v)", args[0]), false
Expand All @@ -58,19 +59,18 @@ func (f *OnnxFunc) Exec(ctx api.FunctionContext, args []any) (any, bool) {
if len(args)-1 != inputCount {
return fmt.Errorf("onnx function requires %d tensors but got %d", inputCount, len(args)-1), false
}

ctx.GetLogger().Debugf("onnx function %s with %d tensors", modelName, inputCount)

var inputTensors []ort.ArbitraryTensor
// Set input tensors
for i := 1; i < len(args); i++ {

inputInfo := interpreter.inputInfo[i-1]
var arg []interface{}
switch v := args[i].(type) {
case []any: // only supports one dimensional arg. Even dim 0 must be an array of 1 element
arg = v
return fmt.Errorf("onnx function parameter %d must be a bytea or array of bytea, but got %[1]T(%[1]v)", i, v), false
default:
return fmt.Errorf("onnx function parameter %d must be a bytea or array of bytea, but got %[1]T(%[1]v)", v), false
}

notSupportedDataLen := -1
Expand Down Expand Up @@ -256,10 +256,9 @@ func (f *OnnxFunc) Exec(ctx api.FunctionContext, args []any) (any, bool) {
modelParaLen *= inputInfo.Dimensions[j]
}
ctx.GetLogger().Debugf("receive tensor %v, require %d length", arg, modelParaLen)
if modelParaLen != inputTensors[i].GetShape().FlattenedSize() {
if modelParaLen != inputTensors[i-1].GetShape().FlattenedSize() {
return fmt.Errorf("onnx function input tensor %d must have %d elements but got %d", i-1, modelParaLen, len(arg)), false
}

}
// todo :optimize: avoid creating output tensor every time

Expand Down

0 comments on commit 345d7ac

Please sign in to comment.