diff --git a/cmd_scripting.go b/cmd_scripting.go index 9f5ce9ef..5a48f8b7 100644 --- a/cmd_scripting.go +++ b/cmd_scripting.go @@ -7,6 +7,7 @@ import ( "io" "strconv" "strings" + "sync" luajson "github.com/alicebob/gopher-json" lua "github.com/yuin/gopher-lua" @@ -21,6 +22,10 @@ func commandsScripting(m *Miniredis) { m.srv.Register("SCRIPT", m.cmdScript) } +var ( + parsedScripts = sync.Map{} +) + // Execute lua. Needs to run m.Lock()ed, from within withTx(). // Returns true if the lua was OK (and hence should be cached). func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []string) bool { @@ -91,13 +96,13 @@ func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []stri return 1 })) - l.DoString(protectGlobals) + _ = doScript(l, protectGlobals) l.Push(lua.LString("redis")) l.Call(1, 0) - if err := l.DoString(script); err != nil { - c.WriteError(errLuaParseError(err)) + if err := doScript(l, script); err != nil { + c.WriteError(err.Error()) return false } @@ -105,6 +110,42 @@ func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []stri return true } +// doScript pre-compiiles the given script into a Lua prototype, +// then executes the pre-compiled function against the given lua state. +// +// This is thread-safe. +func doScript(l *lua.LState, script string) error { + proto, err := compile(script) + if err != nil { + return fmt.Errorf(errLuaParseError(err)) + } + + lfunc := l.NewFunctionFromProto(proto) + l.Push(lfunc) + if err := l.PCall(0, lua.MultRet, nil); err != nil { + // ensure we wrap with the correct format. + return fmt.Errorf(errLuaParseError(err)) + } + + return nil +} + +func compile(script string) (*lua.FunctionProto, error) { + if val, ok := parsedScripts.Load(script); ok { + return val.(*lua.FunctionProto), nil + } + chunk, err := parse.Parse(strings.NewReader(script), "") + if err != nil { + return nil, err + } + proto, err := lua.Compile(chunk, "") + if err != nil { + return nil, err + } + parsedScripts.Store(script, proto) + return proto, nil +} + func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) { if len(args) < 2 { setDirty(c)