Skip to content

Commit

Permalink
Precompile lua functions to prevent wasted CPU cycles
Browse files Browse the repository at this point in the history
`miniredis.runLuaScript` calls lua.*LState.DoString() for every lua
invocation, even if scripts are evaluated via checksums.

In turn, lua parses and compiles the same scripts on each call, which
can cause high CPU usage on script-heavy test instances.

This PR uses `sync.Map` to store pre-compiled lua functions for each
script, using the same logic as the Lua library itself:
tul/gopher-lua@7a6135d

This reduces CPU usage by about 50%.
  • Loading branch information
tonyhb committed May 19, 2023
1 parent 33a1bb4 commit 536e717
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions cmd_scripting.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"strconv"
"strings"
"sync"

luajson "github.com/alicebob/gopher-json"
lua "github.com/yuin/gopher-lua"
Expand All @@ -21,6 +22,10 @@ func commandsScripting(m *Miniredis) {
m.srv.Register("SCRIPT", m.cmdScript)
}

var (
parsedScripts = sync.Map{}
)

// Execute lua. Needs to run m.Lock()ed, from within withTx().
// Returns true if the lua was OK (and hence should be cached).
func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []string) bool {
Expand Down Expand Up @@ -91,20 +96,56 @@ func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []stri
return 1
}))

l.DoString(protectGlobals)
_ = doScript(l, protectGlobals)

l.Push(lua.LString("redis"))
l.Call(1, 0)

if err := l.DoString(script); err != nil {
c.WriteError(errLuaParseError(err))
if err := doScript(l, script); err != nil {
c.WriteError(err.Error())
return false
}

luaToRedis(l, c, l.Get(1))
return true
}

// doScript pre-compiiles the given script into a Lua prototype,
// then executes the pre-compiled function against the given lua state.
//
// This is thread-safe.
func doScript(l *lua.LState, script string) error {
proto, err := compile(script)
if err != nil {
return fmt.Errorf(errLuaParseError(err))
}

lfunc := l.NewFunctionFromProto(proto)
l.Push(lfunc)
if err := l.PCall(0, lua.MultRet, nil); err != nil {
// ensure we wrap with the correct format.
return fmt.Errorf(errLuaParseError(err))
}

return nil
}

func compile(script string) (*lua.FunctionProto, error) {
if val, ok := parsedScripts.Load(script); ok {
return val.(*lua.FunctionProto), nil
}
chunk, err := parse.Parse(strings.NewReader(script), "<string>")
if err != nil {
return nil, err
}
proto, err := lua.Compile(chunk, "")
if err != nil {
return nil, err
}
parsedScripts.Store(script, proto)
return proto, nil
}

func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
Expand Down

0 comments on commit 536e717

Please sign in to comment.