From 33b6dcfa531556502e226c24090622fb3016ff68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Barzowski?= Date: Tue, 7 May 2019 22:43:55 +0200 Subject: [PATCH] Keep object locals only once in AST (#263) Keep object locals only once in AST For example this reduces the size of stdlib ast file roughly 3x. Note that this change doesn't regenerate the stdlib, so that the diff here is sane. It is likely to slightly improve performance of code using a lot of locals (~10% on bench.05.gen.jsonnet). The desugaring is more strightforward now, and we're back to desugaring each node exactly once. --- ast/ast.go | 1 + builtins.go | 24 +- desugarer.go | 251 ++++++++---------- interpreter.go | 18 +- static_analyzer.go | 39 ++- testdata/local_in_object_assertion.golden | 1 + testdata/local_in_object_assertion.jsonnet | 1 + testdata/local_within_object.golden | 1 + testdata/local_within_object.jsonnet | 1 + testdata/object_local_from_parent.golden | 5 + testdata/object_local_from_parent.jsonnet | 7 + ...ect_local_from_parent_through_local.golden | 5 + ...ct_local_from_parent_through_local.jsonnet | 8 + testdata/object_local_recursive.golden | 22 ++ testdata/object_local_recursive.jsonnet | 7 + ...bject_local_uses_local_from_outside.golden | 3 + ...ject_local_uses_local_from_outside.jsonnet | 1 + value.go | 56 +++- 18 files changed, 280 insertions(+), 171 deletions(-) create mode 100644 testdata/local_in_object_assertion.golden create mode 100644 testdata/local_in_object_assertion.jsonnet create mode 100644 testdata/local_within_object.golden create mode 100644 testdata/local_within_object.jsonnet create mode 100644 testdata/object_local_from_parent.golden create mode 100644 testdata/object_local_from_parent.jsonnet create mode 100644 testdata/object_local_from_parent_through_local.golden create mode 100644 testdata/object_local_from_parent_through_local.jsonnet create mode 100644 testdata/object_local_recursive.golden create mode 100644 testdata/object_local_recursive.jsonnet create mode 100644 testdata/object_local_uses_local_from_outside.golden create mode 100644 testdata/object_local_uses_local_from_outside.jsonnet diff --git a/ast/ast.go b/ast/ast.go index 77c9c06c..c42626ff 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -549,6 +549,7 @@ type DesugaredObject struct { NodeBase Asserts Nodes Fields DesugaredObjectFields + Locals LocalBinds } // --------------------------------------------------------------------------- diff --git a/builtins.go b/builtins.go index 73f6899b..cc666961 100644 --- a/builtins.go +++ b/builtins.go @@ -810,6 +810,9 @@ func builtinStrReplace(i *interpreter, trace TraceElement, strv, fromv, tov valu } func builtinUglyObjectFlatMerge(i *interpreter, trace TraceElement, x value) (value, error) { + // TODO(sbarzowski) consider keeping comprehensions in AST + // It will probably be way less hacky, with better error messages and better performance + objarr, err := i.getArray(x, trace) if err != nil { return nil, err @@ -818,6 +821,8 @@ func builtinUglyObjectFlatMerge(i *interpreter, trace TraceElement, x value) (va return &valueSimpleObject{}, nil } newFields := make(simpleObjectFieldMap) + var locals []objectLocal + var upValues bindingFrame for _, elem := range objarr.elements { obj, err := i.evaluateObject(elem, trace) if err != nil { @@ -825,10 +830,22 @@ func builtinUglyObjectFlatMerge(i *interpreter, trace TraceElement, x value) (va } // starts getting ugly - we mess with object internals simpleObj := obj.(*valueSimpleObject) + // there is only one field, really for fieldName, fieldVal := range simpleObj.fields { if _, alreadyExists := newFields[fieldName]; alreadyExists { return nil, i.Error(duplicateFieldNameErrMsg(fieldName), trace) } + + // Here is the tricky part. Each field in a comprehension has different + // upValues, because for example in {[v]: v for v in ["x", "y", "z"] }, + // the v is different for each field. + // Yet, even though upValues are field-specific, they are shadowed by object locals, + // so we need to make holes to let them pass through + upValues := simpleObj.upValues + for _, l := range simpleObj.locals { + delete(upValues, l.name) + } + newFields[fieldName] = simpleObjectField{ hide: fieldVal.hide, field: &bindingsUnboundField{ @@ -836,12 +853,17 @@ func builtinUglyObjectFlatMerge(i *interpreter, trace TraceElement, x value) (va bindings: simpleObj.upValues, }, } + // another ugliness - we just take the locals of our last object, + // we assume that the locals are the same for each of merged objects + locals = simpleObj.locals } } + return makeValueSimpleObject( - nil, // no binding frame + upValues, newFields, []unboundField{}, // No asserts allowed + locals, ), nil } diff --git a/desugarer.go b/desugarer.go index c11733ed..f602353a 100644 --- a/desugarer.go +++ b/desugarer.go @@ -90,27 +90,7 @@ func stringUnescape(loc *ast.LocationRange, s string) (string, error) { return buf.String(), nil } -func desugarFields(location ast.LocationRange, fields *ast.ObjectFields, objLevel int) error { - // Simplify asserts - for i := range *fields { - field := &(*fields)[i] - if field.Kind != ast.ObjectAssert { - continue - } - msg := field.Expr3 - if msg == nil { - msg = buildLiteralString("Object assertion failed.") - } - field.Expr3 = nil - onFailure := &ast.Error{Expr: msg} - assertion := &ast.Conditional{ - Cond: field.Expr2, - BranchTrue: &ast.LiteralBoolean{Value: true}, // ignored anyway - BranchFalse: onFailure, - } - field.Expr2 = assertion - } - +func desugarFields(nodeBase ast.NodeBase, fields *ast.ObjectFields, objLevel int) (*ast.DesugaredObject, error) { for i := range *fields { field := &((*fields)[i]) if field.Method == nil { @@ -121,54 +101,88 @@ func desugarFields(location ast.LocationRange, fields *ast.ObjectFields, objLeve // Body of the function already desugared through expr2 } - // Remove object-level locals - newFields := []ast.ObjectField{} - for _, field := range *fields { - if field.Kind == ast.ObjectLocal { - continue - } - var binds ast.LocalBinds - for _, local := range *fields { - if local.Kind != ast.ObjectLocal { - continue - } - binds = append(binds, ast.LocalBind{Variable: *local.Id, Body: ast.Clone(local.Expr2)}) - } - if len(binds) > 0 { - field.Expr2 = &ast.Local{ - NodeBase: ast.NewNodeBaseLoc(*field.Expr2.Loc()), - Binds: binds, - Body: field.Expr2, - } - } - newFields = append(newFields, field) - } - *fields = newFields + asserts := ast.Nodes{} + locals := ast.LocalBinds{} + desugaredFields := ast.DesugaredObjectFields{} - // Change all to FIELD_EXPR for i := range *fields { field := &(*fields)[i] switch field.Kind { case ast.ObjectAssert: - // Nothing to do. - + msg := field.Expr3 + if msg == nil { + msg = buildLiteralString("Object assertion failed.") + } + onFailure := &ast.Error{Expr: msg} + asserts = append(asserts, &ast.Conditional{ + Cond: field.Expr2, + BranchTrue: &ast.LiteralBoolean{Value: true}, // ignored anyway + BranchFalse: onFailure, + }) case ast.ObjectFieldID: - field.Expr1 = makeStr(string(*field.Id)) - field.Kind = ast.ObjectFieldExpr - - case ast.ObjectFieldExpr: - // Nothing to do. + desugaredFields = append(desugaredFields, ast.DesugaredObjectField{ + Hide: field.Hide, + Name: makeStr(string(*field.Id)), + Body: field.Expr2, + PlusSuper: field.SuperSugar, + }) - case ast.ObjectFieldStr: - // Just set the flag. - field.Kind = ast.ObjectFieldExpr + case ast.ObjectFieldExpr, ast.ObjectFieldStr: + desugaredFields = append(desugaredFields, ast.DesugaredObjectField{ + Hide: field.Hide, + Name: field.Expr1, + Body: field.Expr2, + PlusSuper: field.SuperSugar, + }) case ast.ObjectLocal: - return fmt.Errorf("INTERNAL ERROR: Locals should be removed by now") + locals = append(locals, ast.LocalBind{ + Variable: *field.Id, + Body: ast.Clone(field.Expr2), // TODO(sbarzowski) not sure if clone is needed + }) + default: + panic(fmt.Sprintf("Unexpected object field kind %v", field.Kind)) } + } - return nil + // Hidden variable to allow $ binding. + if objLevel == 0 { + locals = append(locals, ast.LocalBind{ + Variable: ast.Identifier("$"), + Body: &ast.Self{}, + }) + } + + // Desugar stuff inside + for i := range asserts { + assert := &(asserts[i]) + err := desugar(assert, objLevel+1) + if err != nil { + return nil, err + } + } + desugarLocalBinds(locals, objLevel+1) + for i := range desugaredFields { + field := &(desugaredFields[i]) + if field.Name != nil { + err := desugar(&field.Name, objLevel) + if err != nil { + return nil, err + } + } + err := desugar(&field.Body, objLevel+1) + if err != nil { + return nil, err + } + } + + return &ast.DesugaredObject{ + NodeBase: nodeBase, + Asserts: asserts, + Locals: locals, + Fields: desugaredFields, + }, nil } func simpleLambda(body ast.Node, paramName ast.Identifier) ast.Node { @@ -182,13 +196,18 @@ func buildAnd(left ast.Node, right ast.Node) ast.Node { return &ast.Binary{Op: ast.BopAnd, Left: left, Right: right} } -func desugarForSpec(inside ast.Node, forSpec *ast.ForSpec) (ast.Node, error) { +// inside is assumed to be already desugared (and cannot be desugared again) +func desugarForSpec(inside ast.Node, forSpec *ast.ForSpec, objLevel int) (ast.Node, error) { var body ast.Node if len(forSpec.Conditions) > 0 { cond := forSpec.Conditions[0].Expr for i := 1; i < len(forSpec.Conditions); i++ { cond = buildAnd(cond, forSpec.Conditions[i].Expr) } + err := desugar(&cond, objLevel) + if err != nil { + return nil, err + } body = &ast.Conditional{ Cond: cond, BranchTrue: inside, @@ -198,11 +217,15 @@ func desugarForSpec(inside ast.Node, forSpec *ast.ForSpec) (ast.Node, error) { body = inside } function := simpleLambda(body, forSpec.VarName) + err := desugar(&forSpec.Expr, objLevel) + if err != nil { + return nil, err + } current := buildStdCall("flatMap", function, forSpec.Expr) if forSpec.Outer == nil { return current, nil } - return desugarForSpec(current, forSpec.Outer) + return desugarForSpec(current, forSpec.Outer, objLevel) } func wrapInArray(inside ast.Node) ast.Node { @@ -210,27 +233,21 @@ func wrapInArray(inside ast.Node) ast.Node { } func desugarArrayComp(comp *ast.ArrayComp, objLevel int) (ast.Node, error) { - return desugarForSpec(wrapInArray(comp.Body), &comp.Spec) + return desugarForSpec(wrapInArray(comp.Body), &comp.Spec, objLevel) } func desugarObjectComp(comp *ast.ObjectComp, objLevel int) (ast.Node, error) { - - if objLevel == 0 { - dollar := ast.Identifier("$") - comp.Fields = append(comp.Fields, ast.ObjectFieldLocalNoMethod(&dollar, &ast.Self{})) - } - - err := desugarFields(*comp.Loc(), &comp.Fields, objLevel+1) + obj, err := desugarFields(comp.NodeBase, &comp.Fields, objLevel) if err != nil { return nil, err } - if len(comp.Fields) != 1 { + if len(obj.Fields) != 1 { panic("Too many fields in object comprehension, it should have been caught during parsing") } arrComp := ast.ArrayComp{ - Body: buildDesugaredObject(comp.NodeBase, comp.Fields), + Body: obj, Spec: comp.Spec, } @@ -253,7 +270,7 @@ func buildLiteralString(value string) ast.Node { func buildSimpleIndex(obj ast.Node, member ast.Identifier) ast.Node { return &ast.Index{ Target: obj, - Id: &member, + Index: buildLiteralString(string(member)), } } @@ -266,30 +283,21 @@ func buildStdCall(builtinName ast.Identifier, args ...ast.Node) ast.Node { } } -func buildDesugaredObject(nodeBase ast.NodeBase, fields ast.ObjectFields) *ast.DesugaredObject { - var newFields ast.DesugaredObjectFields - var newAsserts ast.Nodes - - for _, field := range fields { - if field.Kind == ast.ObjectAssert { - newAsserts = append(newAsserts, field.Expr2) - } else if field.Kind == ast.ObjectFieldExpr { - newFields = append(newFields, ast.DesugaredObjectField{ - Hide: field.Hide, - Name: field.Expr1, - Body: field.Expr2, - PlusSuper: field.SuperSugar, - }) - } else { - panic(fmt.Sprintf("INTERNAL ERROR: field should have been desugared: %v", field.Kind)) +func desugarLocalBinds(binds ast.LocalBinds, objLevel int) (err error) { + for i := range binds { + if binds[i].Fun != nil { + binds[i] = ast.LocalBind{ + Variable: binds[i].Variable, + Body: binds[i].Fun, + Fun: nil, + } + } + err = desugar(&binds[i].Body, objLevel) + if err != nil { + return } } - - return &ast.DesugaredObject{ - NodeBase: nodeBase, - Asserts: newAsserts, - Fields: newFields, - } + return nil } // Desugar Jsonnet expressions to reduce the number of constructs the rest of the implementation @@ -493,18 +501,9 @@ func desugar(astPtr *ast.Node, objLevel int) (err error) { } case *ast.Local: - for i := range node.Binds { - if node.Binds[i].Fun != nil { - node.Binds[i] = ast.LocalBind{ - Variable: node.Binds[i].Variable, - Body: node.Binds[i].Fun, - Fun: nil, - } - } - err = desugar(&node.Binds[i].Body, objLevel) - if err != nil { - return - } + err = desugarLocalBinds(node.Binds, objLevel) + if err != nil { + return } err = desugar(&node.Body, objLevel) if err != nil { @@ -531,55 +530,21 @@ func desugar(astPtr *ast.Node, objLevel int) (err error) { node.Kind = ast.StringDouble node.BlockIndent = "" case *ast.Object: - // Hidden variable to allow $ binding. - if objLevel == 0 { - dollar := ast.Identifier("$") - node.Fields = append(node.Fields, ast.ObjectFieldLocalNoMethod(&dollar, &ast.Self{})) - } - - err = desugarFields(*node.Loc(), &node.Fields, objLevel) - if err != nil { - return - } - - *astPtr = buildDesugaredObject(node.NodeBase, node.Fields) - err = desugar(astPtr, objLevel) + *astPtr, err = desugarFields(node.NodeBase, &node.Fields, objLevel) if err != nil { return } case *ast.DesugaredObject: - for i := range node.Fields { - field := &((node.Fields)[i]) - if field.Name != nil { - err := desugar(&field.Name, objLevel) - if err != nil { - return err - } - } - err := desugar(&field.Body, objLevel+1) - if err != nil { - return err - } - } - for i := range node.Asserts { - assert := &((node.Asserts)[i]) - err := desugar(assert, objLevel+1) - if err != nil { - return err - } - } + // Desugaring something multiple times is a bad idea. + // All functions here should desugar nodes in one go. + panic("Desugaring desugared object") case *ast.ObjectComp: - comp, err := desugarObjectComp(node, objLevel) - if err != nil { - return err - } - err = desugar(&comp, objLevel) + *astPtr, err = desugarObjectComp(node, objLevel) if err != nil { return err } - *astPtr = comp case *ast.Parens: *astPtr = node.Inner diff --git a/interpreter.go b/interpreter.go index 599b7f84..86ac23cb 100644 --- a/interpreter.go +++ b/interpreter.go @@ -203,6 +203,14 @@ func (s *callStack) getCurrentEnv(ast ast.Node) environment { ) } +func (s *callStack) getTopEnv() environment { + top := s.stack[len(s.stack)-1] + if !top.isCall { + panic("getTopEnv is allowed only for artifical nodes which are called in new environment") + } + return top.env +} + // Build a binding frame containing specified variables. func (s *callStack) capture(freeVars ast.Identifiers) bindingFrame { env := make(bindingFrame) @@ -392,8 +400,12 @@ func (i *interpreter) evaluate(a ast.Node, tc tailCallStatus) (value, error) { for _, assert := range node.Asserts { asserts = append(asserts, &codeUnboundField{assert}) } + var locals []objectLocal + for _, local := range node.Locals { + locals = append(locals, objectLocal{name: local.Variable, node: local.Body}) + } upValues := i.stack.capture(node.FreeVariables()) - return makeValueSimpleObject(upValues, fields, asserts), nil + return makeValueSimpleObject(upValues, fields, asserts, locals), nil case *ast.Error: msgVal, err := i.evaluate(node.Expr, nonTailCall) @@ -561,7 +573,7 @@ func (i *interpreter) evaluate(a ast.Node, tc tailCallStatus) (value, error) { return i.evaluateTailCall(node.function, arguments, tc, trace) default: - return nil, i.Error(fmt.Sprintf("Executing this AST type not implemented: %v", reflect.TypeOf(a)), trace) + panic(fmt.Sprintf("Executing this AST type not implemented: %v", reflect.TypeOf(a))) } } @@ -1128,7 +1140,7 @@ func buildObject(hide ast.ObjectFieldHide, fields map[string]value) valueObject for name, v := range fields { fieldMap[name] = simpleObjectField{hide, &readyValue{v}} } - return makeValueSimpleObject(bindingFrame{}, fieldMap, nil) + return makeValueSimpleObject(bindingFrame{}, fieldMap, nil, nil) } func buildInterpreter(ext vmExtMap, nativeFuncs map[string]*NativeFunction, maxStack int, importer Importer) (*interpreter, error) { diff --git a/static_analyzer.go b/static_analyzer.go index 88d93779..d0e8b279 100644 --- a/static_analyzer.go +++ b/static_analyzer.go @@ -36,6 +36,18 @@ func visitNext(a ast.Node, inObject bool, vars ast.IdentifierSet, state *analysi state.freeVars.AddIdentifiers(a.FreeVariables()) } +func enterLocal(binds ast.LocalBinds, vars ast.IdentifierSet, inObject bool, s *analysisState) ast.IdentifierSet { + newVars := vars.Clone() + for _, bind := range binds { + newVars.Add(bind.Variable) + } + // Binds in local can be mutually or even self recursive + for _, bind := range binds { + visitNext(bind.Body, inObject, newVars, s) + } + return newVars +} + func analyzeVisit(a ast.Node, inObject bool, vars ast.IdentifierSet) error { s := &analysisState{freeVars: ast.NewIdentifierSet()} @@ -99,14 +111,7 @@ func analyzeVisit(a ast.Node, inObject bool, vars ast.IdentifierSet) error { visitNext(a.Target, inObject, vars, s) visitNext(a.Index, inObject, vars, s) case *ast.Local: - newVars := vars.Clone() - for _, bind := range a.Binds { - newVars.Add(bind.Variable) - } - // Binds in local can be mutually or even self recursive - for _, bind := range a.Binds { - visitNext(bind.Body, inObject, newVars, s) - } + newVars := enterLocal(a.Binds, vars, inObject, s) visitNext(a.Body, inObject, newVars, s) // Any usage of newly created variables inside are considered free @@ -123,14 +128,24 @@ func analyzeVisit(a ast.Node, inObject bool, vars ast.IdentifierSet) error { case *ast.LiteralString: //nothing to do here case *ast.DesugaredObject: + newVars := enterLocal(a.Locals, vars, true, s) for _, field := range a.Fields { - // Field names are calculated *outside* of the object - visitNext(field.Name, inObject, vars, s) - visitNext(field.Body, true, vars, s) + visitNext(field.Body, true, newVars, s) } for _, assert := range a.Asserts { - visitNext(assert, true, vars, s) + visitNext(assert, true, newVars, s) } + + // Object local vars are not free outside of object + for _, bind := range a.Locals { + s.freeVars.Remove(bind.Variable) + } + + // Field names are calculated *outside* of the object + for _, field := range a.Fields { + visitNext(field.Name, inObject, vars, s) + } + case *ast.Self: if !inObject { return parser.MakeStaticError("Can't use self outside of an object.", *a.Loc()) diff --git a/testdata/local_in_object_assertion.golden b/testdata/local_in_object_assertion.golden new file mode 100644 index 00000000..ffcd4415 --- /dev/null +++ b/testdata/local_in_object_assertion.golden @@ -0,0 +1 @@ +{ } diff --git a/testdata/local_in_object_assertion.jsonnet b/testdata/local_in_object_assertion.jsonnet new file mode 100644 index 00000000..911dfb63 --- /dev/null +++ b/testdata/local_in_object_assertion.jsonnet @@ -0,0 +1 @@ +{ local x = 42, assert x == 42 } \ No newline at end of file diff --git a/testdata/local_within_object.golden b/testdata/local_within_object.golden new file mode 100644 index 00000000..d81cc071 --- /dev/null +++ b/testdata/local_within_object.golden @@ -0,0 +1 @@ +42 diff --git a/testdata/local_within_object.jsonnet b/testdata/local_within_object.jsonnet new file mode 100644 index 00000000..d43f0b49 --- /dev/null +++ b/testdata/local_within_object.jsonnet @@ -0,0 +1 @@ +local a = 42; {x: a,}.x \ No newline at end of file diff --git a/testdata/object_local_from_parent.golden b/testdata/object_local_from_parent.golden new file mode 100644 index 00000000..732c711f --- /dev/null +++ b/testdata/object_local_from_parent.golden @@ -0,0 +1,5 @@ +{ + "f": { + "f": 42 + } +} diff --git a/testdata/object_local_from_parent.jsonnet b/testdata/object_local_from_parent.jsonnet new file mode 100644 index 00000000..d84d0658 --- /dev/null +++ b/testdata/object_local_from_parent.jsonnet @@ -0,0 +1,7 @@ +{ + local a = 42, + f: { + local b = a, + f: b + } +} \ No newline at end of file diff --git a/testdata/object_local_from_parent_through_local.golden b/testdata/object_local_from_parent_through_local.golden new file mode 100644 index 00000000..732c711f --- /dev/null +++ b/testdata/object_local_from_parent_through_local.golden @@ -0,0 +1,5 @@ +{ + "f": { + "f": 42 + } +} diff --git a/testdata/object_local_from_parent_through_local.jsonnet b/testdata/object_local_from_parent_through_local.jsonnet new file mode 100644 index 00000000..d392eb9d --- /dev/null +++ b/testdata/object_local_from_parent_through_local.jsonnet @@ -0,0 +1,8 @@ +{ + local a = 42, + local b = { + local c = a, + f: c + }, + f: b +} \ No newline at end of file diff --git a/testdata/object_local_recursive.golden b/testdata/object_local_recursive.golden new file mode 100644 index 00000000..6725ae82 --- /dev/null +++ b/testdata/object_local_recursive.golden @@ -0,0 +1,22 @@ +{ + "field": [ + [ + 0, + 0 + ], + [ + 0, + 0, + 1 + ], + [ + 0, + 0 + ], + [ + 0, + 0, + 1 + ] + ] +} diff --git a/testdata/object_local_recursive.jsonnet b/testdata/object_local_recursive.jsonnet new file mode 100644 index 00000000..cf3060ab --- /dev/null +++ b/testdata/object_local_recursive.jsonnet @@ -0,0 +1,7 @@ +{ + local a = [a[1], 0], + local b = d, + local c = a, + local d = a + [1], + field: [a, b, c, d] +} \ No newline at end of file diff --git a/testdata/object_local_uses_local_from_outside.golden b/testdata/object_local_uses_local_from_outside.golden new file mode 100644 index 00000000..44436124 --- /dev/null +++ b/testdata/object_local_uses_local_from_outside.golden @@ -0,0 +1,3 @@ +{ + "field": 42 +} diff --git a/testdata/object_local_uses_local_from_outside.jsonnet b/testdata/object_local_uses_local_from_outside.jsonnet new file mode 100644 index 00000000..f5ac3f56 --- /dev/null +++ b/testdata/object_local_uses_local_from_outside.jsonnet @@ -0,0 +1 @@ +local a = 42; { local x = a, field: x } \ No newline at end of file diff --git a/value.go b/value.go index 71cf4ad5..068f4a86 100644 --- a/value.go +++ b/value.go @@ -456,6 +456,12 @@ func (obj *valueObjectBase) getAssertionsCheckResult() error { return obj.assertionError } +type objectLocal struct { + name ast.Identifier + // Locals may depend on self and super so they are unbound fields and not simply thunks + node ast.Node +} + // valueSimpleObject represents a flat object (no inheritance). // Note that it can be used as part of extended objects // in inheritance using operator +. @@ -469,6 +475,7 @@ type valueSimpleObject struct { upValues bindingFrame fields simpleObjectFieldMap asserts []unboundField + locals []objectLocal } func checkAssertionsHelper(i *interpreter, trace TraceElement, obj valueObject, curr valueObject, superDepth int) error { @@ -486,7 +493,8 @@ func checkAssertionsHelper(i *interpreter, trace TraceElement, obj valueObject, case *valueSimpleObject: for _, assert := range curr.asserts { sb := selfBinding{self: obj, superDepth: superDepth} - _, err := assert.evaluate(i, trace, sb, curr.upValues, "") + fieldUpValues := prepareFieldUpvalues(sb, curr.upValues, curr.locals) + _, err := assert.evaluate(i, trace, sb, fieldUpValues, "") if err != nil { return err } @@ -516,11 +524,12 @@ func (*valueSimpleObject) inheritanceSize() int { return 1 } -func makeValueSimpleObject(b bindingFrame, fields simpleObjectFieldMap, asserts []unboundField) *valueSimpleObject { +func makeValueSimpleObject(b bindingFrame, fields simpleObjectFieldMap, asserts []unboundField, locals []objectLocal) *valueSimpleObject { return &valueSimpleObject{ upValues: b, fields: fields, asserts: asserts, + locals: locals, } } @@ -579,30 +588,51 @@ func makeValueExtendedObject(left, right valueObject) *valueExtendedObject { // findField returns a field in object curr, with superDepth at least minSuperDepth // It also returns an associated bindingFrame and actual superDepth that the field // was found at. -func findField(curr value, minSuperDepth int, f string) (bool, simpleObjectField, bindingFrame, int) { +func findField(curr value, minSuperDepth int, f string) (bool, simpleObjectField, bindingFrame, []objectLocal, int) { switch curr := curr.(type) { case *valueExtendedObject: if curr.right.inheritanceSize() > minSuperDepth { - found, field, frame, counter := findField(curr.right, minSuperDepth, f) + found, field, frame, locals, counter := findField(curr.right, minSuperDepth, f) if found { - return true, field, frame, counter + return true, field, frame, locals, counter } } - found, field, frame, counter := findField(curr.left, minSuperDepth-curr.right.inheritanceSize(), f) - return found, field, frame, counter + curr.right.inheritanceSize() + found, field, frame, locals, counter := findField(curr.left, minSuperDepth-curr.right.inheritanceSize(), f) + return found, field, frame, locals, counter + curr.right.inheritanceSize() case *valueSimpleObject: if minSuperDepth <= 0 { if field, ok := curr.fields[f]; ok { - return true, field, curr.upValues, 0 + return true, field, curr.upValues, curr.locals, 0 } } - return false, simpleObjectField{}, nil, 0 + return false, simpleObjectField{}, nil, nil, 0 default: panic(fmt.Sprintf("Unknown object type %#v", curr)) } } +func prepareFieldUpvalues(sb selfBinding, upValues bindingFrame, locals []objectLocal) bindingFrame { + newUpValues := make(bindingFrame) + for k, v := range upValues { + newUpValues[k] = v + } + localThunks := make([]*cachedThunk, 0, len(locals)) + for _, l := range locals { + th := &cachedThunk{ + // We will fill upValues later + env: &environment{upValues: nil, selfBinding: sb}, + body: l.node, + } + newUpValues[l.name] = th + localThunks = append(localThunks, th) + } + for _, th := range localThunks { + th.env.upValues = newUpValues + } + return newUpValues +} + func objectIndex(i *interpreter, trace TraceElement, sb selfBinding, fieldName string) (value, error) { err := checkAssertions(i, trace, sb.self) if err != nil { @@ -612,17 +642,19 @@ func objectIndex(i *interpreter, trace TraceElement, sb selfBinding, fieldName s return nil, i.Error("Attempt to use super when there is no super class.", trace) } - found, field, upValues, foundAt := findField(sb.self, sb.superDepth, fieldName) + found, field, upValues, locals, foundAt := findField(sb.self, sb.superDepth, fieldName) if !found { return nil, i.Error(fmt.Sprintf("Field does not exist: %s", fieldName), trace) } fieldSelfBinding := selfBinding{self: sb.self, superDepth: foundAt} - return field.field.evaluate(i, trace, fieldSelfBinding, upValues, fieldName) + fieldUpValues := prepareFieldUpvalues(sb, upValues, locals) + + return field.field.evaluate(i, trace, fieldSelfBinding, fieldUpValues, fieldName) } func objectHasField(sb selfBinding, fieldName string, h Hidden) bool { - found, field, _, _ := findField(sb.self, sb.superDepth, fieldName) + found, field, _, _, _ := findField(sb.self, sb.superDepth, fieldName) if !found || (h == withoutHidden && field.hide == ast.ObjectFieldHidden) { return false }