Skip to content

Commit

Permalink
Merge pull request #4 from sarchlab/3-support-float-operation
Browse files Browse the repository at this point in the history
Support float32
  • Loading branch information
syifan authored Oct 3, 2023
2 parents b3801bf + fd68c67 commit ec2b860
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 37 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
skip-pkg-cache: true
version: "latest"
args: --timeout=10m

- name: Install Ginkgo
run: go install github.com/onsi/ginkgo/v2/ginkgo

Expand Down
104 changes: 72 additions & 32 deletions core/emu.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"math"
"strconv"
"strings"
)
Expand Down Expand Up @@ -29,20 +30,19 @@ func (i instEmulator) RunInst(inst string, state *coreState) {
if strings.Contains(instName, "CMP") {
instName = "CMP"
}
switch instName {
case "WAIT":
i.runWait(tokens, state)
case "SEND":
i.runSend(tokens, state)
case "JMP":
i.runJmp(tokens, state)
case "CMP":
i.runCmp(tokens, state)
case "JEQ":
i.runJeq(tokens, state)
case "DONE":
i.runDone()
default:

instFuncs := map[string]func([]string, *coreState){
"WAIT": i.runWait,
"SEND": i.runSend,
"JMP": i.runJmp,
"CMP": i.runCmp,
"JEQ": i.runJeq,
"DONE": func(_ []string, _ *coreState) { i.runDone() }, // Since runDone might not have parameters
}

if instFunc, ok := instFuncs[instName]; ok {
instFunc(tokens, state)
} else {
panic("unknown instruction " + inst)
}
}
Expand Down Expand Up @@ -135,40 +135,80 @@ func (i instEmulator) writeOperand(operand string, value uint32, state *coreStat
}

func (i instEmulator) runCmp(inst []string, state *coreState) {
Itype := inst[0]
//Float or Integer
switch {
case strings.Contains(Itype, "I"):
i.parseAndCompareI(inst, state)
case strings.Contains(Itype, "F32"):
i.parseAndCompareF32(inst, state)
default:
panic("invalid cmp")
}
}

func (i instEmulator) parseAndCompareI(inst []string, state *coreState) {
instruction := inst[0]
dst := inst[1]
src := inst[2]
//Pending for float type
//Float or Integer
// switch {
// case strings.Contains(instruction, "I"):
// imme, err := strconv.ParseUint(inst[3], 10, 32)
// }

srcVal := i.readOperand(src, state)
dstVal := uint32(0)
imme, err := strconv.ParseUint(inst[3], 10, 32)
if err != nil {
panic("invalid compare number")
}

srcVal := i.readOperand(src, state)
dstVal := uint32(0)
imme32 := uint32(imme)
immeI32 := int32(uint32(imme))
srcValI := int32(srcVal)

conditionFuncs := map[string]func(uint32, uint32) bool{
"EQ": func(a, b uint32) bool { return a == b },
"NE": func(a, b uint32) bool { return a != b },
"LT": func(a, b uint32) bool { return a < b },
"LE": func(a, b uint32) bool { return a <= b },
"GT": func(a, b uint32) bool { return a > b },
"GE": func(a, b uint32) bool { return a >= b },
conditionFuncs := map[string]func(int32, int32) bool{
"EQ": func(a, b int32) bool { return a == b },
"NE": func(a, b int32) bool { return a != b },
"LE": func(a, b int32) bool { return a <= b },
"LT": func(a, b int32) bool { return a < b },
"GT": func(a, b int32) bool { return a > b },
"GE": func(a, b int32) bool { return a >= b },
}

for key, function := range conditionFuncs {
if strings.Contains(instruction, key) && function(srcVal, imme32) {
if strings.Contains(instruction, key) && function(srcValI, immeI32) {
dstVal = 1
break
}
}
i.writeOperand(dst, dstVal, state)
state.PC++
}

func (i instEmulator) parseAndCompareF32(inst []string, state *coreState) {
instruction := inst[0]
dst := inst[1]
src := inst[2]

srcVal := i.readOperand(src, state)
dstVal := uint32(0)
imme, err := strconv.ParseUint(inst[3], 10, 32)
if err != nil {
panic("invalid compare number")
}

conditionFuncsF := map[string]func(float32, float32) bool{
"EQ": func(a, b float32) bool { return a == b },
"NE": func(a, b float32) bool { return a != b },
"LT": func(a, b float32) bool { return a < b },
"LE": func(a, b float32) bool { return a <= b },
"GT": func(a, b float32) bool { return a > b },
"GE": func(a, b float32) bool { return a >= b },
}

immeF32 := math.Float32frombits(uint32(imme))
srcValF := math.Float32frombits(srcVal)

for key, function := range conditionFuncsF {
if strings.Contains(instruction, key) && function(srcValF, immeF32) {
dstVal = 1
}
}
i.writeOperand(dst, dstVal, state)
state.PC++
}
Expand Down
43 changes: 39 additions & 4 deletions samples/relu/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package main
import (
_ "embed"
"fmt"
"math/rand"
"time"
"unsafe"

"github.com/sarchlab/akita/v3/sim"
"github.com/sarchlab/zeonica/api"
Expand All @@ -14,17 +17,32 @@ import (
var width = 16
var height = 16

//go:embed relu.cgraasm
// For float test, change reluI.cgraasm to reluF.cgraasm
//
//go:embed reluI.cgraasm
var program string

func relu(driver api.Driver) {
length := 16

rand.Seed(time.Now().UnixNano())
src := make([]uint32, length)
dst := make([]uint32, length)

//For float test
// minF := float32(-10.0)
// maxF := float32(10.0)
// for i := 0; i < length; i++ {
// FNum := minF + rand.Float32()*(maxF-minF)
// src[i] = *(*uint32)(unsafe.Pointer(&FNum))
// }

//For Int test
minI := int32(-10)
maxI := int32(10)
for i := 0; i < length; i++ {
src[i] = uint32(i)
INum := minI + rand.Int31n(maxI-minI+1)
src[i] = *(*uint32)(unsafe.Pointer(&INum))
}

driver.FeedIn(src, cgra.West, [2]int{0, height}, height)
Expand All @@ -38,8 +56,25 @@ func relu(driver api.Driver) {

driver.Run()

fmt.Println(src)
fmt.Println(dst)
//For float test
// srcF := make([]float32, length)
// dstF := make([]float32, length)
// for i := 0; i < length; i++ {
// srcF[i] = *(*float32)(unsafe.Pointer(&src[i]))
// dstF[i] = *(*float32)(unsafe.Pointer(&dst[i])) // Convert each element to float.
// }
// fmt.Println(srcF)
// fmt.Println(dstF)

//For int test
srcI := make([]int32, length)
dstI := make([]int32, length)
for i := 0; i < length; i++ {
srcI[i] = *(*int32)(unsafe.Pointer(&src[i]))
dstI[i] = *(*int32)(unsafe.Pointer(&dst[i])) // Convert each element to float.
}
fmt.Println(srcI)
fmt.Println(dstI)
}

func main() {
Expand Down
10 changes: 10 additions & 0 deletions samples/relu/reluF.cgraasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
WAIT, $0, NET_RECV_3
F32_CMP_LT, $1, $0, 0
JEQ, ELSE, $1, 1
IF:
SEND, NET_SEND_1, $0
JMP, END
ELSE:
SEND, NET_SEND_1, 0
END:
DONE,
2 changes: 1 addition & 1 deletion samples/relu/relu.cgraasm → samples/relu/reluI.cgraasm
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
WAIT, $0, NET_RECV_3
F_CMP_LT, $1, $0, 0
I_CMP_LT, $1, $0, 0
JEQ, ELSE, $1, 1
IF:
SEND, NET_SEND_1, $0
Expand Down

0 comments on commit ec2b860

Please sign in to comment.