diff --git a/interp/cfg.go b/interp/cfg.go index 9de7db2f3..2f6935d52 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -572,7 +572,7 @@ func (interp *Interpreter) cfg(root *node, importPath, pkgName string) ([]*node, dest.gen = nop case isFuncField(dest): // Setting a struct field of function type requires an extra step. Do not optimize. - case isCall(src) && !isInterfaceSrc(dest.typ) && !isRecursiveField(dest) && n.kind != defineStmt: + case isCall(src) && !isInterfaceSrc(dest.typ) && n.kind != defineStmt: // Call action may perform the assignment directly. if dest.typ.id() != src.typ.id() { // Skip optimitization if returned type doesn't match assigned one. @@ -2403,20 +2403,6 @@ func isField(n *node) bool { return n.kind == selectorExpr && len(n.child) > 0 && n.child[0].typ != nil && isStruct(n.child[0].typ) } -func isRecursiveField(n *node) bool { - if !isField(n) { - return false - } - t := n.typ - for t != nil { - if t.recursive { - return true - } - t = t.val - } - return false -} - func isInConstOrTypeDecl(n *node) bool { anc := n.anc for anc != nil { diff --git a/interp/gta.go b/interp/gta.go index 036c09419..4e168475e 100644 --- a/interp/gta.go +++ b/interp/gta.go @@ -63,7 +63,7 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ } typ := atyp if typ == nil { - if typ, err = nodeType(interp, sc, src); err != nil { + if typ, err = nodeType(interp, sc, src); err != nil || typ == nil { return false } val = src.rval @@ -150,7 +150,7 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ } rcvrtype = ptrOf(elementType, withNode(rtn), withScope(sc)) rcvrtype.incomplete = elementType.incomplete - elementType.method = append(elementType.method, n) + elementType.addMethod(n) } else { rcvrtype = sc.getType(typeName) if rcvrtype == nil { @@ -159,7 +159,7 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ rcvrtype = sc.sym[typeName].typ } } - rcvrtype.method = append(rcvrtype.method, n) + rcvrtype.addMethod(n) n.child[0].child[0].lastChild().typ = rcvrtype case ident == "init": // init functions do not get declared as per the Go spec. @@ -288,7 +288,9 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ } else { if sym.typ != nil && (len(sym.typ.method) > 0) { // Type has already been seen as a receiver in a method function - n.typ.method = append(n.typ.method, sym.typ.method...) + for _, m := range sym.typ.method { + n.typ.addMethod(m) + } } else { // TODO(mpl): figure out how to detect redeclarations without breaking type aliases. // Allow redeclarations for now. diff --git a/interp/type.go b/interp/type.go index 94866bf4e..76c395e97 100644 --- a/interp/type.go +++ b/interp/type.go @@ -123,13 +123,12 @@ type itype struct { path string // for a defined type, the package import path length int // length of array if ArrayT rtype reflect.Type // Reflection type if ValueT, or nil - incomplete bool // true if type must be parsed again (out of order declarations) - recursive bool // true if the type has an element which refer to itself - untyped bool // true for a literal value (string or number) - isBinMethod bool // true if the type refers to a bin method function node *node // root AST node of type definition scope *scope // type declaration scope (in case of re-parse incomplete type) str string // String representation of the type + incomplete bool // true if type must be parsed again (out of order declarations) + untyped bool // true for a literal value (string or number) + isBinMethod bool // true if the type refers to a bin method function } func untypedBool() *itype { @@ -213,6 +212,14 @@ func wrapperValueTOf(rtype reflect.Type, val *itype, opts ...itypeOption) *itype return t } +func variadicOf(val *itype, opts ...itypeOption) *itype { + t := &itype{cat: variadicT, val: val, str: "..." + val.str} + for _, opt := range opts { + opt(t) + } + return t +} + // ptrOf returns a pointer to t. func ptrOf(val *itype, opts ...itypeOption) *itype { if val.ptr != nil { @@ -319,12 +326,17 @@ func mapOf(key, val *itype, opts ...itypeOption) *itype { } // interfaceOf returns an interface type with the given fields. -func interfaceOf(fields []structField, opts ...itypeOption) *itype { +func interfaceOf(t *itype, fields []structField, opts ...itypeOption) *itype { str := "interface{}" if len(fields) > 0 { str = "interface { " + methodsTypeString(fields) + "}" } - t := &itype{cat: interfaceT, field: fields, str: str} + if t == nil { + t = &itype{} + } + t.cat = interfaceT + t.field = fields + t.str = str for _, opt := range opts { opt(t) } @@ -332,12 +344,17 @@ func interfaceOf(fields []structField, opts ...itypeOption) *itype { } // structOf returns a struct type with the given fields. -func structOf(fields []structField, opts ...itypeOption) *itype { +func structOf(t *itype, fields []structField, opts ...itypeOption) *itype { str := "struct {}" if len(fields) > 0 { str = "struct { " + fieldsTypeString(fields) + "}" } - t := &itype{cat: structT, field: fields, str: str} + if t == nil { + t = &itype{} + } + t.cat = structT + t.field = fields + t.str = str for _, opt := range opts { opt(t) } @@ -346,21 +363,31 @@ func structOf(fields []structField, opts ...itypeOption) *itype { // nodeType returns a type definition for the corresponding AST subtree. func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { + return nodeType2(interp, sc, n, map[*node]bool{}) +} + +func nodeType2(interp *Interpreter, sc *scope, n *node, seen map[*node]bool) (t *itype, err error) { if n.typ != nil && !n.typ.incomplete { return n.typ, nil } if sname := typeName(n); sname != "" { - if sym, _, found := sc.lookup(sname); found && sym.kind == typeSym && sym.typ != nil && sym.typ.isComplete() { - return sym.typ, nil + sym, _, found := sc.lookup(sname) + if found && sym.kind == typeSym && sym.typ != nil { + if sym.typ.isComplete() { + return sym.typ, nil + } + if seen[n] { + // TODO (marc): find a better way to distinguish recursive vs incomplete types. + sym.typ.incomplete = false + return sym.typ, nil + } } } + seen[n] = true - t := &itype{node: n, scope: sc} - - var err error switch n.kind { case addressExpr, starExpr: - val, err := nodeType(interp, sc, n.child[0]) + val, err := nodeType2(interp, sc, n.child[0], seen) if err != nil { return nil, err } @@ -370,7 +397,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { case arrayType: c0 := n.child[0] if len(n.child) == 1 { - val, err := nodeType(interp, sc, c0) + val, err := nodeType2(interp, sc, c0, seen) if err != nil { return nil, err } @@ -422,7 +449,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { } length = constToInt(v) } - val, err := nodeType(interp, sc, n.child[1]) + val, err := nodeType2(interp, sc, n.child[1], seen) if err != nil { return nil, err } @@ -459,17 +486,17 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { } case unaryExpr: - t, err = nodeType(interp, sc, n.child[0]) + t, err = nodeType2(interp, sc, n.child[0], seen) case binaryExpr: // Get type of first operand. - if t, err = nodeType(interp, sc, n.child[0]); err != nil { + if t, err = nodeType2(interp, sc, n.child[0], seen); err != nil { return nil, err } // For operators other than shift, get the type from the 2nd operand if the first is untyped. if t.untyped && !isShiftNode(n) { var t1 *itype - t1, err = nodeType(interp, sc, n.child[1]) + t1, err = nodeType2(interp, sc, n.child[1], seen) if !(t1.untyped && isInt(t1.TypeOf()) && isFloat(t.TypeOf())) { t = t1 } @@ -487,7 +514,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { a.child[0].typ = &itype{cat: interfaceT, val: dt, str: "interface{}"} case a.kind == defineStmt && len(a.child) > a.nleft+a.nright: - if dt, err = nodeType(interp, sc, a.child[a.nleft]); err != nil { + if dt, err = nodeType2(interp, sc, a.child[a.nleft], seen); err != nil { return nil, err } @@ -503,14 +530,13 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { case callExpr: if isBuiltinCall(n, sc) { // Builtin types are special and may depend from their input arguments. - t.cat = builtinT switch n.child[0].ident { case bltnComplex: var nt0, nt1 *itype - if nt0, err = nodeType(interp, sc, n.child[1]); err != nil { + if nt0, err = nodeType2(interp, sc, n.child[1], seen); err != nil { return nil, err } - if nt1, err = nodeType(interp, sc, n.child[2]); err != nil { + if nt1, err = nodeType2(interp, sc, n.child[2], seen); err != nil { return nil, err } if nt0.incomplete || nt1.incomplete { @@ -535,7 +561,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { } } case bltnReal, bltnImag: - if t, err = nodeType(interp, sc, n.child[1]); err != nil { + if t, err = nodeType2(interp, sc, n.child[1], seen); err != nil { return nil, err } if !t.incomplete { @@ -553,20 +579,22 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { case bltnCap, bltnCopy, bltnLen: t = sc.getType("int") case bltnAppend, bltnMake: - t, err = nodeType(interp, sc, n.child[1]) + t, err = nodeType2(interp, sc, n.child[1], seen) case bltnNew: - t, err = nodeType(interp, sc, n.child[1]) + t, err = nodeType2(interp, sc, n.child[1], seen) incomplete := t.incomplete t = ptrOf(t, withScope(sc)) t.incomplete = incomplete case bltnRecover: t = sc.getType("interface{}") + default: + t = &itype{cat: builtinT} } if err != nil { return nil, err } } else { - if t, err = nodeType(interp, sc, n.child[0]); err != nil { + if t, err = nodeType2(interp, sc, n.child[0], seen); err != nil || t == nil { return nil, err } switch t.cat { @@ -582,7 +610,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { } case compositeLitExpr: - t, err = nodeType(interp, sc, n.child[0]) + t, err = nodeType2(interp, sc, n.child[0], seen) case chanType, chanTypeRecv, chanTypeSend: dir := chanSendRecv @@ -592,7 +620,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { case chanTypeSend: dir = chanSend } - val, err := nodeType(interp, sc, n.child[0]) + val, err := nodeType2(interp, sc, n.child[0], seen) if err != nil { return nil, err } @@ -600,15 +628,15 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { t.incomplete = val.incomplete case ellipsisExpr: - t.cat = variadicT - if t.val, err = nodeType(interp, sc, n.child[0]); err != nil { + val, err := nodeType2(interp, sc, n.child[0], seen) + if err != nil { return nil, err } - t.str = "..." + t.val.str + t = variadicOf(val, withNode(n), withScope(sc)) t.incomplete = t.val.incomplete case funcLit: - t, err = nodeType(interp, sc, n.child[2]) + t, err = nodeType2(interp, sc, n.child[2], seen) case funcType: var incomplete bool @@ -616,7 +644,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { args := make([]*itype, 0, len(n.child[0].child)) for _, arg := range n.child[0].child { cl := len(arg.child) - 1 - typ, err := nodeType(interp, sc, arg.child[cl]) + typ, err := nodeType2(interp, sc, arg.child[cl], seen) if err != nil { return nil, err } @@ -633,7 +661,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { // Handle returned values for _, ret := range n.child[1].child { cl := len(ret.child) - 1 - typ, err := nodeType(interp, sc, ret.child[cl]) + typ, err := nodeType2(interp, sc, ret.child[cl], seen) if err != nil { return nil, err } @@ -656,9 +684,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { ident := filepath.Join(n.ident, baseName) sym, _, found = sc.lookup(ident) if !found { - t.name = n.ident - t.path = sc.pkgName - t.incomplete = true + t = &itype{name: n.ident, path: sc.pkgName, node: n, incomplete: true, scope: sc} sc.sym[n.ident] = &symbol{kind: typeSym, typ: t} break } @@ -669,7 +695,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { } if t.incomplete && t.node != n { m := t.method - if t, err = nodeType(interp, sc, t.node); err != nil { + if t, err = nodeType2(interp, sc, t.node, seen); err != nil { return nil, err } t.method = m @@ -681,7 +707,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { case indexExpr: var lt *itype - if lt, err = nodeType(interp, sc, n.child[0]); err != nil { + if lt, err = nodeType2(interp, sc, n.child[0], seen); err != nil { return nil, err } if lt.incomplete { @@ -694,13 +720,12 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { } case interfaceType: - t.cat = interfaceT - var incomplete bool if sname := typeName(n); sname != "" { if sym, _, found := sc.lookup(sname); found && sym.kind == typeSym { - sym.typ = t + t = interfaceOf(sym.typ, sym.typ.field, withNode(n), withScope(sc)) } } + var incomplete bool fields := make([]structField, 0, len(n.child[0].child)) for _, field := range n.child[0].child { f0 := field.child[0] @@ -712,7 +737,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { fields = append(fields, structField{name: "Error", typ: typ}) continue } - typ, err := nodeType(interp, sc, f0) + typ, err := nodeType2(interp, sc, f0, seen) if err != nil { return nil, err } @@ -720,25 +745,25 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { incomplete = incomplete || typ.incomplete continue } - typ, err := nodeType(interp, sc, field.child[1]) + typ, err := nodeType2(interp, sc, field.child[1], seen) if err != nil { return nil, err } fields = append(fields, structField{name: f0.ident, typ: typ}) incomplete = incomplete || typ.incomplete } - *t = *interfaceOf(fields, withNode(n), withScope(sc)) + t = interfaceOf(t, fields, withNode(n), withScope(sc)) t.incomplete = incomplete case landExpr, lorExpr: t = sc.getType("bool") case mapType: - key, err := nodeType(interp, sc, n.child[0]) + key, err := nodeType2(interp, sc, n.child[0], seen) if err != nil { return nil, err } - val, err := nodeType(interp, sc, n.child[1]) + val, err := nodeType2(interp, sc, n.child[1], seen) if err != nil { return nil, err } @@ -746,7 +771,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { t.incomplete = key.incomplete || val.incomplete case parenExpr: - t, err = nodeType(interp, sc, n.child[0]) + t, err = nodeType2(interp, sc, n.child[0], seen) case selectorExpr: // Resolve the left part of selector, then lookup the right part on it @@ -772,12 +797,11 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { } } - if lt, err = nodeType(interp, localScope, n.child[0]); err != nil { + if lt, err = nodeType2(interp, localScope, n.child[0], seen); err != nil { return nil, err } if lt.incomplete { - t.incomplete = true break } name := n.child[1].ident @@ -803,7 +827,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { } default: if m, _ := lt.lookupMethod(name); m != nil { - t, err = nodeType(interp, sc, m.child[2]) + t, err = nodeType2(interp, sc, m.child[2], seen) } else if bm, _, _, ok := lt.lookupBinMethod(name); ok { t = valueTOf(bm.Type, isBinMethod(), withRecv(lt), withScope(sc)) } else if ti := lt.lookupField(name); len(ti) > 0 { @@ -816,7 +840,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { } case sliceExpr: - t, err = nodeType(interp, sc, n.child[0]) + t, err = nodeType2(interp, sc, n.child[0], seen) if err != nil { return nil, err } @@ -830,22 +854,17 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { } case structType: - t.cat = structT - var ( - methods []*node - incomplete bool - ) if sname := typeName(n); sname != "" { if sym, _, found := sc.lookup(sname); found && sym.kind == typeSym { - methods = sym.typ.method - sym.typ = t + t = structOf(sym.typ, sym.typ.field, withNode(n), withScope(sc)) } } + var incomplete bool fields := make([]structField, 0, len(n.child[0].child)) for _, c := range n.child[0].child { switch { case len(c.child) == 1: - typ, err := nodeType(interp, sc, c.child[0]) + typ, err := nodeType2(interp, sc, c.child[0], seen) if err != nil { return nil, err } @@ -853,7 +872,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { incomplete = incomplete || typ.incomplete case len(c.child) == 2 && c.child[1].kind == basicLit: tag := vString(c.child[1].rval) - typ, err := nodeType(interp, sc, c.child[0]) + typ, err := nodeType2(interp, sc, c.child[0], seen) if err != nil { return nil, err } @@ -866,7 +885,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { tag = vString(c.lastChild().rval) l-- } - typ, err := nodeType(interp, sc, c.child[l-1]) + typ, err := nodeType2(interp, sc, c.child[l-1], seen) if err != nil { return nil, err } @@ -876,15 +895,14 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { } } } - *t = *structOf(fields, withNode(n), withScope(sc)) - t.method = methods // Recover the symbol methods. + t = structOf(t, fields, withNode(n), withScope(sc)) t.incomplete = incomplete default: err = n.cfgErrorf("type definition not implemented: %s", n.kind) } - if err == nil && t.cat == nilT && !t.incomplete { + if err == nil && t != nil && t.cat == nilT && !t.incomplete { err = n.cfgErrorf("use of untyped nil %s", t.name) } @@ -997,40 +1015,13 @@ func (t *itype) finalize() (*itype, error) { return t, err } -// ReferTo returns true if the type contains a reference to a -// full type name. It allows to assess a type recursive status. -func (t *itype) referTo(name string, seen map[*itype]bool) bool { - if t.path+"/"+t.name == name { - return true - } - if seen[t] { - return false - } - seen[t] = true - switch t.cat { - case aliasT, arrayT, chanT, chanRecvT, chanSendT, ptrT, sliceT, variadicT: - return t.val.referTo(name, seen) - case funcT: - for _, a := range t.arg { - if a.referTo(name, seen) { - return true - } - } - for _, a := range t.ret { - if a.referTo(name, seen) { - return true - } - } - case mapT: - return t.key.referTo(name, seen) || t.val.referTo(name, seen) - case structT, interfaceT: - for _, f := range t.field { - if f.typ.referTo(name, seen) { - return true - } +func (t *itype) addMethod(n *node) { + for _, m := range t.method { + if m == n { + return } } - return false + t.method = append(t.method, n) } func (t *itype) numIn() int { @@ -1100,27 +1091,6 @@ func (t *itype) concrete() *itype { return t } -// IsRecursive returns true if type is recursive. -// Only a named struct or interface can be recursive. -func (t *itype) isRecursive() bool { - if t.name == "" { - return false - } - switch t.cat { - case structT, interfaceT: - for _, f := range t.field { - if f.typ.referTo(t.path+"/"+t.name, map[*itype]bool{}) { - return true - } - } - } - return false -} - -func (t *itype) isIndirectRecursive() bool { - return t.isRecursive() || t.val != nil && t.val.isIndirectRecursive() -} - // isVariadic returns true if the function type is variadic. // If the type is not a function or is not variadic, it will // return false. @@ -1664,23 +1634,8 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type { panic(err) } } - recursive := false name := t.path + "/" + t.name - // Predefined types from universe or runtime may have a nil scope. - if t.scope != nil { - if st := t.scope.sym[t.name]; st != nil { - // Update the type recursive status. Several copies of type - // may exist per symbol, as a new type is created at each GTA - // pass (several needed due to out of order declarations), and - // a node can still point to a previous copy. - st.typ.recursive = st.typ.recursive || st.typ.isRecursive() - recursive = st.typ.isRecursive() - // It is possible that t.recursive is not inline with st.typ.recursive - // which will break recursion detection. Set it here to make sure it - // is correct. - t.recursive = recursive - } - } + if t.rtype != nil && !ctx.rebuilding { return t.rtype } @@ -1738,7 +1693,7 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type { fctx := ctx.Clone() field := reflect.StructField{ Name: exportName(f.name), Type: f.typ.refType(fctx), - Tag: reflect.StructTag(f.tag), Anonymous: (f.embed && !recursive), + Tag: reflect.StructTag(f.tag), Anonymous: f.embed, } fields = append(fields, field) // Find any nil type refs that indicates a rebuild is needed on this field. diff --git a/interp/typecheck.go b/interp/typecheck.go index 38b629ac7..0620a6851 100644 --- a/interp/typecheck.go +++ b/interp/typecheck.go @@ -53,10 +53,6 @@ func (check typecheck) assignment(n *node, typ *itype, context string) error { return nil } - if typ.isIndirectRecursive() || n.typ.isIndirectRecursive() { - return nil - } - if !n.typ.assignableTo(typ) { if context == "" { return n.cfgErrorf("cannot use type %s as type %s", n.typ.id(), typ.id())