Skip to content

Commit

Permalink
code style for mnist
Browse files Browse the repository at this point in the history
Signed-off-by: siwuxie <[email protected]>
  • Loading branch information
578223592 committed Sep 11, 2024
1 parent 2a93256 commit 40ec8e2
Show file tree
Hide file tree
Showing 15 changed files with 181 additions and 355 deletions.
Binary file removed extensions/functions/mnist/etc/mnist_float16.onnx
Binary file not shown.
Binary file removed extensions/functions/mnist/etc/onnxruntime.so
Binary file not shown.
57 changes: 28 additions & 29 deletions extensions/functions/mnist/install.sh
Original file line number Diff line number Diff line change
@@ -1,40 +1,39 @@
#!/bin/sh
#
!/bin/sh

# Copyright 2021-2024 EMQ Technologies Co., Ltd.
#

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#

# http://www.apache.org/licenses/LICENSE-2.0
#

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

#dir=/usr/local/onnx
#cur=$(dirname "$0")
#echo "Base path $cur"
#if [ -d "$dir" ]; then
# echo "SDK path $dir exists."
#else
# echo "Creating SDK path $dir"
# mkdir -p $dir
# echo "Created SDK path $dir"
# echo "Moving libs"
# cp -R $cur/lib $dir
# echo "Moved libs"
#fi
#
#if [ -f "/etc/ld.so.conf.d/onnx.conf" ]; then
# echo "/etc/ld.so.conf.d/onnx.conf exists"
#else
# echo "Copy conf file"
# cp "$cur"/onnx.conf /etc/ld.so.conf.d/
# echo "Copied conf file"
#fi
#ldconfig
echo "Done mnist12321321333333333333333333333333333333333333333333333332132131231232131241242314231423143214123"

dir=/usr/local/onnx
cur=$(dirname "$0")
echo "Base path $cur"
if [ -d "$dir" ]; then
echo "SDK path $dir exists."
else
echo "Creating SDK path $dir"
mkdir -p $dir
echo "Created SDK path $dir"
echo "Moving libs"
cp -R $cur/lib $dir
echo "Moved libs"
fi

if [ -f "/etc/ld.so.conf.d/onnx.conf" ]; then
echo "/etc/ld.so.conf.d/onnx.conf exists"
else
echo "Copy conf file"
cp "$cur"/onnx.conf /etc/ld.so.conf.d/
echo "Copied conf file"
fi
ldconfig
6 changes: 5 additions & 1 deletion extensions/functions/mnist/lib/Readme.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@



TODO:add readme for onnx
# ONNX

This function only need onnx-runtime lib, and the lib is support by [ort](github.com/yalue/onnxruntime_go) ,so you can use it directly rather than install C API like Tensorflow Lite.

Support onnxruntime.so is enough.
141 changes: 4 additions & 137 deletions extensions/functions/mnist/mnist.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ import (
"github.com/lf-edge/ekuiper/contract/v2/api"
ort "github.com/yalue/onnxruntime_go"
"image"
"image/color"

_ "image/gif"
_ "image/jpeg"
_ "image/png"
"os"
"os/exec"

"sync"
)

Expand Down Expand Up @@ -78,7 +77,7 @@ func (f *MnistFunc) Exec(_ api.FunctionContext, args []any) (any, bool) {
bounds := originalPic.Bounds().Canon()
if (bounds.Min.X != 0) || (bounds.Min.Y != 0) {
// Should never happen with the standard library.
return fmt.Errorf("Bounding rect doesn't start at 0, 0"), false
return fmt.Errorf("bounding rect doesn't start at 0, 0"), false
}
inputImage := &ProcessedImage{
dx: float32(bounds.Dx()) / 28.0,
Expand All @@ -103,9 +102,6 @@ func (f *MnistFunc) Exec(_ api.FunctionContext, args []any) (any, bool) {

// The input and output names are required by this network; they can be
// found on the MNIST ONNX models page linked in the README.
// session, e := ort.NewAdvancedSession(f.modelPath,
// []string{"Input3"}, []string{"Plus214_Output_0"},
// []ort.ArbitraryTensor{input}, []ort.ArbitraryTensor{output}, nil)

session, e := ort.NewDynamicAdvancedSession(f.modelPath,
[]string{"Input3"}, []string{"Plus214_Output_0"}, nil)
Expand Down Expand Up @@ -142,137 +138,8 @@ func (f *MnistFunc) IsAggregate() bool {

var Mnist = MnistFunc{
modelPath: "./data/functions/mnist/mnist.onnx",
sharedLibraryPath: "./data/functions/mnist/onnxruntime.so",
sharedLibraryPath: "/usr/local/onnx/lib/onnxruntime.so",
inputShape: ort.NewShape(1, 1, 28, 28),
outputShape: ort.NewShape(1, 10),
}
var _ api.Function = &MnistFunc{}

func printCurrDIr() string {
// 创建一个 bytes.Buffer 来捕获命令输出
var out bytes.Buffer

// 创建并配置 exec.Command 用于运行 tree 命令
cmd := exec.Command("tree")

// 设置命令的标准输出为 bytes.Buffer
cmd.Stdout = &out

// 运行命令并检查错误
err := cmd.Run()
if err != nil {
return fmt.Sprintf("Error executing command:%v", err)

}

// 将命令输出转换为字符串
res := out.String()

// 打印结果
fmt.Println(res)
return res
}

func checkFileStat(filePath string) {
// 确认文件路径

fmt.Println("checkFileStat File path:", filePath)

// 检查文件是否存在
if _, err := os.Stat(filePath); os.IsNotExist(err) {
fmt.Println("File does not exist:", filePath)
} else {
fmt.Println("File exists:", filePath)
}
}

/// 辅助图片类

// Implements the color interface
type grayscaleFloat float32

func (f grayscaleFloat) RGBA() (r, g, b, a uint32) {
a = 0xffff
v := uint32(f * 0xffff)
if v > 0xffff {
v = 0xffff
}
r = v
g = v
b = v
return
}

// ProcessedImage Used to satisfy the image interface as well as to help with formatting and
// resizing an input image into the format expected as a network input.
type ProcessedImage struct {
// The number of "pixels" in the input image corresponding to a single
// pixel in the 28x28 output image.
dx, dy float32

// The input image being transformed
pic image.Image

// If true, the grayscale values in the postprocessed image will be
// inverted, so that dark colors in the original become light, and vice
// versa. Recall that the network expects black backgrounds, so this should
// be set to true for images with light backgrounds.
Invert bool
}

func (p *ProcessedImage) ColorModel() color.Model {
return color.Gray16Model
}

func (p *ProcessedImage) Bounds() image.Rectangle {
return image.Rect(0, 0, 28, 28)
}

// At Returns an average grayscale value using the pixels in the input image.
func (p *ProcessedImage) At(x, y int) color.Color {
if (x < 0) || (x >= 28) || (y < 0) || (y >= 28) {
return color.Black
}

// Compute the window of pixels in the input image we'll be averaging.
startX := int(float32(x) * p.dx)
endX := int(float32(x+1) * p.dx)
if endX == startX {
endX = startX + 1
}
startY := int(float32(y) * p.dy)
endY := int(float32(y+1) * p.dy)
if endY == startY {
endY = startY + 1
}

// Compute the average brightness over the window of pixels
var sum float32
var nPix int
for row := startY; row < endY; row++ {
for col := startX; col < endX; col++ {
c := p.pic.At(col, row)
grayValue := color.Gray16Model.Convert(c).(color.Gray16).Y
sum += float32(grayValue) / 0xffff
nPix++
}
}

brightness := grayscaleFloat(sum / float32(nPix))
if p.Invert {
brightness = 1.0 - brightness
}
return brightness
}

// GetNetworkInput Returns a slice of data that can be used as the input to the onnx network.
func (p *ProcessedImage) GetNetworkInput() []float32 {
toReturn := make([]float32, 0, 28*28)
for row := 0; row < 28; row++ {
for col := 0; col < 28; col++ {
c := float32(p.At(col, row).(grayscaleFloat))
toReturn = append(toReturn, c)
}
}
return toReturn
}
14 changes: 7 additions & 7 deletions extensions/functions/mnist/mnist.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
"zh_CN": "https://ekuiper.org/docs/zh/latest/sqls/functions/custom_functions.html"
},
"description": {
"en_US": "todo :Example plugin to demonstrate inferring Tensorflow lite model to label an image",
"zh_CN": "todo:示例插件,演示如何使用Tensorflow lite模型对图像进行标记推断"
"en_US": "Example plugin to demonstrate inferring onnx model to label an image",
"zh_CN": "示例插件,演示如何使用onnx模型对图像进行标记推断"
}
},
"name": "mnist",
Expand All @@ -22,16 +22,16 @@
"name": "mnist",
"example": "mnist(col1)",
"hint": {
"en_US": "todo Label an image by tensorflow lite model.",
"zh_CN": "todo采用 tensorflow lite 模型标记图片"
"en_US": "Inference the number in pic by onnx model.",
"zh_CN": "采用 onnx 模型推理图中数字"
},
"args": [
{
"name": "image",
"hidden": false,
"optional": false,
"control": "field",
"type": "string",
"type": "[]byte",
"hint": {
"en_US": "Input image",
"zh_CN": "输入图像"
Expand All @@ -45,8 +45,8 @@
"return": {
"type": "string",
"hint": {
"en_US": "Image Label",
"zh_CN": "图像标注"
"en_US": "Probability",
"zh_CN": "图像中的数字概率"
}
},
"node": {
Expand Down
64 changes: 2 additions & 62 deletions extensions/functions/mnist/mnist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"fmt"
mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/lf-edge/ekuiper/contract/v2/api"
ort "github.com/yalue/onnxruntime_go"
_ "image/gif"
Expand All @@ -11,10 +10,8 @@ import (
"os"
"sync"
"testing"
"time"
)

// todo 测试文件仿照tf lite
func Test_mnist_Exec(t *testing.T) {
type fields struct {
modelPath string
Expand Down Expand Up @@ -62,76 +59,19 @@ func Test_mnist_Exec(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
f := &MnistFunc{
modelPath: tt.fields.modelPath,
once: tt.fields.once,
once: sync.Once{},
inputShape: tt.fields.inputShape,
outputShape: tt.fields.outputShape,
sharedLibraryPath: tt.fields.sharedLibraryPath,
initModelError: tt.fields.initModelError,
}

out, got1 := f.Exec(tt.args.in0, tt.args.args)

if !got1 {
t.Errorf("Exec() error = %v, wantErr %v", got1, tt.want1)
}
fmt.Println(out)
})
}
}

/*
➜ mnist git:(torch_dev_swx) ✗ go test -v -cover
=== RUN Test_mnist_Exec
=== RUN Test_mnist_Exec/test1
Output probabilities:
0: 1.350922
1: 1.149244
2: 2.231948
3: 0.826893
4: -3.473754
5: 1.200287
6: -1.185765
7: -5.960128
8: 4.764542
9: -2.345179
probably a 8, with probability 4.764542
-----------------------------
true --- PASS: Test_mnist_Exec (0.03s)
--- PASS: Test_mnist_Exec/test1 (0.03s)
PASS
coverage: 58.7% of statements
ok github.com/lf-edge/ekuiper/v2/extensions/functions/mnist 0.030s
*/

func TestPic(t *testing.T) {
const TOPIC = "tfdmnist"

images := []string{
"img.png",
// 其他你需要的图像
}
opts := mqtt.NewClientOptions().AddBroker("tcp://localhost:1883")
client := mqtt.NewClient(opts)
if token := client.Connect(); token.Wait() && token.Error() != nil {
panic(token.Error())
}
for _, image := range images {
fmt.Println("Publishing " + image)
payload, err := os.ReadFile(image)
if err != nil {
fmt.Println(err)
continue
}
if token := client.Publish(TOPIC, 0, false, payload); token.Wait() && token.Error() != nil {
fmt.Println(token.Error())
} else {
fmt.Println("Published " + image)
}
time.Sleep(1 * time.Second)
}
client.Disconnect(0)
}
Loading

0 comments on commit 40ec8e2

Please sign in to comment.