diff --git a/internal/unsafe2/unsafe.go b/internal/unsafe2/unsafe.go new file mode 100644 index 000000000..5f885d0e8 --- /dev/null +++ b/internal/unsafe2/unsafe.go @@ -0,0 +1,52 @@ +package unsafe2 + +import ( + "reflect" + "unsafe" +) + +type dummy struct{} + +// DummyType represents a stand-in for a recursive type. +var DummyType = reflect.TypeOf(dummy{}) + +type rtype struct { + _ [48]byte +} + +type emptyInterface struct { + typ *rtype + _ unsafe.Pointer +} + +type structField struct { + _ int64 + typ *rtype + _ uintptr +} + +type structType struct { + rtype + _ int64 + fields []structField +} + +// SwapFieldType swaps the type of the struct field with the given type. +// +// The struct type must have been created at runtime. This is very unsafe. +func SwapFieldType(s reflect.Type, idx int, t reflect.Type) { + if s.Kind() != reflect.Struct || idx >= s.NumField() { + return + } + + rtyp := unpackType(s) + styp := (*structType)(unsafe.Pointer(rtyp)) + f := styp.fields[idx] + f.typ = unpackType(t) + styp.fields[idx] = f +} + +func unpackType(t reflect.Type) *rtype { + v := reflect.New(t).Elem().Interface() + return (*emptyInterface)(unsafe.Pointer(&v)).typ +} diff --git a/internal/unsafe2/unsafe_test.go b/internal/unsafe2/unsafe_test.go new file mode 100644 index 000000000..64e303b33 --- /dev/null +++ b/internal/unsafe2/unsafe_test.go @@ -0,0 +1,33 @@ +package unsafe2_test + +import ( + "reflect" + "testing" + + "github.com/traefik/yaegi/internal/unsafe2" +) + +func TestSwapFieldType(t *testing.T) { + f := []reflect.StructField{ + { + Name: "A", + Type: reflect.TypeOf(int(0)), + }, + { + Name: "B", + Type: reflect.PtrTo(unsafe2.DummyType), + }, + { + Name: "C", + Type: reflect.TypeOf(int64(0)), + }, + } + typ := reflect.StructOf(f) + ntyp := reflect.PtrTo(typ) + + unsafe2.SwapFieldType(typ, 1, ntyp) + + if typ.Field(1).Type != ntyp { + t.Fatalf("unexpected field type: want %s; got %s", ntyp, typ.Field(1).Type) + } +} diff --git a/interp/run.go b/interp/run.go index ea5752712..80adbe4f4 100644 --- a/interp/run.go +++ b/interp/run.go @@ -10,7 +10,6 @@ import ( "regexp" "strings" "sync" - "unsafe" ) // bltn type defines functions which run at CFG execution. @@ -568,18 +567,6 @@ func convert(n *node) { } } -func isRecursiveType(t *itype, rtype reflect.Type) bool { - if t.cat == structT && rtype.Kind() == reflect.Interface { - return true - } - switch t.cat { - case aliasT, arrayT, mapT, ptrT, sliceT: - return isRecursiveType(t.val, t.val.rtype) - default: - return false - } -} - func assign(n *node) { next := getExec(n.tnext) dvalue := make([]func(*frame) reflect.Value, n.nleft) @@ -1038,11 +1025,7 @@ func call(n *node) { switch { case n.child[0].recv != nil: // Compute method receiver value. - if isRecursiveType(n.child[0].recv.node.typ, n.child[0].recv.node.typ.rtype) { - values = append(values, genValueRecvInterfacePtr(n.child[0])) - } else { - values = append(values, genValueRecv(n.child[0])) - } + values = append(values, genValueRecv(n.child[0])) method = true case len(n.child[0].child) > 0 && n.child[0].child[0].typ != nil && isInterfaceSrc(n.child[0].child[0].typ): recvIndexLater = true @@ -1096,8 +1079,6 @@ func call(n *node) { values = append(values, genValueInterface(c)) case isInterfaceBin(arg): values = append(values, genInterfaceWrapper(c, arg.rtype)) - case isRecursiveType(c.typ, c.typ.rtype): - values = append(values, genValueRecursiveInterfacePtrValue(c)) default: values = append(values, genValue(c)) } @@ -1852,9 +1833,6 @@ func getIndexSeq(n *node) { fnext := getExec(n.fnext) n.exec = func(f *frame) bltn { v := value(f) - if v.Type().Kind() == reflect.Interface && n.child[0].typ.recursive { - v = writableDeref(v) - } r := v.FieldByIndex(index) getFrame(f, l).data[i] = r if r.Bool() { @@ -1865,34 +1843,16 @@ func getIndexSeq(n *node) { } else { n.exec = func(f *frame) bltn { v := value(f) - if v.Type().Kind() == reflect.Interface && n.child[0].typ.recursive { - v = writableDeref(v) - } getFrame(f, l).data[i] = v.FieldByIndex(index) return tnext } } } -//go:nocheckptr -func writableDeref(v reflect.Value) reflect.Value { - // Here we have an interface to a struct. Any attempt to dereference it will - // make a copy of the struct. We need to get a Value to the actual struct. - // TODO: using unsafe is a temporary measure. Rethink this. - // TODO: InterfaceData has been depreciated, this is even less of a good idea now. - return reflect.NewAt(v.Elem().Type(), unsafe.Pointer(v.InterfaceData()[1])).Elem() //nolint:govet,staticcheck -} - func getPtrIndexSeq(n *node) { index := n.val.([]int) tnext := getExec(n.tnext) - var value func(*frame) reflect.Value - if isRecursiveType(n.child[0].typ, n.child[0].typ.rtype) { - v := genValue(n.child[0]) - value = func(f *frame) reflect.Value { return v(f).Elem().Elem() } - } else { - value = genValue(n.child[0]) - } + value := genValue(n.child[0]) i := n.findex l := n.level @@ -2546,8 +2506,6 @@ func doComposite(n *node, hasType bool, keyed bool) { values[fieldIndex] = genValueAsFunctionWrapper(val) case isArray(val.typ) && val.typ.val != nil && isInterfaceSrc(val.typ.val): values[fieldIndex] = genValueInterfaceArray(val) - case isRecursiveType(ft, rft): - values[fieldIndex] = genValueRecursiveInterface(val, rft) case isInterfaceSrc(ft) && !isEmptyInterface(ft): values[fieldIndex] = genValueInterface(val) case isInterface(ft): @@ -2946,8 +2904,6 @@ func _append(n *node) { values[i] = genValueInterface(arg) case isInterfaceBin(n.typ.val): values[i] = genInterfaceWrapper(arg, n.typ.val.rtype) - case isRecursiveType(n.typ.val, n.typ.val.rtype): - values[i] = genValueRecursiveInterface(arg, n.typ.val.rtype) case arg.typ.untyped: values[i] = genValueAs(arg, n.child[1].typ.TypeOf().Elem()) default: @@ -2972,8 +2928,6 @@ func _append(n *node) { value0 = genValueInterface(n.child[2]) case isInterfaceBin(elem): value0 = genInterfaceWrapper(n.child[2], elem.rtype) - case isRecursiveType(elem, elem.rtype): - value0 = genValueRecursiveInterface(n.child[2], elem.rtype) case n.child[2].typ.untyped: value0 = genValueAs(n.child[2], n.child[1].typ.TypeOf().Elem()) default: diff --git a/interp/type.go b/interp/type.go index 8f02f757c..a748d6a1d 100644 --- a/interp/type.go +++ b/interp/type.go @@ -8,6 +8,8 @@ import ( "strconv" "strings" "sync" + + "github.com/traefik/yaegi/internal/unsafe2" ) // tcat defines interpreter type categories. @@ -274,6 +276,11 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { if n.typ != nil && !n.typ.incomplete { return n.typ, nil } + if sname := typeName(n); sname != "" { + if sym, _, found := sc.lookup(sname); found && sym.kind == typeSym && sym.typ != nil && sym.typ.isComplete() { + return sym.typ, nil + } + } repr := strings.Builder{} t := &itype{node: n, scope: sc} @@ -610,6 +617,8 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { ident := filepath.Join(n.ident, baseName) sym, _, found = sc.lookup(ident) if !found { + t.name = n.ident + t.path = sc.pkgID t.incomplete = true sc.sym[n.ident] = &symbol{kind: typeSym, typ: t} break @@ -853,12 +862,12 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { switch { case t == nil: - case t.cat == nilT: - t.str = "nil" case t.name != "" && t.path != "": t.str = t.path + "." + t.name case repr.Len() > 0: t.str = repr.String() + case t.cat == nilT: + t.str = "nil" } return t, err @@ -1582,13 +1591,37 @@ var ( constVal = reflect.TypeOf((*constant.Value)(nil)).Elem() ) +type fieldRebuild struct { + typ *itype + idx int +} + +type refTypeContext struct { + defined map[string]*itype + refs map[string][]fieldRebuild + rebuilding bool +} + +// Clone creates a copy if the ref type context without the `needsRebuild` set. +func (c *refTypeContext) Clone() *refTypeContext { + return &refTypeContext{defined: c.defined, refs: c.refs, rebuilding: c.rebuilding} +} + // RefType returns a reflect.Type representation from an interpreter type. // In simple cases, reflect types are directly mapped from the interpreter // counterpart. // For recursive named struct or interfaces, as reflect does not permit to -// create a recursive named struct, an interface{} is returned in place to -// avoid infinitely nested structs. -func (t *itype) refType(defined map[string]*itype, wrapRecursive bool) reflect.Type { +// create a recursive named struct, a nil type is set temporarily for each recursive +// field. When done, the nil type fields are updated with the original reflect type +// pointer using unsafe. We thus obtain a usable recursive type definition, except +// for string representation, as created reflect types are still unnamed. +func (t *itype) refType(ctx *refTypeContext) reflect.Type { + if ctx == nil { + ctx = &refTypeContext{ + defined: map[string]*itype{}, + refs: map[string][]fieldRebuild{}, + } + } if t.incomplete || t.cat == nilT { var err error if t, err = t.finalize(); err != nil { @@ -1612,82 +1645,82 @@ func (t *itype) refType(defined map[string]*itype, wrapRecursive bool) reflect.T t.recursive = recursive } } - if wrapRecursive && t.recursive { - return interf - } - if t.rtype != nil { + if t.rtype != nil && !ctx.rebuilding { return t.rtype } - if defined[name] != nil && defined[name].rtype != nil { - return defined[name].rtype - } - if t.val != nil && t.val.cat == structT && t.val.rtype == nil && hasRecursiveStruct(t.val, copyDefined(defined)) { - // Replace reference to self (direct or indirect) by an interface{} to handle - // recursive types with reflect. - typ := *t.val - t.val = &typ - t.val.rtype = interf - recursive = true + if dt := ctx.defined[name]; dt != nil { + if dt.rtype != nil { + t.rtype = dt.rtype + return dt.rtype + } + + // To indicate that a rebuild is needed on the nearest struct + // field, create an entry with a nil type. + flds := ctx.refs[name] + ctx.refs[name] = append(flds, fieldRebuild{}) + return unsafe2.DummyType } switch t.cat { case aliasT: - t.rtype = t.val.refType(defined, wrapRecursive) + t.rtype = t.val.refType(ctx) case arrayT: - t.rtype = reflect.ArrayOf(t.length, t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.ArrayOf(t.length, t.val.refType(ctx)) case sliceT, variadicT: - t.rtype = reflect.SliceOf(t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.SliceOf(t.val.refType(ctx)) case chanT: - t.rtype = reflect.ChanOf(reflect.BothDir, t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.ChanOf(reflect.BothDir, t.val.refType(ctx)) case chanRecvT: - t.rtype = reflect.ChanOf(reflect.RecvDir, t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.ChanOf(reflect.RecvDir, t.val.refType(ctx)) case chanSendT: - t.rtype = reflect.ChanOf(reflect.SendDir, t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.ChanOf(reflect.SendDir, t.val.refType(ctx)) case errorT: t.rtype = reflect.TypeOf(new(error)).Elem() case funcT: - if t.name != "" { - defined[name] = t // TODO(marc): make sure that key is name and not t.name. - } variadic := false in := make([]reflect.Type, len(t.arg)) out := make([]reflect.Type, len(t.ret)) for i, v := range t.arg { - in[i] = v.refType(defined, true) + in[i] = v.refType(ctx) variadic = v.cat == variadicT } for i, v := range t.ret { - out[i] = v.refType(defined, true) + out[i] = v.refType(ctx) } t.rtype = reflect.FuncOf(in, out, variadic) case interfaceT: t.rtype = interf case mapT: - t.rtype = reflect.MapOf(t.key.refType(defined, wrapRecursive), t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.MapOf(t.key.refType(ctx), t.val.refType(ctx)) case ptrT: - t.rtype = reflect.PtrTo(t.val.refType(defined, wrapRecursive)) + t.rtype = reflect.PtrTo(t.val.refType(ctx)) case structT: if t.name != "" { - // Check against local t.name and not name to catch recursive type definitions. - if defined[t.name] != nil { - recursive = true - } - defined[t.name] = t + ctx.defined[name] = t } var fields []reflect.StructField - // TODO(mpl): make Anonymous work for recursive types too. Maybe not worth the - // effort, and we're better off just waiting for - // https://github.com/golang/go/issues/39717 to land. - for _, f := range t.field { + for i, f := range t.field { + fctx := ctx.Clone() field := reflect.StructField{ - Name: exportName(f.name), Type: f.typ.refType(defined, wrapRecursive), + Name: exportName(f.name), Type: f.typ.refType(fctx), Tag: reflect.StructTag(f.tag), Anonymous: (f.embed && !recursive), } fields = append(fields, field) + // Find any nil type refs that indicates a rebuild is needed on this field. + for _, flds := range ctx.refs { + for j, fld := range flds { + if fld.typ == nil { + flds[j] = fieldRebuild{typ: t, idx: i} + } + } + } } - if recursive && wrapRecursive { - t.rtype = interf - } else { - t.rtype = reflect.StructOf(fields) + t.rtype = reflect.StructOf(fields) + + // The rtype has now been built, we can go back and rebuild + // all the recursive types that relied on this type. + for _, f := range ctx.refs[name] { + ftyp := f.typ.field[f.idx].typ.refType(&refTypeContext{defined: ctx.defined, rebuilding: true}) + unsafe2.SwapFieldType(f.typ.rtype, f.idx, ftyp) } default: if z, _ := t.zero(); z.IsValid() { @@ -1699,7 +1732,7 @@ func (t *itype) refType(defined map[string]*itype, wrapRecursive bool) reflect.T // TypeOf returns the reflection type of dynamic interpreter type t. func (t *itype) TypeOf() reflect.Type { - return t.refType(map[string]*itype{}, false) + return t.refType(nil) } func (t *itype) frameType() (r reflect.Type) { @@ -1802,44 +1835,6 @@ func (t *itype) elem() *itype { return t.val } -func copyDefined(m map[string]*itype) map[string]*itype { - n := make(map[string]*itype, len(m)) - for k, v := range m { - n[k] = v - } - return n -} - -// hasRecursiveStruct determines if a struct is a recursion or a recursion -// intermediate. A recursion intermediate is a struct that contains a recursive -// struct. -func hasRecursiveStruct(t *itype, defined map[string]*itype) bool { - if len(defined) == 0 { - return false - } - - typ := t - for typ != nil { - if typ.cat != structT { - typ = typ.val - continue - } - - if defined[typ.path+"/"+typ.name] != nil { - return true - } - defined[typ.path+"/"+typ.name] = typ - - for _, f := range typ.field { - if hasRecursiveStruct(f.typ, copyDefined(defined)) { - return true - } - } - return false - } - return false -} - func constToInt(c constant.Value) int { if constant.BitLen(c) > 64 { panic(fmt.Sprintf("constant %s overflows int64", c.ExactString())) diff --git a/interp/typecheck.go b/interp/typecheck.go index b683ce1bd..38b629ac7 100644 --- a/interp/typecheck.go +++ b/interp/typecheck.go @@ -752,10 +752,6 @@ func (check typecheck) builtin(name string, n *node, child []*node, ellipsis boo } return nil } - // We cannot check a recursive type. - if isRecursiveType(typ, typ.TypeOf()) { - return nil - } fun := &node{ typ: &itype{ diff --git a/interp/value.go b/interp/value.go index 8c33ffb6c..8f9ea0e2e 100644 --- a/interp/value.go +++ b/interp/value.go @@ -129,25 +129,6 @@ func genValueBinRecv(n *node, recv *receiver) func(*frame) reflect.Value { } } -func genValueRecvInterfacePtr(n *node) func(*frame) reflect.Value { - v := genValue(n.recv.node) - fi := n.recv.index - - return func(f *frame) reflect.Value { - r := v(f) - r = r.Elem().Elem() - - if len(fi) == 0 { - return r - } - - if r.Kind() == reflect.Ptr { - r = r.Elem() - } - return r.FieldByIndex(fi) - } -} - func genValueAsFunctionWrapper(n *node) func(*frame) reflect.Value { value := genValue(n) typ := n.typ.TypeOf() @@ -240,10 +221,6 @@ func genDestValue(typ *itype, n *node) func(*frame) reflect.Value { return genInterfaceWrapper(n, typ.rtype) case n.kind == basicLit && n.val == nil: return func(*frame) reflect.Value { return reflect.New(typ.rtype).Elem() } - case isRecursiveType(typ, typ.rtype): - return genValueRecursiveInterface(n, typ.rtype) - case isRecursiveType(n.typ, n.typ.rtype): - return genValueRecursiveInterfacePtrValue(n) case n.typ.untyped && isComplex(typ.TypeOf()): return genValueComplex(n) case n.typ.untyped && !typ.untyped: @@ -440,63 +417,6 @@ func genValueNode(n *node) func(*frame) reflect.Value { } } -func genValueRecursiveInterface(n *node, t reflect.Type) func(*frame) reflect.Value { - value := genValue(n) - - return func(f *frame) reflect.Value { - vv := value(f) - v := reflect.New(t).Elem() - toRecursive(v, vv) - return v - } -} - -func toRecursive(dest, src reflect.Value) { - if !src.IsValid() { - return - } - - switch dest.Kind() { - case reflect.Map: - v := reflect.MakeMapWithSize(dest.Type(), src.Len()) - for _, kv := range src.MapKeys() { - vv := reflect.New(dest.Type().Elem()).Elem() - toRecursive(vv, src.MapIndex(kv)) - vv.SetMapIndex(kv, vv) - } - dest.Set(v) - case reflect.Slice: - l := src.Len() - v := reflect.MakeSlice(dest.Type(), l, l) - for i := 0; i < l; i++ { - toRecursive(v.Index(i), src.Index(i)) - } - dest.Set(v) - case reflect.Ptr: - v := reflect.New(dest.Type().Elem()).Elem() - s := src - if s.Elem().Kind() != reflect.Struct { // In the case of *interface{}, we want *struct{} - s = s.Elem() - } - toRecursive(v, s) - dest.Set(v.Addr()) - default: - dest.Set(src) - } -} - -func genValueRecursiveInterfacePtrValue(n *node) func(*frame) reflect.Value { - value := genValue(n) - - return func(f *frame) reflect.Value { - v := value(f) - if v.IsZero() { - return v - } - return v.Elem().Elem() - } -} - func vInt(v reflect.Value) (i int64) { if c := vConstantValue(v); c != nil { i, _ = constant.Int64Val(constant.ToInt(c))