diff --git a/extism_test.go b/extism_test.go index 2a35104..2a0e04f 100644 --- a/extism_test.go +++ b/extism_test.go @@ -215,19 +215,18 @@ func TestExit(t *testing.T) { assert.NotNil(t, err, fmt.Sprintf("err can't be nil. config: %v", config)) } - fmt.Printf("err: %v", err) assert.Equal(t, expected, actual, fmt.Sprintf("exit must be %v. config: '%v'", expected, config)) } } } -func TestHost(t *testing.T) { +func TestHost_simple(t *testing.T) { manifest := manifest("host.wasm") mult := HostFunction{ Name: "mult", Namespace: "env", - Callback: func(ctx context.Context, plugin *Plugin, userData interface{}, stack []uint64) { + Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) { a := api.DecodeI32(stack[0]) b := api.DecodeI32(stack[1]) @@ -251,6 +250,49 @@ func TestHost(t *testing.T) { } } +func TestHost_memory(t *testing.T) { + manifest := manifest("host_memory.wasm") + + mult := HostFunction{ + Name: "to_upper", + Namespace: "host", + Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) { + offset := stack[0] + buffer, err := plugin.ReadBytes(offset) + if err != nil { + panic(err) + } + + result := bytes.ToUpper(buffer) + plugin.Logf(Debug, "Result: %s", result) + + plugin.Free(offset) + + offset, err = plugin.WriteBytes(result) + if err != nil { + panic(err) + } + + stack[0] = offset + }, + Params: []api.ValueType{api.ValueTypeI64}, + Results: []api.ValueType{api.ValueTypeI64}, + } + + if plugin, ok := plugin(t, manifest, mult); ok { + defer plugin.Close() + + exit, output, err := plugin.Call("run_test", []byte("Frodo")) + + if assertCall(t, err, exit) { + actual := string(output) + expected := "HELLO FRODO!" + + assert.Equal(t, expected, actual) + } + } +} + func TestHTTP_allowed(t *testing.T) { manifest := manifest("http.wasm") manifest.AllowedHosts = []string{"jsonplaceholder.*.com"} diff --git a/host.go b/host.go index efdd66d..35d662c 100644 --- a/host.go +++ b/host.go @@ -45,7 +45,7 @@ const I64 = api.ValueTypeI64 // // To safely decode/encode values from/to the uint64 inputs/ouputs, users are encouraged to use // Wazero's api.EncodeXXX or api.DecodeXXX functions. -type HostFunctionCallback func(ctx context.Context, plugin *Plugin, userData interface{}, stack []uint64) +type HostFunctionCallback func(ctx context.Context, p *CurrentPlugin, userData interface{}, stack []uint64) // HostFunction represents a custom function defined by the host. // Here's an example multiplication function that loads operands from memory: @@ -53,7 +53,7 @@ type HostFunctionCallback func(ctx context.Context, plugin *Plugin, userData int // mult := HostFunction{ // Name: "mult", // Namespace: "env", -// Callback: func(ctx context.Context, plugin *Plugin, userData interface{}, stack []uint64) { +// Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) { // a := api.DecodeI32(stack[0]) // b := api.DecodeI32(stack[1]) // @@ -71,6 +71,106 @@ type HostFunction struct { UserData interface{} } +type CurrentPlugin struct { + plugin *Plugin +} + +func (p *Plugin) currentPlugin() *CurrentPlugin { + return &CurrentPlugin{p} +} + +func (p *CurrentPlugin) Log(level LogLevel, message string) { + p.plugin.Log(level, message) +} + +func (p *CurrentPlugin) Logf(level LogLevel, format string, args ...any) { + p.plugin.Logf(level, format, args...) +} + +// Memory returns the plugin's WebAssembly memory interface. +func (p *CurrentPlugin) Memory() api.Memory { + return p.plugin.Memory() +} + +// Alloc a new memory block of the given length, returning its offset +func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) { + out, err := p.plugin.Runtime.Extism.ExportedFunction("extism_alloc").Call(p.plugin.Runtime.ctx, uint64(n)) + if err != nil { + return 0, err + } else if len(out) != 1 { + return 0, fmt.Errorf("Expected 1 return, go %v.", len(out)) + } + + return uint64(out[0]), nil +} + +// Free the memory block specified by the given offset +func (p *CurrentPlugin) Free(offset uint64) error { + _, err := p.plugin.Runtime.Extism.ExportedFunction("extism_free").Call(p.plugin.Runtime.ctx, uint64(offset)) + if err != nil { + return err + } + + return nil +} + +// Length returns the number of bytes allocated at the specified offset +func (p *CurrentPlugin) Length(offs uint64) (uint64, error) { + out, err := p.plugin.Runtime.Extism.ExportedFunction("extism_length").Call(p.plugin.Runtime.ctx, uint64(offs)) + if err != nil { + return 0, err + } else if len(out) != 1 { + return 0, fmt.Errorf("Expected 1 return, go %v.", len(out)) + } + + return uint64(out[0]), nil +} + +// Write a string to wasm memory and return the offset +func (p *CurrentPlugin) WriteString(s string) (uint64, error) { + return p.WriteBytes([]byte(s)) +} + +// WriteBytes writes a string to wasm memory and return the offset +func (p *CurrentPlugin) WriteBytes(b []byte) (uint64, error) { + ptr, err := p.Alloc(uint64(len(b))) + if err != nil { + return 0, err + } + + ok := p.Memory().Write(uint32(ptr), b) + if !ok { + return 0, fmt.Errorf("Failed to write to memory.") + } + + return ptr, nil +} + +// ReadString reads a string from wasm memory +func (p *CurrentPlugin) ReadString(offset uint64) (string, error) { + buffer, err := p.ReadBytes(offset) + if err != nil { + return "", err + } + + return string(buffer), nil +} + +// ReadBytes reads a byte array from memory +func (p *CurrentPlugin) ReadBytes(offset uint64) ([]byte, error) { + length, err := p.Length(offset) + if err != nil { + return []byte{}, err + } + + buffer, ok := p.Memory().Read(uint32(offset), uint32(length)) + if !ok { + return []byte{}, fmt.Errorf("Invalid memory block") + } + + return buffer, nil +} + func buildHostModule(ctx context.Context, rt wazero.Runtime, name string, funcs []HostFunction) (api.Module, error) { builder := rt.NewHostModuleBuilder(name) @@ -83,7 +183,7 @@ func defineCustomHostFunctions(builder wazero.HostModuleBuilder, funcs []HostFun for _, f := range funcs { builder.NewFunctionBuilder().WithGoFunction(api.GoFunc(func(ctx context.Context, stack []uint64) { if plugin, ok := ctx.Value("plugin").(*Plugin); ok { - f.Callback(ctx, plugin, f.UserData, stack) + f.Callback(ctx, &CurrentPlugin{plugin}, f.UserData, stack) return } @@ -143,9 +243,11 @@ func buildEnvModule(ctx context.Context, rt wazero.Runtime, extism api.Module, f logFunc := func(name string, level LogLevel) { hostFunc(name, func(ctx context.Context, m api.Module, offset uint64) { if plugin, ok := ctx.Value("plugin").(*Plugin); ok { - extism := plugin.Runtime.Extism + message, err := plugin.currentPlugin().ReadString(offset) + if err != nil { + panic(fmt.Errorf("Failed to read log message from memory: %v", err)) + } - message := readString(extism, ctx, offset) plugin.Log(level, message) return @@ -165,9 +267,12 @@ func buildEnvModule(ctx context.Context, rt wazero.Runtime, extism api.Module, f func configGet(ctx context.Context, m api.Module, offset uint64) uint64 { if plugin, ok := ctx.Value("plugin").(*Plugin); ok { - extism := plugin.Runtime.Extism + cp := plugin.currentPlugin() - name := readString(extism, ctx, offset) + name, err := cp.ReadString(offset) + if err != nil { + panic(fmt.Errorf("Failed to read config name from memory: %v", err)) + } value, ok := plugin.Config[name] if !ok { @@ -175,7 +280,12 @@ func configGet(ctx context.Context, m api.Module, offset uint64) uint64 { return 0 } - return writeString(extism, ctx, value) + offset, err = cp.WriteString(value) + if err != nil { + panic(fmt.Errorf("Failed to write config value to memory: %v", err)) + } + + return offset } panic("Invalid context, `plugin` key not found") @@ -183,9 +293,12 @@ func configGet(ctx context.Context, m api.Module, offset uint64) uint64 { func varGet(ctx context.Context, m api.Module, offset uint64) uint64 { if plugin, ok := ctx.Value("plugin").(*Plugin); ok { - extism := plugin.Runtime.Extism + cp := plugin.currentPlugin() - name := readString(extism, ctx, offset) + name, err := cp.ReadString(offset) + if err != nil { + panic(fmt.Errorf("Failed to read var name from memory: %v", err)) + } value, ok := plugin.Var[name] if !ok { @@ -193,7 +306,12 @@ func varGet(ctx context.Context, m api.Module, offset uint64) uint64 { return 0 } - return writeBlock(extism, ctx, value) + offset, err = cp.WriteBytes(value) + if err != nil { + panic(fmt.Errorf("Failed to write var value to memory: %v", err)) + } + + return offset } panic("Invalid context, `plugin` key not found") @@ -205,9 +323,12 @@ func varSet(ctx context.Context, m api.Module, nameOffset uint64, valueOffset ui panic("Invalid context, `plugin` key not found") } - extism := plugin.Runtime.Extism + cp := plugin.currentPlugin() - name := readString(extism, ctx, nameOffset) + name, err := cp.ReadString(nameOffset) + if err != nil { + panic(fmt.Errorf("Failed to read var name from memory: %v", err)) + } size := 0 for _, v := range plugin.Var { @@ -223,18 +344,22 @@ func varSet(ctx context.Context, m api.Module, nameOffset uint64, valueOffset ui if valueOffset == 0 { delete(plugin.Var, name) } else { - value := readBlock(extism, ctx, valueOffset) + value, err := cp.ReadBytes(valueOffset) + if err != nil { + panic(fmt.Errorf("Failed to read var value from memory: %v", err)) + } + plugin.Var[name] = value } } func httpRequest(ctx context.Context, m api.Module, requestOffset uint64, bodyOffset uint64) uint64 { if plugin, ok := ctx.Value("plugin").(*Plugin); ok { - extism := plugin.Runtime.Extism + cp := plugin.currentPlugin() - requestJson := readBlock(extism, ctx, requestOffset) + requestJson, err := cp.ReadBytes(requestOffset) var request HttpRequest - err := json.Unmarshal(requestJson, &request) + err = json.Unmarshal(requestJson, &request) if err != nil { panic(fmt.Errorf("Invalid HTTP Request: %v", err)) } @@ -265,8 +390,13 @@ func httpRequest(ctx context.Context, m api.Module, requestOffset uint64, bodyOf var bodyReader io.Reader = nil if bodyOffset != 0 { - // TODO: do we need to call extism_free on the body? - body := readBlock(extism, ctx, bodyOffset) + body, err := cp.ReadBytes(bodyOffset) + if err != nil { + panic("Failed to read response body from memory") + } + + cp.Free(bodyOffset) + bodyReader = bytes.NewReader(body) } @@ -299,57 +429,22 @@ func httpRequest(ctx context.Context, m api.Module, requestOffset uint64, bodyOf if len(body) == 0 { return 0 } else { - return writeBlock(extism, ctx, body) + offset, err := cp.WriteBytes(body) + if err != nil { + panic("Failed to write resposne body to memory") + } + + return offset } } panic("Invalid context, `plugin` key not found") } -func httpStatusCode(ctx context.Context, m api.Module) uint32 { +func httpStatusCode(ctx context.Context, m api.Module) int32 { if plugin, ok := ctx.Value("plugin").(*Plugin); ok { - return uint32(plugin.LastStatusCode) + return int32(plugin.LastStatusCode) } panic("Invalid context, `plugin` key not found") } - -func writeString(extism api.Module, ctx context.Context, value string) uint64 { - return writeBlock(extism, ctx, []byte(value)) -} - -func writeBlock(extism api.Module, ctx context.Context, buffer []byte) uint64 { - res, err := extism.ExportedFunction("extism_alloc").Call(ctx, uint64(len(buffer))) - if err != nil { - panic(err) - } - - out := res[0] - mem := extism.Memory() - mem.Write(uint32(out), buffer) - - return out -} - -func readString(extism api.Module, ctx context.Context, offset uint64) string { - return string(readBlock(extism, ctx, offset)) -} - -func readBlock(extism api.Module, ctx context.Context, offset uint64) []byte { - blockLengthResult, err := extism.ExportedFunction("extism_length").Call(ctx, uint64(offset)) - if err != nil { - panic(err) - } else if len(blockLengthResult) != 1 { - panic(fmt.Errorf("Expected 1 value, got %v values", len(blockLengthResult))) - } - - blockLength := blockLengthResult[0] - - mem := extism.Memory() - buffer, ok := mem.Read(uint32(offset), uint32(blockLength)) - if !ok { - panic("Out of bounds read") - } - - return buffer -} diff --git a/plugins/host_memory/go.mod b/plugins/host_memory/go.mod new file mode 100644 index 0000000..0a16e05 --- /dev/null +++ b/plugins/host_memory/go.mod @@ -0,0 +1,8 @@ +module github.com/extism/extism-sdk-plugins-host-memory + +go 1.20 + +require ( + github.com/extism/go-pdk v0.0.0-20230119214914-65bffbeb3e64 // indirect + github.com/valyala/fastjson v1.6.3 // indirect +) diff --git a/plugins/host_memory/go.sum b/plugins/host_memory/go.sum new file mode 100644 index 0000000..dca507f --- /dev/null +++ b/plugins/host_memory/go.sum @@ -0,0 +1,4 @@ +github.com/extism/go-pdk v0.0.0-20230119214914-65bffbeb3e64 h1:IfR1k741q+yQLvv5sLShCkvt3FgKU4wQVJfp7hhb/iY= +github.com/extism/go-pdk v0.0.0-20230119214914-65bffbeb3e64/go.mod h1:1wdiAoG8306g4WK+6laBrS+75089/0V4XRVTllt8b5U= +github.com/valyala/fastjson v1.6.3 h1:tAKFnnwmeMGPbwJ7IwxcTPCNr3uIzoIj3/Fh90ra4xc= +github.com/valyala/fastjson v1.6.3/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= diff --git a/plugins/host_memory/main.go b/plugins/host_memory/main.go new file mode 100644 index 0000000..bf95d1b --- /dev/null +++ b/plugins/host_memory/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "fmt" + + "github.com/extism/go-pdk" +) + +//go:wasm-module host +//export to_upper +func to_upper(offset uint64) uint64 + +//export run_test +func run_test() int32 { + name := pdk.InputString() + + // Store the message in the wasm memory and get an pointer for the location + message := fmt.Sprintf("Hello %s!", name) + mem := pdk.AllocateString(message) + + pdk.Log(pdk.LogError, fmt.Sprintf("offset: %v, length: %v", mem.Offset(), mem.Length())) + + // Send the pointer of the message to to_upper and get back + // a new pointer for the new transformed message + offset := to_upper(mem.Offset()) + mem = pdk.FindMemory(offset) + + pdk.Log(pdk.LogError, fmt.Sprintf("offset: %v, length: %v", offset, mem.Length())) + + // zero-copy output to host + pdk.OutputMemory(mem) + + return 0 +} + +func main() {} diff --git a/wasm/host_memory.wasm b/wasm/host_memory.wasm new file mode 100644 index 0000000..a5907dc Binary files /dev/null and b/wasm/host_memory.wasm differ