From 506048a2cd82120ab454f5c2f0ed1e6df636e51d Mon Sep 17 00:00:00 2001 From: Laurent Demailly Date: Tue, 27 Aug 2024 15:40:49 -0700 Subject: [PATCH] Bug with some lambdas and json_go (#186) * Bug with some lambdas and json_go * println was swallowing errors, conversly log wasn't; fixed both * fix bug with json_go and functions, also don't html escape either * godoc * AI overlord made a few suggestions, I disliked s.ConvertError least * Better names * Fix up godoc --- eval/eval.go | 70 ++++++++++++++++++++++------------------- eval/stack.go | 15 ++++++--- extensions/extension.go | 49 ++++++++++++++++------------- main_test.txtar | 8 ++++- object/object.go | 22 ++++++++++--- 5 files changed, 99 insertions(+), 65 deletions(-) diff --git a/eval/eval.go b/eval/eval.go index 614f07b7..f3354f91 100644 --- a/eval/eval.go +++ b/eval/eval.go @@ -37,30 +37,30 @@ func (s *State) evalAssignment(right object.Object, node *ast.InfixExpression) o log.LogVf("eval assign %#v to %#v", right, id.Value()) return s.env.Set(id.Literal(), right) // Propagate possible error (constant setting). default: - return s.Error("assignment to non identifier: " + node.Left.Value().DebugString()) + return s.NewError("assignment to non identifier: " + node.Left.Value().DebugString()) } } func (s *State) evalIndexAssigment(which ast.Node, index, value object.Object) object.Object { if which.Value().Type() != token.IDENT { - return s.Error("index assignment to non identifier: " + which.Value().DebugString()) + return s.NewError("index assignment to non identifier: " + which.Value().DebugString()) } id, _ := which.(*ast.Identifier) val, ok := s.env.Get(id.Literal()) if !ok { - return s.Error("identifier not found: " + id.Literal()) + return s.NewError("identifier not found: " + id.Literal()) } switch val.Type() { case object.ARRAY: if index.Type() != object.INTEGER { - return s.Error("index assignment to array with non integer index: " + index.Inspect()) + return s.NewError("index assignment to array with non integer index: " + index.Inspect()) } idx := index.(object.Integer).Value if idx < 0 { idx = int64(object.Len(val)) + idx } if idx < 0 || idx >= int64(object.Len(val)) { - return s.Error("index assignment out of bounds: " + index.Inspect()) + return s.NewError("index assignment out of bounds: " + index.Inspect()) } elements := object.Elements(val) elements[idx] = value @@ -102,12 +102,12 @@ func (s *State) evalPrefixIncrDecr(operator token.Type, node ast.Node) object.Ob log.LogVf("eval prefix %s", ast.DebugString(node)) nv := node.Value() if nv.Type() != token.IDENT { - return s.Error("can't increment/decrement " + nv.DebugString()) + return s.NewError("can't increment/decrement " + nv.DebugString()) } id := nv.Literal() val, ok := s.env.Get(id) if !ok { - return s.Error("identifier not found: " + id) + return s.NewError("identifier not found: " + id) } toAdd := int64(1) if operator == token.DECR { @@ -119,7 +119,7 @@ func (s *State) evalPrefixIncrDecr(operator token.Type, node ast.Node) object.Ob case object.Float: return s.env.Set(id, object.Float{Value: val.Value + float64(toAdd)}) // So PI++ fails not silently. default: - return s.Error("can't increment/decrement " + val.Type().String()) + return s.NewError("can't increment/decrement " + val.Type().String()) } } @@ -128,7 +128,7 @@ func (s *State) evalPostfixExpression(node *ast.PostfixExpression) object.Object id := node.Prev.Literal() val, ok := s.env.Get(id) if !ok { - return s.Error("identifier not found: " + id) + return s.NewError("identifier not found: " + id) } var toAdd int64 switch node.Type() { @@ -137,7 +137,7 @@ func (s *State) evalPostfixExpression(node *ast.PostfixExpression) object.Object case token.DECR: toAdd = -1 default: - return s.Error("unknown postfix operator: " + node.Type().String()) + return s.NewError("unknown postfix operator: " + node.Type().String()) } var oerr object.Object switch val := val.(type) { @@ -146,7 +146,7 @@ func (s *State) evalPostfixExpression(node *ast.PostfixExpression) object.Object case object.Float: oerr = s.env.Set(id, object.Float{Value: val.Value + float64(toAdd)}) // So PI++ fails not silently. default: - return s.Error("can't increment/decrement " + val.Type().String()) + return s.NewError("can't increment/decrement " + val.Type().String()) } if oerr.Type() == object.ERROR { return oerr @@ -295,7 +295,7 @@ func (s *State) evalMapLiteral(node *ast.MapLiteral) object.Object { key := s.evalInternal(keyNode) if !object.Equals(key, key) { log.Warnf("key %s is not hashable", key.Inspect()) - return s.Error("key " + key.Inspect() + " is not hashable") + return s.NewError("key " + key.Inspect() + " is not hashable") } value := s.evalInternal(valueNode) result = result.Set(key, value) @@ -314,6 +314,10 @@ func (s *State) evalPrintLogError(node *ast.Builtin) object.Object { buf.WriteString(" ") } r := s.evalInternal(v) + // If what we print/println is an error, return it instead. log can log errors. + if r.Type() == object.ERROR && !doLog { + return r + } if isString := r.Type() == object.STRING; isString { buf.WriteString(r.(object.String).Value) } else { @@ -321,7 +325,7 @@ func (s *State) evalPrintLogError(node *ast.Builtin) object.Object { } } if node.Type() == token.ERROR { - return s.Error(buf.String()) + return s.NewError(buf.String()) } if (s.NoLog && doLog) || node.Type() == token.PRINTLN { buf.WriteRune('\n') // log() has a implicit newline when using log.Xxx, print() doesn't, println() does. @@ -362,7 +366,7 @@ func (s *State) evalBuiltin(node *ast.Builtin) object.Object { if minV > 0 { val = s.evalInternal(node.Parameters[0]) rt = val.Type() - if rt == object.ERROR { + if rt == object.ERROR && t != token.LOG { // log can log (and thus catch) errors. return val } } @@ -376,7 +380,7 @@ func (s *State) evalBuiltin(node *ast.Builtin) object.Object { case token.LEN: l := object.Len(val) if l == -1 { - return s.Error("len: not supported on " + val.Type().String()) + return s.NewError("len: not supported on " + val.Type().String()) } return object.Integer{Value: int64(l)} default: @@ -395,7 +399,7 @@ func (s *State) evalIndexRangeExpression(left object.Object, leftIdx, rightIdx a log.Debugf("eval index %s[%s:%s]", left.Inspect(), leftIndex.Inspect(), rightIndex.Inspect()) } if leftIndex.Type() != object.INTEGER || (!nilRight && rightIndex.Type() != object.INTEGER) { - return s.Error("range index not integer") + return s.NewError("range index not integer") } num := object.Len(left) l := leftIndex.(object.Integer).Value @@ -412,7 +416,7 @@ func (s *State) evalIndexRangeExpression(left object.Object, leftIdx, rightIdx a } } if l > r { - return s.Error("range index invalid: left greater then right") + return s.NewError("range index invalid: left greater then right") } l = min(l, int64(num)) r = min(r, int64(num)) @@ -427,7 +431,7 @@ func (s *State) evalIndexRangeExpression(left object.Object, leftIdx, rightIdx a case left.Type() == object.NIL: return object.NULL default: - return s.Error("range index operator not supported: " + left.Type().String()) + return s.NewError("range index operator not supported: " + left.Type().String()) } } @@ -455,7 +459,7 @@ func (s *State) evalIndexExpression(left, index object.Object) object.Object { case left.Type() == object.NIL: return object.NULL default: - return s.Error("index operator not supported: " + left.Type().String() + "[" + index.Type().String() + "]") + return s.NewError("index operator not supported: " + left.Type().String() + "[" + index.Type().String() + "]") } } @@ -525,7 +529,7 @@ func (s *State) applyExtension(fn object.Extension, args []object.Object) object func (s *State) applyFunction(name string, fn object.Object, args []object.Object) object.Object { function, ok := fn.(object.Function) if !ok { - return s.Error("not a function: " + fn.Type().String() + ":" + fn.Inspect()) + return s.NewError("not a function: " + fn.Type().String() + ":" + fn.Inspect()) } if v, output, ok := s.cache.Get(function.CacheKey, args); ok { log.Debugf("Cache hit for %s %v", function.CacheKey, args) @@ -646,7 +650,7 @@ func (s *State) evalIdentifier(node *ast.Identifier) object.Object { } val, ok = s.Extensions[node.Literal()] if !ok { - return s.Error("identifier not found: " + node.Literal()) + return s.NewError("identifier not found: " + node.Literal()) } return val } @@ -661,7 +665,7 @@ func (s *State) evalIfExpression(ie *ast.IfExpression) object.Object { log.LogVf("if %s is object.FALSE, picking else branch", ie.Condition.Value().DebugString()) return s.evalInternal(ie.Alternative) default: - return s.Error("condition is not a boolean: " + condition.Inspect()) + return s.NewError("condition is not a boolean: " + condition.Inspect()) } } @@ -700,12 +704,12 @@ func (s *State) evalPrefixExpression(operator token.Type, right object.Object) o if right.Type() == object.INTEGER { return object.Integer{Value: ^right.(object.Integer).Value} } - return s.Error("bitwise not of " + right.Inspect()) + return s.NewError("bitwise not of " + right.Inspect()) case token.PLUS: // nothing do with unary plus, just return the value. return right default: - return s.Error("unknown operator: " + operator.String()) + return s.NewError("unknown operator: " + operator.String()) } } @@ -718,7 +722,7 @@ func (s *State) evalBangOperatorExpression(right object.Object) object.Object { case object.NULL: return object.TRUE // allow !nil == true default: - return s.Error("not of " + right.Inspect()) + return s.NewError("not of " + right.Inspect()) } } @@ -731,7 +735,7 @@ func (s *State) evalMinusPrefixOperatorExpression(right object.Object) object.Ob value := right.(object.Float).Value return object.Float{Value: -value} default: - return s.Error("minus of " + right.Inspect()) + return s.NewError("minus of " + right.Inspect()) } } @@ -765,7 +769,7 @@ func (s *State) evalInfixExpression(operator token.Type, left, right object.Obje case left.Type() == object.MAP && right.Type() == object.MAP: return evalMapInfixExpression(operator, left, right) default: - return s.Error("no " + operator.String() + " on left=" + left.Inspect() + " right=" + right.Inspect()) + return s.NewError("no " + operator.String() + " on left=" + left.Inspect() + " right=" + right.Inspect()) } } @@ -791,12 +795,12 @@ func (s *State) evalArrayInfixExpression(operator token.Type, left, right object switch operator { case token.ASTERISK: // repeat if right.Type() != object.INTEGER { - return s.Error("right operand of * on arrays must be an integer") + return s.NewError("right operand of * on arrays must be an integer") } // TODO: go1.23 use slices.Repeat rightVal := right.(object.Integer).Value if rightVal < 0 { - return s.Error("right operand of * on arrays must be a positive integer") + return s.NewError("right operand of * on arrays must be a positive integer") } result := object.MakeObjectSlice(len(leftVal) * int(rightVal)) for range rightVal { @@ -860,7 +864,7 @@ func (s *State) evalIntegerInfixExpression(operator token.Type, left, right obje case token.COLON: lg := rightVal - leftVal if lg < 0 { - return s.Error("range index invalid: left greater then right") + return s.NewError("range index invalid: left greater then right") } arr := object.MakeObjectSlice(int(lg)) for i := leftVal; i < rightVal; i++ { @@ -868,7 +872,7 @@ func (s *State) evalIntegerInfixExpression(operator token.Type, left, right obje } return object.NewArray(arr) default: - return s.Error("unknown operator: " + operator.String()) + return s.NewError("unknown operator: " + operator.String()) } } @@ -879,7 +883,7 @@ func (s *State) getFloatValue(o object.Object) (float64, *object.Error) { case object.FLOAT: return o.(object.Float).Value, nil default: - e := s.Error("not converting to float: " + o.Type().String()) + e := s.NewError("not converting to float: " + o.Type().String()) return math.NaN(), &e } } @@ -906,6 +910,6 @@ func (s *State) evalFloatInfixExpression(operator token.Type, left, right object case token.PERCENT: return object.Float{Value: math.Mod(leftVal, rightVal)} default: - return s.Error("unknown operator: " + operator.String()) + return s.NewError("unknown operator: " + operator.String()) } } diff --git a/eval/stack.go b/eval/stack.go index e793e186..70c54fb8 100644 --- a/eval/stack.go +++ b/eval/stack.go @@ -23,14 +23,21 @@ func (s *State) Stack() []string { return stack } -// Creates a new error object with the given message and stack. -func (s *State) Error(msg string) object.Error { - if log.LogDebug() { +// NewError creates a new error object from a plain string. +// NewError will attach the stack trace to the Error object. +func (s *State) NewError(msg string) object.Error { + if log.LogVerbose() { log.LogVf("Error %q called", msg) } return object.Error{Value: msg, Stack: s.Stack()} } +// Errorf formats and create an object.Error using given format and args. func (s *State) Errorf(format string, args ...interface{}) object.Error { - return s.Error(fmt.Sprintf(format, args...)) + return s.NewError(fmt.Sprintf(format, args...)) +} + +// Error converts from a go error to an object.Error. +func (s *State) Error(err error) object.Error { + return s.NewError(err.Error()) } diff --git a/extensions/extension.go b/extensions/extension.go index 80ce94b0..21466fbb 100644 --- a/extensions/extension.go +++ b/extensions/extension.go @@ -3,6 +3,7 @@ package extensions import ( + "bytes" "encoding/json" "fmt" "io" @@ -143,7 +144,7 @@ func initInternal(c *Config) error { //nolint:funlen,gocognit,gocyclo,maintidx / MinArgs: 1, MaxArgs: 1, ArgTypes: []object.Type{object.ANY}, - Callback: object.ShortCallback(jsonSer), + Callback: jsonSer, } err = object.CreateFunction(jsonFn) if err != nil { @@ -154,7 +155,7 @@ func initInternal(c *Config) error { //nolint:funlen,gocognit,gocyclo,maintidx / MinArgs: 1, MaxArgs: 2, ArgTypes: []object.Type{object.ANY, object.STRING}, - Callback: object.ShortCallback(jsonSerGo), + Callback: jsonSerGo, Help: `optional indent e.g json_go(m, " ")`, } err = object.CreateFunction(jsonFn) @@ -346,7 +347,7 @@ func initInternal(c *Config) error { //nolint:funlen,gocognit,gocyclo,maintidx / case object.STRING: i, serr := strconv.ParseInt(o.(object.String).Value, 0, 64) if serr != nil { - return s.Error(serr.Error()) + return s.Error(serr) } return object.Integer{Value: i} default: @@ -375,28 +376,32 @@ func sprintf(args []object.Object) object.Object { return object.String{Value: res} } -func jsonSer(args []object.Object) object.Object { +func jsonSer(env any, _ string, args []object.Object) object.Object { + s := env.(*eval.State) w := strings.Builder{} err := args[0].JSON(&w) if err != nil { - return object.Error{Value: err.Error()} + return s.Error(err) } return object.String{Value: w.String()} } -func jsonSerGo(args []object.Object) object.Object { +func jsonSerGo(env any, _ string, args []object.Object) object.Object { + s := env.(*eval.State) v := args[0].Unwrap(true) var err error - var bytes []byte - if len(args) == 1 { - bytes, err = json.Marshal(v) - } else { - bytes, err = json.MarshalIndent(v, "", args[1].(object.String).Value) - } + var buf bytes.Buffer + encoder := json.NewEncoder(&buf) + if len(args) == 2 { + encoder.SetIndent("", args[1].(object.String).Value) + } + // Disable HTML escaping + encoder.SetEscapeHTML(false) + err = encoder.Encode(v) if err != nil { - return object.Error{Value: err.Error()} + return s.Error(err) } - return object.String{Value: string(bytes)} + return object.String{Value: buf.String()} } func evalFunc(env any, name string, args []object.Object) object.Object { @@ -404,7 +409,7 @@ func evalFunc(env any, name string, args []object.Object) object.Object { s := env.(*eval.State) res, err := eval.EvalString(s, str, name == "unjson" /* empty env */) if err != nil { - return s.Error(err.Error()) + return s.Error(err) } return res } @@ -436,17 +441,17 @@ func saveFunc(env any, _ string, args []object.Object) object.Object { s := env.(*eval.State) file, err := sanitizeFileName(args) if err != nil { - return s.Error(err.Error()) + return s.Error(err) } f, err := os.Create(file) if err != nil { - return s.Error(err.Error()) + return s.Error(err) } defer f.Close() // Write to file. n, err := s.SaveGlobals(f) if err != nil { - return s.Error(err.Error()) + return s.Error(err) } log.Infof("Saved %d ids/fns to: %s", n, file) return object.MakeQuad( @@ -458,21 +463,21 @@ func loadFunc(env any, _ string, args []object.Object) object.Object { file, err := sanitizeFileName(args) s := env.(*eval.State) if err != nil { - return s.Error(err.Error()) + return s.Error(err) } f, err := os.Open(file) if err != nil { - return s.Error(err.Error()) + return s.Error(err) } defer f.Close() all, err := io.ReadAll(f) if err != nil { - return s.Error(err.Error()) + return s.Error(err) } // Eval the content. res, err := eval.EvalString(env, string(all), false) if err != nil { - return s.Error(err.Error()) + return s.Error(err) } log.Infof("Read/evaluated: %s", file) return res diff --git a/main_test.txtar b/main_test.txtar index 0ef9eb68..5d2207be 100644 --- a/main_test.txtar +++ b/main_test.txtar @@ -240,7 +240,7 @@ grol -quiet -c 'println(json_go({"abc":42, 63:63, "x": {[3]:122, true:false} })) stdout '^{"63":63,"abc":42,"x":{"\[3\]":122,"true":false}}$' # pretty print variant -grol -quiet -c 'println(json_go({"abc":42, 63:63, "x": {[3]:122, true:false} }, " "))' +grol -quiet -c 'print(json_go({"abc":42, 63:63, "x": {[3]:122, true:false} }, " "))' cmp stdout json_output # returning a map in lambda shouldn't lose needed extra {} despite being solo argument @@ -267,6 +267,12 @@ stderr '^func level3\(\){level2\(\)}$' stderr 'Total 1 error' stderr '^$' +# json with some functions +grol -quiet -c 'f=()=>{{"k":"v"}}; println(json_go(f))' +!stderr 'json: unsupported type' +!stdout 'json: unsupported type' +stdout '^"\(\)=>{{\\"k\\":\\"v\\"}}"$' + -- json_output -- { "63": 63, diff --git a/object/object.go b/object/object.go index f6730290..936cd94c 100644 --- a/object/object.go +++ b/object/object.go @@ -2,6 +2,7 @@ package object import ( "cmp" + "errors" "fmt" "io" "slices" @@ -478,9 +479,15 @@ func (e Error) JSON(w io.Writer) error { _, err := fmt.Fprintf(w, `{"err":%q}`, e.Value) return err } -func (e Error) Unwrap(_ bool) any { return e } -func (e Error) Error() string { return e.Value } -func (e Error) Type() Type { return ERROR } + +func (e Error) Unwrap(forceStringKeys bool) any { + if forceStringKeys { + return e.Value + } + return errors.New(e.Value) +} +func (e Error) Error() string { return e.Value } +func (e Error) Type() Type { return ERROR } func (e Error) Inspect() string { if len(e.Stack) == 0 { return fmt.Sprintf("", e.Value) @@ -534,8 +541,13 @@ func WriteStrings(out *strings.Builder, list []Object, before, sep, after string out.WriteString(after) } -func (f Function) Unwrap(_ bool) any { return f } -func (f Function) Type() Type { return FUNC } +func (f Function) Unwrap(forceStringKeys bool) any { + if forceStringKeys { + return f.Inspect() + } + return f +} +func (f Function) Type() Type { return FUNC } // Must be called after the function is fully initialized. // Whether a function result should be cached doesn't depend on the Name,