Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: fix compiledModule leak #1608

Merged
merged 8 commits into from
Aug 2, 2023
4 changes: 2 additions & 2 deletions internal/engine/compiler/compiler_controlflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ func TestCompiler_callIndirect_largeTypeIndex(t *testing.T) {

makeExecutable(code1.Bytes())
f := function{
parent: &compiledFunction{parent: &compiledModule{executable: code1}},
parent: &compiledFunction{parent: &compiledCode{executable: code1}},
codeInitialAddress: uintptr(unsafe.Pointer(&code1.Bytes()[0])),
moduleInstance: env.moduleInstance,
}
Expand Down Expand Up @@ -896,7 +896,7 @@ func TestCompiler_compileCall(t *testing.T) {

makeExecutable(code.Bytes())
me.functions = append(me.functions, function{
parent: &compiledFunction{parent: &compiledModule{executable: code}},
parent: &compiledFunction{parent: &compiledCode{executable: code}},
codeInitialAddress: uintptr(unsafe.Pointer(&code.Bytes()[0])),
moduleInstance: env.moduleInstance,
})
Expand Down
6 changes: 3 additions & 3 deletions internal/engine/compiler/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (j *compilerEnv) callEngine() *callEngine {
}

func (j *compilerEnv) exec(machineCode []byte) {
cm := new(compiledModule)
cm := &compiledModule{compiledCode: &compiledCode{}}
if err := cm.executable.Map(len(machineCode)); err != nil {
panic(err)
}
Expand All @@ -211,7 +211,7 @@ func (j *compilerEnv) exec(machineCode []byte) {
makeExecutable(executable)

f := &function{
parent: &compiledFunction{parent: cm},
parent: &compiledFunction{parent: cm.compiledCode},
codeInitialAddress: uintptr(unsafe.Pointer(&executable[0])),
moduleInstance: j.moduleInstance,
}
Expand Down Expand Up @@ -268,7 +268,7 @@ func newCompilerEnvironment() *compilerEnv {
Globals: []*wasm.GlobalInstance{},
Engine: me,
},
ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledModule{}}}),
ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledCode{}}}),
}
}

Expand Down
64 changes: 42 additions & 22 deletions internal/engine/compiler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ type (
// as the underlying memory region is accessed by assembly directly by using
// codesElement0Address.
functions []function

// Keep a reference to the compiled module to prevent the GC from reclaiming
// it while the code may still be needed.
module *compiledModule
}

// callEngine holds context per moduleEngine.Call, and shared across all the
Expand Down Expand Up @@ -130,11 +134,13 @@ type (
// initialFn is the initial function for this call engine.
initialFn *function

// Keep a reference to the compiled module to prevent the GC from reclaiming
// it while the code may still be needed.
module *compiledModule

// stackIterator provides a way to iterate over the stack for Listeners.
// It is setup and valid only during a call to a Listener hook.
stackIterator stackIterator

ensureTermination bool
}

// moduleContext holds the per-function call specific module information.
Expand Down Expand Up @@ -264,12 +270,27 @@ type (
}

compiledModule struct {
executable asm.CodeSegment
functions []compiledFunction
source *wasm.Module
// The data that need to be accessed by compiledFunction.parent are
// separated in an embedded field because we use finalizers to manage
// the lifecycle of compiledModule instances and having cyclic pointers
// prevents the Go runtime from calling them, which results in memory
// leaks since the memory mapped code segments cannot be released.
//
// The indirection guarantees that the finalizer set on compiledModule
// instances can run when all references are gone, and the Go GC can
// manage to reclaim the compiledCode when all compiledFunction objects
// referencing it have been freed.
*compiledCode
functions []compiledFunction

ensureTermination bool
}

compiledCode struct {
source *wasm.Module
executable asm.CodeSegment
}

// compiledFunction corresponds to a function in a module (not instantiated one). This holds the machine code
// compiled by wazero compiler.
compiledFunction struct {
Expand All @@ -282,7 +303,7 @@ type (
index wasm.Index
goFunc interface{}
listener experimental.FunctionListener
parent *compiledModule
parent *compiledCode
sourceOffsetMap sourceOffsetMap
}

Expand Down Expand Up @@ -496,13 +517,6 @@ func (e *engine) Close() (err error) {
e.mux.Lock()
defer e.mux.Unlock()
// Releasing the references to compiled codes including the memory-mapped machine codes.

for i := range e.codes {
for j := range e.codes[i].functions {
e.codes[i].functions[j].parent = nil
}
}

e.codes = nil
return
}
Expand All @@ -523,9 +537,11 @@ func (e *engine) CompileModule(_ context.Context, module *wasm.Module, listeners
var withGoFunc bool
localFuncs, importedFuncs := len(module.FunctionSection), module.ImportFunctionCount
cm := &compiledModule{
compiledCode: &compiledCode{
source: module,
},
functions: make([]compiledFunction, localFuncs),
ensureTermination: ensureTermination,
source: module,
}

if localFuncs == 0 {
Expand Down Expand Up @@ -559,7 +575,7 @@ func (e *engine) CompileModule(_ context.Context, module *wasm.Module, listeners
funcIndex := wasm.Index(i)
compiledFn := &cm.functions[i]
compiledFn.executableOffset = executable.Size()
compiledFn.parent = cm
compiledFn.parent = cm.compiledCode
compiledFn.index = importedFuncs + funcIndex
if i < ln {
compiledFn.listener = listeners[i]
Expand Down Expand Up @@ -628,6 +644,8 @@ func (e *engine) NewModuleEngine(module *wasm.Module, instance *wasm.ModuleInsta
parent: c,
}
}

me.module = cm
return me, nil
}

Expand Down Expand Up @@ -720,7 +738,7 @@ func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error {

func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) {
m := ce.initialFn.moduleInstance
if ce.ensureTermination {
if ce.module.ensureTermination {
select {
case <-ctx.Done():
// If the provided context is already done, close the call context
Expand All @@ -741,12 +759,14 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u
// If the module closed during the call, and the call didn't err for another reason, set an ExitError.
err = m.FailIfClosed()
}
// Ensure that the compiled module will never be GC'd before this method returns.
runtime.KeepAlive(ce.module)
}()

ft := ce.initialFn.funcType
ce.initializeStack(ft, params)

if ce.ensureTermination {
if ce.module.ensureTermination {
done := m.CloseModuleOnCanceledOrTimeout(ctx)
defer done()
}
Expand Down Expand Up @@ -959,11 +979,11 @@ var initialStackSize uint64 = 512

func (e *moduleEngine) newCallEngine(stackSize uint64, fn *function) *callEngine {
ce := &callEngine{
stack: make([]uint64, stackSize),
archContext: newArchContext(),
initialFn: fn,
moduleContext: moduleContext{fn: fn},
ensureTermination: fn.parent.parent.ensureTermination,
stack: make([]uint64, stackSize),
archContext: newArchContext(),
initialFn: fn,
moduleContext: moduleContext{fn: fn},
module: e.module,
}

stackHeader := (*reflect.SliceHeader)(unsafe.Pointer(&ce.stack))
Expand Down
2 changes: 1 addition & 1 deletion internal/engine/compiler/engine_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func BenchmarkCallEngine_builtinFunctionFunctionListener(b *testing.B) {
},
},
index: 0,
parent: &compiledModule{
parent: &compiledCode{
source: &wasm.Module{
TypeSection: []wasm.FunctionType{{}},
FunctionSection: []wasm.Index{0},
Expand Down
9 changes: 7 additions & 2 deletions internal/engine/compiler/engine_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
func (e *engine) deleteCompiledModule(module *wasm.Module) {
e.mux.Lock()
defer e.mux.Unlock()

delete(e.codes, module.ID)

// Note: we do not call e.Cache.Delete, as the lifetime of
Expand Down Expand Up @@ -158,14 +159,18 @@ func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser, modul

ensureTermination := header[cachedVersionEnd] != 0
functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:])
cm = &compiledModule{functions: make([]compiledFunction, functionsNum), ensureTermination: ensureTermination}
cm = &compiledModule{
compiledCode: new(compiledCode),
functions: make([]compiledFunction, functionsNum),
ensureTermination: ensureTermination,
}

imported := module.ImportFunctionCount

var eightBytes [8]byte
for i := uint32(0); i < functionsNum; i++ {
f := &cm.functions[i]
f.parent = cm
f.parent = cm.compiledCode

// Read the stack pointer ceil.
if f.stackPointerCeil, err = readUint64(reader, &eightBytes); err != nil {
Expand Down
54 changes: 35 additions & 19 deletions internal/engine/compiler/engine_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ func TestSerializeCompiledModule(t *testing.T) {
}{
{
in: &compiledModule{
executable: makeCodeSegment(1, 2, 3, 4, 5),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345},
},
Expand All @@ -57,11 +59,13 @@ func TestSerializeCompiledModule(t *testing.T) {
},
{
in: &compiledModule{
ensureTermination: true,
executable: makeCodeSegment(1, 2, 3, 4, 5),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345},
},
ensureTermination: true,
},
exp: concat(
[]byte(wazeroMagic),
Expand All @@ -77,12 +81,14 @@ func TestSerializeCompiledModule(t *testing.T) {
},
{
in: &compiledModule{
ensureTermination: true,
executable: makeCodeSegment(1, 2, 3, 4, 5, 1, 2, 3),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5, 1, 2, 3),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345},
{executableOffset: 5, stackPointerCeil: 0xffffffff},
},
ensureTermination: true,
},
exp: concat(
[]byte(wazeroMagic),
Expand Down Expand Up @@ -159,7 +165,9 @@ func TestDeserializeCompiledModule(t *testing.T) {
[]byte{1, 2, 3, 4, 5}, // machine code.
),
expCompiledModule: &compiledModule{
executable: makeCodeSegment(1, 2, 3, 4, 5),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345, index: 0},
},
Expand All @@ -181,9 +189,11 @@ func TestDeserializeCompiledModule(t *testing.T) {
[]byte{1, 2, 3, 4, 5}, // code.
),
expCompiledModule: &compiledModule{
ensureTermination: true,
executable: makeCodeSegment(1, 2, 3, 4, 5),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
},
functions: []compiledFunction{{executableOffset: 0, stackPointerCeil: 12345, index: 0}},
ensureTermination: true,
},
expStaleCache: false,
expErr: "",
Expand All @@ -208,7 +218,9 @@ func TestDeserializeCompiledModule(t *testing.T) {
),
importedFunctionCount: 1,
expCompiledModule: &compiledModule{
executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
},
functions: []compiledFunction{
{executableOffset: 0, stackPointerCeil: 12345, index: 1},
{executableOffset: 7, stackPointerCeil: 0xffffffff, index: 2},
Expand Down Expand Up @@ -279,8 +291,8 @@ func TestDeserializeCompiledModule(t *testing.T) {
if tc.expCompiledModule != nil {
require.Equal(t, len(tc.expCompiledModule.functions), len(cm.functions))
for i := 0; i < len(cm.functions); i++ {
require.Equal(t, cm, cm.functions[i].parent)
tc.expCompiledModule.functions[i].parent = cm
require.Equal(t, cm.compiledCode, cm.functions[i].parent)
tc.expCompiledModule.functions[i].parent = cm.compiledCode
}
}

Expand Down Expand Up @@ -361,13 +373,13 @@ func TestEngine_getCompiledModuleFromCache(t *testing.T) {
},
expHit: true,
expCompiledModule: &compiledModule{
executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
},
functions: []compiledFunction{
{stackPointerCeil: 12345, executableOffset: 0, index: 0},
{stackPointerCeil: 0xffffffff, executableOffset: 5, index: 1},
},
source: nil,
ensureTermination: false,
},
},
}
Expand All @@ -379,7 +391,7 @@ func TestEngine_getCompiledModuleFromCache(t *testing.T) {
if exp := tc.expCompiledModule; exp != nil {
exp.source = m
for i := range tc.expCompiledModule.functions {
tc.expCompiledModule.functions[i].parent = exp
tc.expCompiledModule.functions[i].parent = exp.compiledCode
}
}

Expand Down Expand Up @@ -422,8 +434,10 @@ func TestEngine_addCompiledModuleToCache(t *testing.T) {
tc := filecache.New(t.TempDir())
e := engine{fileCache: tc}
cm := &compiledModule{
executable: makeCodeSegment(1, 2, 3),
functions: []compiledFunction{{stackPointerCeil: 123}},
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3),
},
functions: []compiledFunction{{stackPointerCeil: 123}},
}
m := &wasm.Module{ID: sha256.Sum256(nil), IsHostModule: true} // Host module!
err := e.addCompiledModuleToCache(m, cm)
Expand All @@ -438,8 +452,10 @@ func TestEngine_addCompiledModuleToCache(t *testing.T) {
e := engine{fileCache: tc}
m := &wasm.Module{}
cm := &compiledModule{
executable: makeCodeSegment(1, 2, 3),
functions: []compiledFunction{{stackPointerCeil: 123}},
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3),
},
functions: []compiledFunction{{stackPointerCeil: 123}},
}
err := e.addCompiledModuleToCache(m, cm)
require.NoError(t, err)
Expand Down
Loading
Loading