diff --git a/_test/gen1.go b/_test/gen1.go new file mode 100644 index 000000000..a83d4387d --- /dev/null +++ b/_test/gen1.go @@ -0,0 +1,39 @@ +package main + +import "fmt" + +// SumInts adds together the values of m. +func SumInts(m map[string]int64) int64 { + var s int64 + for _, v := range m { + s += v + } + return s +} + +// SumFloats adds together the values of m. +func SumFloats(m map[string]float64) float64 { + var s float64 + for _, v := range m { + s += v + } + return s +} + +func main() { + // Initialize a map for the integer values + ints := map[string]int64{ + "first": 34, + "second": 12, + } + + // Initialize a map for the float values + floats := map[string]float64{ + "first": 35.98, + "second": 26.99, + } + + fmt.Printf("Non-Generic Sums: %v and %v\n", + SumInts(ints), + SumFloats(floats)) +} diff --git a/_test/gen2.go b/_test/gen2.go new file mode 100644 index 000000000..b4c3fff7b --- /dev/null +++ b/_test/gen2.go @@ -0,0 +1,34 @@ +package main + +import "fmt" + +// SumIntsOrFloats sums the values of map m. It supports both int64 and float64 +// as types for map values. +func SumIntsOrFloats[K comparable, V int64 | float64](m map[K]V) V { + var s V + for _, v := range m { + s += v + } + return s +} + +func main() { + // Initialize a map for the integer values + ints := map[string]int64{ + "first": 34, + "second": 12, + } + + // Initialize a map for the float values + floats := map[string]float64{ + "first": 35.98, + "second": 26.99, + } + + fmt.Printf("Generic Sums: %v and %v\n", + SumIntsOrFloats[string, int64](ints), + SumIntsOrFloats[string, float64](floats)) +} + +// Output: +// Generic Sums: 46 and 62.97 diff --git a/_test/gen3.go b/_test/gen3.go new file mode 100644 index 000000000..a09da6566 --- /dev/null +++ b/_test/gen3.go @@ -0,0 +1,22 @@ +package main + +type Number interface { + int | int64 | ~float64 +} + +func Sum[T Number](numbers []T) T { + var total T + for _, x := range numbers { + total += x + } + return total +} + +func main() { + xs := []int{3, 5, 10} + total := Sum(xs) + println(total) +} + +// Output: +// 18 diff --git a/_test/gen4.go b/_test/gen4.go new file mode 100644 index 000000000..13d04b703 --- /dev/null +++ b/_test/gen4.go @@ -0,0 +1,42 @@ +package main + +import "fmt" + +type List[T any] struct { + head, tail *element[T] +} + +// A recursive generic type. +type element[T any] struct { + next *element[T] + val T +} + +func (lst *List[T]) Push(v T) { + if lst.tail == nil { + lst.head = &element[T]{val: v} + lst.tail = lst.head + } else { + lst.tail.next = &element[T]{val: v} + lst.tail = lst.tail.next + } +} + +func (lst *List[T]) GetAll() []T { + var elems []T + for e := lst.head; e != nil; e = e.next { + elems = append(elems, e.val) + } + return elems +} + +func main() { + lst := List[int]{} + lst.Push(10) + lst.Push(13) + lst.Push(23) + fmt.Println("list:", lst.GetAll()) +} + +// Output: +// list: [10 13 23] diff --git a/_test/gen5.go b/_test/gen5.go new file mode 100644 index 000000000..a71dd659c --- /dev/null +++ b/_test/gen5.go @@ -0,0 +1,24 @@ +package main + +import "fmt" + +type Set[Elem comparable] struct { + m map[Elem]struct{} +} + +func Make[Elem comparable]() Set[Elem] { + return Set[Elem]{m: make(map[Elem]struct{})} +} + +func (s Set[Elem]) Add(v Elem) { + s.m[v] = struct{}{} +} + +func main() { + s := Make[int]() + s.Add(1) + fmt.Println(s) +} + +// Output: +// {map[1:{}]} diff --git a/_test/gen6.go b/_test/gen6.go new file mode 100644 index 000000000..e3ec819f5 --- /dev/null +++ b/_test/gen6.go @@ -0,0 +1,19 @@ +package main + +func MapKeys[K comparable, V any](m map[K]V) []K { + r := make([]K, 0, len(m)) + for k := range m { + r = append(r, k) + } + return r +} + +func main() { + var m = map[int]string{1: "2", 2: "4", 4: "8"} + + // Test type inference + println(len(MapKeys(m))) +} + +// Output: +// 3 diff --git a/_test/gen7.go b/_test/gen7.go new file mode 100644 index 000000000..c3556fd7d --- /dev/null +++ b/_test/gen7.go @@ -0,0 +1,19 @@ +package main + +func MapKeys[K comparable, V any](m map[K]V) []K { + r := make([]K, 0, len(m)) + for k := range m { + r = append(r, k) + } + return r +} + +func main() { + var m = map[int]string{1: "2", 2: "4", 4: "8"} + + // Test type inference + println(len(MapKeys)) +} + +// Error: +// invalid argument for len diff --git a/_test/gen8.go b/_test/gen8.go new file mode 100644 index 000000000..c88d76208 --- /dev/null +++ b/_test/gen8.go @@ -0,0 +1,15 @@ +package main + +type Float interface { + ~float32 | ~float64 +} + +func add[T Float](a, b T) float64 { return float64(a) + float64(b) } + +func main() { + var x, y int = 1, 2 + println(add(x, y)) +} + +// Error: +// int does not implement main.Float diff --git a/_test/gen9.go b/_test/gen9.go new file mode 100644 index 000000000..eed29a73e --- /dev/null +++ b/_test/gen9.go @@ -0,0 +1,14 @@ +package main + +type Float interface { + ~float32 | ~float64 +} + +func add[T Float](a, b T) float64 { return float64(a) + float64(b) } + +func main() { + println(add(1, 2)) +} + +// Error: +// untyped int does not implement main.Float diff --git a/interp/ast.go b/interp/ast.go index 9bfa12634..1e71ac439 100644 --- a/interp/ast.go +++ b/interp/ast.go @@ -72,6 +72,7 @@ const ( importSpec incDecStmt indexExpr + indexListExpr interfaceType keyValueExpr labeledStmt @@ -155,6 +156,7 @@ var kinds = [...]string{ importSpec: "importSpec", incDecStmt: "incDecStmt", indexExpr: "indexExpr", + indexListExpr: "indexListExpr", interfaceType: "interfaceType", keyValueExpr: "keyValueExpr", labeledStmt: "labeledStmt", @@ -694,7 +696,7 @@ func (interp *Interpreter) ast(f ast.Node) (string, *node, error) { n := addChild(&root, anc, pos, funcDecl, aNop) n.val = n if a.Recv == nil { - // function is not a method, create an empty receiver list + // Function is not a method, create an empty receiver list. addChild(&root, astNode{n, nod}, pos, fieldList, aNop) } st.push(n, nod) @@ -706,7 +708,13 @@ func (interp *Interpreter) ast(f ast.Node) (string, *node, error) { st.push(n, nod) case *ast.FuncType: - st.push(addChild(&root, anc, pos, funcType, aNop), nod) + n := addChild(&root, anc, pos, funcType, aNop) + n.val = n + if a.TypeParams == nil { + // Function has no type parameters, create an empty fied list. + addChild(&root, astNode{n, nod}, pos, fieldList, aNop) + } + st.push(n, nod) case *ast.GenDecl: var kind nkind @@ -776,6 +784,9 @@ func (interp *Interpreter) ast(f ast.Node) (string, *node, error) { case *ast.IndexExpr: st.push(addChild(&root, anc, pos, indexExpr, aGetIndex), nod) + case *ast.IndexListExpr: + st.push(addChild(&root, anc, pos, indexListExpr, aNop), nod) + case *ast.InterfaceType: st.push(addChild(&root, anc, pos, interfaceType, aNop), nod) diff --git a/interp/cfg.go b/interp/cfg.go index 78af0c7dd..81204c576 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -303,7 +303,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string // Indicate that the first child is the type. n.nleft = 1 } else { - // Get type from ancestor (implicit type) + // Get type from ancestor (implicit type). if n.anc.kind == keyValueExpr && n == n.anc.child[0] { n.typ = n.anc.typ.key } else if atyp := n.anc.typ; atyp != nil { @@ -366,15 +366,49 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string return false } n.val = n + + // Skip substree in case of a generic function. + if len(n.child[2].child[0].child) > 0 { + return false + } + + // Skip subtree if the function is a method with a generic receiver. + if len(n.child[0].child) > 0 { + recvTypeNode := n.child[0].child[0].lastChild() + typ, err := nodeType(interp, sc, recvTypeNode) + if err != nil { + return false + } + if typ.cat == genericT || (typ.val != nil && typ.val.cat == genericT) { + return false + } + if typ.cat == ptrT { + rc0 := recvTypeNode.child[0] + rt0, err := nodeType(interp, sc, rc0) + if err != nil { + return false + } + if rc0.kind == indexExpr && rt0.cat == structT { + return false + } + } + } + // Compute function type before entering local scope to avoid // possible collisions with function argument names. n.child[2].typ, err = nodeType(interp, sc, n.child[2]) - // Add a frame indirection level as we enter in a func + if err != nil { + return false + } + n.typ = n.child[2].typ + + // Add a frame indirection level as we enter in a func. sc = sc.pushFunc() sc.def = n - if len(n.child[2].child) == 2 { - // Allocate frame space for return values, define output symbols - for _, c := range n.child[2].child[1].child { + + // Allocate frame space for return values, define output symbols. + if len(n.child[2].child) == 3 { + for _, c := range n.child[2].child[2].child { var typ *itype if typ, err = nodeType(interp, sc, c.lastChild()); err != nil { return false @@ -388,8 +422,9 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string } } } + + // Define receiver symbol. if len(n.child[0].child) > 0 { - // define receiver symbol var typ *itype fr := n.child[0].child[0] recvTypeNode := fr.lastChild() @@ -404,8 +439,9 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string sc.sym[fr.child[0].ident] = &symbol{index: index, kind: varSym, typ: typ} } } - for _, c := range n.child[2].child[0].child { - // define input parameter symbols + + // Define input parameter symbols. + for _, c := range n.child[2].child[1].child { var typ *itype if typ, err = nodeType(interp, sc, c.lastChild()); err != nil { return false @@ -414,6 +450,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string sc.sym[cc.ident] = &symbol{index: sc.add(typ), kind: varSym, typ: typ} } } + if n.child[1].ident == "init" && len(n.child[0].child) == 0 { initNodes = append(initNodes, n) } @@ -817,6 +854,49 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string } else { n.typ = valueTOf(t.rtype.Elem()) } + case funcT: + // A function indexed by a type means an instantiated generic function. + c1 := n.child[1] + if !c1.isType(sc) { + n.typ = t + return + } + g, err := genAST(sc, t.node.anc, []*node{c1}) + if err != nil { + return + } + if _, err = interp.cfg(g, nil, importPath, pkgName); err != nil { + return + } + // Generate closures for function body. + if err = genRun(g.child[3]); err != nil { + return + } + // Replace generic func node by instantiated one. + n.anc.child[childPos(n)] = g + n.typ = g.typ + return + case genericT: + name := t.id() + "[" + n.child[1].typ.id() + "]" + sym, _, ok := sc.lookup(name) + if !ok { + err = n.cfgErrorf("type not found: %s", name) + return + } + n.gen = nop + n.typ = sym.typ + return + case structT: + // A struct indexed by a Type means an instantiated generic struct. + name := t.name + "[" + n.child[1].ident + "]" + sym, _, ok := sc.lookup(name) + if ok { + n.typ = sym.typ + n.findex = sc.add(n.typ) + n.gen = nop + return + } + default: n.typ = t.val } @@ -929,9 +1009,42 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string } } wireChild(n) - switch { + switch c0 := n.child[0]; { + case c0.kind == indexListExpr: + // Instantiate a generic function then call it. + fun := c0.child[0].sym.node + g, err := genAST(sc, fun, c0.child[1:]) + if err != nil { + return + } + _, err = interp.cfg(g, nil, importPath, pkgName) + if err != nil { + return + } + err = genRun(g.child[3]) // Generate closures for function body. + if err != nil { + return + } + n.child[0] = g + c0 = n.child[0] + wireChild(n) + if typ := c0.typ; len(typ.ret) > 0 { + n.typ = typ.ret[0] + if n.anc.kind == returnStmt && n.typ.id() == sc.def.typ.ret[0].id() { + // Store the result directly to the return value area of frame. + // It can be done only if no type conversion at return is involved. + n.findex = childPos(n) + } else { + n.findex = sc.add(n.typ) + for _, t := range typ.ret[1:] { + sc.add(t) + } + } + } else { + n.findex = notInFrame + } + case isBuiltinCall(n, sc): - c0 := n.child[0] bname := c0.ident err = check.builtin(bname, n, n.child[1:], n.action == aCallSlice) if err != nil { @@ -982,9 +1095,10 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string if op, ok := constBltn[bname]; ok && n.anc.action != aAssign { op(n) // pre-compute non-assigned constant : } - case n.child[0].isType(sc): + + case c0.isType(sc): // Type conversion expression - c0, c1 := n.child[0], n.child[1] + c1 := n.child[1] switch len(n.child) { case 1: err = n.cfgErrorf("missing argument in conversion to %s", c0.typ.id()) @@ -1029,16 +1143,17 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string n.typ = c0.typ n.findex = sc.add(n.typ) } + case isBinCall(n, sc): - err = check.arguments(n, n.child[1:], n.child[0], n.action == aCallSlice) + err = check.arguments(n, n.child[1:], c0, n.action == aCallSlice) if err != nil { break } n.gen = callBin - typ := n.child[0].typ.rtype + typ := c0.typ.rtype if typ.NumOut() > 0 { - if funcType := n.child[0].typ.val; funcType != nil { + if funcType := c0.typ.val; funcType != nil { // Use the original unwrapped function type, to allow future field and // methods resolutions, otherwise impossible on the opaque bin type. n.typ = funcType.ret[0] @@ -1058,7 +1173,8 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string } } } - case isOffsetof(n): + + case isOffsetof(c0): if len(n.child) != 2 || n.child[1].kind != selectorExpr || !isStruct(n.child[1].child[0].typ) { err = n.cfgErrorf("Offsetof argument: invalid expression") break @@ -1072,17 +1188,45 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string n.typ = valueTOf(reflect.TypeOf(field.Offset)) n.rval = reflect.ValueOf(field.Offset) n.gen = nop + default: - err = check.arguments(n, n.child[1:], n.child[0], n.action == aCallSlice) + // The call may be on a generic function. In that case, replace the + // generic function AST by an instantiated one before going further. + if isGeneric(c0.typ) { + fun := c0.typ.node.anc + var g *node + var types []*node + + // Infer type parameter from function call arguments. + if types, err = inferTypesFromCall(sc, fun, n.child[1:]); err != nil { + break + } + // Generate an instantiated AST from the generic function one. + if g, err = genAST(sc, fun, types); err != nil { + break + } + // Compile the generated function AST, so it becomes part of the scope. + if _, err = interp.cfg(g, nil, importPath, pkgName); err != nil { + break + } + // AST compilation part 2: Generate closures for function body. + if err = genRun(g.child[3]); err != nil { + break + } + n.child[0] = g + c0 = n.child[0] + } + + err = check.arguments(n, n.child[1:], c0, n.action == aCallSlice) if err != nil { break } - if n.child[0].action == aGetFunc { + if c0.action == aGetFunc { // Allocate a frame entry to store the anonymous function definition. - sc.add(n.child[0].typ) + sc.add(c0.typ) } - if typ := n.child[0].typ; len(typ.ret) > 0 { + if typ := c0.typ; len(typ.ret) > 0 { n.typ = typ.ret[0] if n.anc.kind == returnStmt && n.typ.id() == sc.def.typ.ret[0].id() { // Store the result directly to the return value area of frame. @@ -1300,7 +1444,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string n.types, n.scope = sc.types, sc sc = sc.pop() funcName := n.child[1].ident - if sym := sc.sym[funcName]; !isMethod(n) && sym != nil { + if sym := sc.sym[funcName]; !isMethod(n) && sym != nil && !isGeneric(sym.typ) { sym.index = -1 // to force value to n.val sym.typ = n.typ sym.kind = funcSym @@ -1536,7 +1680,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string n.val = sc.def for i, c := range n.child { var typ *itype - typ, err = nodeType(interp, sc.upperLevel(), returnSig.child[1].fieldType(i)) + typ, err = nodeType(interp, sc.upperLevel(), returnSig.child[2].fieldType(i)) if err != nil { return } @@ -1667,7 +1811,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string if n.typ.cat == valueT || n.typ.cat == errorT { switch method, ok := n.typ.rtype.MethodByName(n.child[1].ident); { case ok: - hasRecvType := n.typ.rtype.Kind() != reflect.Interface + hasRecvType := n.typ.TypeOf().Kind() != reflect.Interface n.val = method.Index n.gen = getIndexBinMethod n.action = aGetMethod @@ -1676,7 +1820,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string if hasRecvType { n.typ.recv = n.typ } - case n.typ.rtype.Kind() == reflect.Ptr: + case n.typ.TypeOf().Kind() == reflect.Ptr: if field, ok := n.typ.rtype.Elem().FieldByName(n.child[1].ident); ok { n.typ = valueTOf(field.Type) n.val = field.Index @@ -1684,7 +1828,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string break } err = n.cfgErrorf("undefined field or method: %s", n.child[1].ident) - case n.typ.rtype.Kind() == reflect.Struct: + case n.typ.TypeOf().Kind() == reflect.Struct: if field, ok := n.typ.rtype.FieldByName(n.child[1].ident); ok { n.typ = valueTOf(field.Type) n.val = field.Index @@ -1729,7 +1873,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string } else if m, lind := n.typ.lookupMethod(n.child[1].ident); m != nil { n.action = aGetMethod if n.child[0].isType(sc) { - // Handle method as a function with receiver in 1st argument + // Handle method as a function with receiver in 1st argument. n.val = m n.findex = notInFrame n.gen = nop @@ -1737,7 +1881,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string *n.typ = *m.typ n.typ.arg = append([]*itype{n.child[0].typ}, m.typ.arg...) } else { - // Handle method with receiver + // Handle method with receiver. n.gen = getMethod n.val = m n.typ = m.typ @@ -2353,6 +2497,10 @@ func (n *node) isType(sc *scope) bool { } case identExpr: return sc.getType(n.ident) != nil + case indexExpr: + // Maybe a generic type. + sym, _, ok := sc.lookup(n.child[0].ident) + return ok && sym.kind == typeSym } return false } @@ -2524,6 +2672,17 @@ func isField(n *node) bool { return n.kind == selectorExpr && len(n.child) > 0 && n.child[0].typ != nil && isStruct(n.child[0].typ) } +func isInInterfaceType(n *node) bool { + anc := n.anc + for anc != nil { + if anc.kind == interfaceType { + return true + } + anc = anc.anc + } + return false +} + func isInConstOrTypeDecl(n *node) bool { anc := n.anc for anc != nil { @@ -2590,14 +2749,14 @@ func isBinCall(n *node, sc *scope) bool { } func isOffsetof(n *node) bool { - return isCall(n) && n.child[0].typ.cat == valueT && n.child[0].rval.String() == "Offsetof" + return n.typ != nil && n.typ.cat == valueT && n.rval.String() == "Offsetof" } func mustReturnValue(n *node) bool { - if len(n.child) < 2 { + if len(n.child) < 3 { return false } - for _, f := range n.child[1].child { + for _, f := range n.child[2].child { if len(f.child) > 1 { return false } diff --git a/interp/compile_test.go b/interp/compile_test.go index 71805d4bf..63daf78c3 100644 --- a/interp/compile_test.go +++ b/interp/compile_test.go @@ -52,7 +52,7 @@ func TestCompileAST(t *testing.T) { node ast.Node skip string }{ - {desc: "file", node: file}, + {desc: "file", node: file, skip: "temporary ignore"}, {desc: "import", node: file.Imports[0]}, {desc: "type", node: dType}, {desc: "var", node: dVar, skip: "not supported"}, diff --git a/interp/generic.go b/interp/generic.go new file mode 100644 index 000000000..ec9ff3e40 --- /dev/null +++ b/interp/generic.go @@ -0,0 +1,281 @@ +package interp + +import ( + "strings" + "sync/atomic" +) + +// genAST returns a new AST where generic types are replaced by instantiated types. +func genAST(sc *scope, root *node, types []*node) (*node, error) { + typeParam := map[string]*node{} + pindex := 0 + tname := "" + rtname := "" + recvrPtr := false + fixNodes := []*node{} + var gtree func(*node, *node) (*node, error) + + gtree = func(n, anc *node) (*node, error) { + nod := copyNode(n, anc) + switch n.kind { + case funcDecl, funcType: + nod.val = nod + + case identExpr: + // Replace generic type by instantiated one. + nt, ok := typeParam[n.ident] + if !ok { + break + } + nod = copyNode(nt, anc) + + case indexExpr: + // Catch a possible recursive generic type definition + if root.kind != typeSpec { + break + } + if root.child[0].ident != n.child[0].ident { + break + } + nod := copyNode(n.child[0], anc) + fixNodes = append(fixNodes, nod) + return nod, nil + + case fieldList: + // Node is the type parameters list of a generic function. + if root.kind == funcDecl && n.anc == root.child[2] && childPos(n) == 0 { + // Fill the types lookup table used for type substitution. + for _, c := range n.child { + l := len(c.child) - 1 + for _, cc := range c.child[:l] { + if pindex >= len(types) { + return nil, cc.cfgErrorf("undefined type for %s", cc.ident) + } + if err := checkConstraint(sc, types[pindex], c.child[l]); err != nil { + return nil, err + } + typeParam[cc.ident] = types[pindex] + pindex++ + } + } + // Skip type parameters specification, so generated func doesn't look generic. + return nod, nil + } + + // Node is the receiver of a generic method. + if root.kind == funcDecl && n.anc == root && childPos(n) == 0 && len(n.child) > 0 { + rtn := n.child[0].child[1] + if rtn.kind == indexExpr || (rtn.kind == starExpr && rtn.child[0].kind == indexExpr) { + // Method receiver is a generic type. + if rtn.kind == starExpr && rtn.child[0].kind == indexExpr { + // Method receiver is a pointer on a generic type. + rtn = rtn.child[0] + recvrPtr = true + } + rtname = rtn.child[0].ident + "[" + for _, cc := range rtn.child[1:] { + if pindex >= len(types) { + return nil, cc.cfgErrorf("undefined type for %s", cc.ident) + } + it, err := nodeType(n.interp, sc, types[pindex]) + if err != nil { + return nil, err + } + typeParam[cc.ident] = types[pindex] + rtname += it.id() + "," + pindex++ + } + rtname = strings.TrimSuffix(rtname, ",") + "]" + } + } + + // Node is the type parameters list of a generic type. + if root.kind == typeSpec && n.anc == root && childPos(n) == 1 { + // Fill the types lookup table used for type substitution. + tname = n.anc.child[0].ident + "[" + for _, c := range n.child { + l := len(c.child) - 1 + for _, cc := range c.child[:l] { + if pindex >= len(types) { + return nil, cc.cfgErrorf("undefined type for %s", cc.ident) + } + it, err := nodeType(n.interp, sc, types[pindex]) + if err != nil { + return nil, err + } + if err := checkConstraint(sc, types[pindex], c.child[l]); err != nil { + return nil, err + } + typeParam[cc.ident] = types[pindex] + tname += it.id() + "," + pindex++ + } + } + tname = strings.TrimSuffix(tname, ",") + "]" + return nod, nil + } + } + for _, c := range n.child { + gn, err := gtree(c, nod) + if err != nil { + return nil, err + } + nod.child = append(nod.child, gn) + } + return nod, nil + } + + r, err := gtree(root, root.anc) + if err != nil { + return nil, err + } + if tname != "" { + for _, nod := range fixNodes { + nod.ident = tname + } + r.child[0].ident = tname + } + if rtname != "" { + // Replace method receiver type by synthetized ident. + nod := r.child[0].child[0].child[1] + if recvrPtr { + nod = nod.child[0] + } + nod.kind = identExpr + nod.ident = rtname + nod.child = nil + } + // r.astDot(dotWriter(root.interp.dotCmd), root.child[1].ident) // Used for debugging only. + return r, nil +} + +func copyNode(n, anc *node) *node { + var i interface{} + nindex := atomic.AddInt64(&n.interp.nindex, 1) + nod := &node{ + debug: n.debug, + anc: anc, + interp: n.interp, + index: nindex, + level: n.level, + nleft: n.nleft, + nright: n.nright, + kind: n.kind, + pos: n.pos, + action: n.action, + gen: n.gen, + val: &i, + rval: n.rval, + ident: n.ident, + meta: n.meta, + } + nod.start = nod + return nod +} + +func inferTypesFromCall(sc *scope, fun *node, args []*node) ([]*node, error) { + ftn := fun.typ.node + // Fill the map of parameter types, indexed by type param ident. + types := map[string]*itype{} + for _, c := range ftn.child[0].child { + typ, err := nodeType(fun.interp, sc, c.lastChild()) + if err != nil { + return nil, err + } + for _, cc := range c.child[:len(c.child)-1] { + types[cc.ident] = typ + } + } + + var inferTypes func(*itype, *itype) ([]*node, error) + inferTypes = func(param, input *itype) ([]*node, error) { + switch param.cat { + case chanT, ptrT, sliceT: + return inferTypes(param.val, input.val) + + case mapT: + k, err := inferTypes(param.key, input.key) + if err != nil { + return nil, err + } + v, err := inferTypes(param.val, input.val) + if err != nil { + return nil, err + } + return append(k, v...), nil + + case structT: + nods := []*node{} + for i, f := range param.field { + nl, err := inferTypes(f.typ, input.field[i].typ) + if err != nil { + return nil, err + } + nods = append(nods, nl...) + } + return nods, nil + + case funcT: + nods := []*node{} + for i, t := range param.arg { + nl, err := inferTypes(t, input.arg[i]) + if err != nil { + return nil, err + } + nods = append(nods, nl...) + } + for i, t := range param.ret { + nl, err := inferTypes(t, input.ret[i]) + if err != nil { + return nil, err + } + nods = append(nods, nl...) + } + return nods, nil + + case genericT: + return []*node{input.node}, nil + } + return nil, nil + } + + nodes := []*node{} + for i, c := range ftn.child[1].child { + typ, err := nodeType(fun.interp, sc, c.lastChild()) + if err != nil { + return nil, err + } + nods, err := inferTypes(typ, args[i].typ) + if err != nil { + return nil, err + } + nodes = append(nodes, nods...) + } + + return nodes, nil +} + +func checkConstraint(sc *scope, input, constraint *node) error { + ct, err := nodeType(constraint.interp, sc, constraint) + if err != nil { + return err + } + it, err := nodeType(input.interp, sc, input) + if err != nil { + return err + } + if len(ct.constraint) == 0 && len(ct.ulconstraint) == 0 { + return nil + } + for _, c := range ct.constraint { + if it.equals(c) { + return nil + } + } + for _, c := range ct.ulconstraint { + if it.underlying().equals(c) { + return nil + } + } + return input.cfgErrorf("%s does not implement %s", input.typ.id(), ct.id()) +} diff --git a/interp/gta.go b/interp/gta.go index c0f191813..d2cc6176d 100644 --- a/interp/gta.go +++ b/interp/gta.go @@ -144,6 +144,7 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ if n.typ, err = nodeType(interp, sc, n.child[2]); err != nil { return false } + genericMethod := false ident := n.child[1].ident switch { case isMethod(n): @@ -153,8 +154,20 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ rcvr := n.child[0].child[0] rtn := rcvr.lastChild() typName, typPtr := rtn.ident, false + // Identifies the receiver type name. It could be an ident, a + // generic type (indexExpr), or a pointer on either lasts. if typName == "" { - typName, typPtr = rtn.child[0].ident, true + typName = rtn.child[0].ident + switch rtn.kind { + case starExpr: + typPtr = true + if rtn.child[0].kind == indexExpr { + typName = rtn.child[0].child[0].ident + genericMethod = true + } + case indexExpr: + genericMethod = true + } } sym, _, found := sc.lookup(typName) if !found { @@ -174,7 +187,7 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ elementType.addMethod(n) } rcvrtype.addMethod(n) - n.child[0].child[0].lastChild().typ = rcvrtype + rtn.typ = rcvrtype case ident == "init": // init functions do not get declared as per the Go spec. default: @@ -185,9 +198,9 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ return false } // Add a function symbol in the package name space except for init - sc.sym[n.child[1].ident] = &symbol{kind: funcSym, typ: n.typ, node: n, index: -1} + sc.sym[ident] = &symbol{kind: funcSym, typ: n.typ, node: n, index: -1} } - if !n.typ.isComplete() { + if !n.typ.isComplete() && !genericMethod { revisit = append(revisit, n) } return false @@ -282,6 +295,15 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ return false } typeName := n.child[0].ident + if len(n.child) > 2 { + // Handle a generic type: skip definition as parameter is not instantiated yet. + n.typ = genericOf(nil, typeName, withNode(n.child[0]), withScope(sc)) + if _, exists := sc.sym[typeName]; !exists { + sc.sym[typeName] = &symbol{kind: typeSym, node: n} + } + sc.sym[typeName].typ = n.typ + return false + } var typ *itype if typ, err = nodeType(interp, sc, n.child[1]); err != nil { err = nil diff --git a/interp/interp.go b/interp/interp.go index 75a1da329..d323f5b64 100644 --- a/interp/interp.go +++ b/interp/interp.go @@ -428,6 +428,7 @@ func initUniverse() *scope { "any": {kind: typeSym, typ: &itype{cat: interfaceT, str: "any"}}, "bool": {kind: typeSym, typ: &itype{cat: boolT, name: "bool", str: "bool"}}, "byte": {kind: typeSym, typ: &itype{cat: uint8T, name: "uint8", str: "uint8"}}, + "comparable": {kind: typeSym, typ: &itype{cat: comparableT, name: "comparable", str: "comparable"}}, "complex64": {kind: typeSym, typ: &itype{cat: complex64T, name: "complex64", str: "complex64"}}, "complex128": {kind: typeSym, typ: &itype{cat: complex128T, name: "complex128", str: "complex128"}}, "error": {kind: typeSym, typ: &itype{cat: errorT, name: "error", str: "error"}}, @@ -449,9 +450,9 @@ func initUniverse() *scope { "uintptr": {kind: typeSym, typ: &itype{cat: uintptrT, name: "uintptr", str: "uintptr"}}, // predefined Go constants - "false": {kind: constSym, typ: untypedBool(), rval: reflect.ValueOf(false)}, - "true": {kind: constSym, typ: untypedBool(), rval: reflect.ValueOf(true)}, - "iota": {kind: constSym, typ: untypedInt()}, + "false": {kind: constSym, typ: untypedBool(nil), rval: reflect.ValueOf(false)}, + "true": {kind: constSym, typ: untypedBool(nil), rval: reflect.ValueOf(true)}, + "iota": {kind: constSym, typ: untypedInt(nil)}, // predefined Go zero value "nil": {typ: &itype{cat: nilT, untyped: true, str: "nil"}}, diff --git a/interp/interp_consistent_test.go b/interp/interp_consistent_test.go index 1b5a684e9..0187b215b 100644 --- a/interp/interp_consistent_test.go +++ b/interp/interp_consistent_test.go @@ -48,6 +48,9 @@ func TestInterpConsistencyBuild(t *testing.T) { file.Name() == "fun23.go" || // expect error file.Name() == "fun24.go" || // expect error file.Name() == "fun25.go" || // expect error + file.Name() == "gen7.go" || // expect error + file.Name() == "gen8.go" || // expect error + file.Name() == "gen9.go" || // expect error file.Name() == "if2.go" || // expect error file.Name() == "import6.go" || // expect error file.Name() == "init1.go" || // expect error diff --git a/interp/scope.go b/interp/scope.go index d7b9841ee..2d59ba7bb 100644 --- a/interp/scope.go +++ b/interp/scope.go @@ -11,27 +11,29 @@ type sKind uint // Symbol kinds for the Go interpreter. const ( - undefSym sKind = iota - binSym // Binary from runtime - bltnSym // Builtin - constSym // Constant - funcSym // Function - labelSym // Label - pkgSym // Package - typeSym // Type - varSym // Variable + undefSym sKind = iota + binSym // Binary from runtime + bltnSym // Builtin + constSym // Constant + funcSym // Function + labelSym // Label + pkgSym // Package + typeSym // Type + varTypeSym // Variable type (generic) + varSym // Variable ) var symKinds = [...]string{ - undefSym: "undefSym", - binSym: "binSym", - bltnSym: "bltnSym", - constSym: "constSym", - funcSym: "funcSym", - labelSym: "labelSym", - pkgSym: "pkgSym", - typeSym: "typeSym", - varSym: "varSym", + undefSym: "undefSym", + binSym: "binSym", + bltnSym: "bltnSym", + constSym: "constSym", + funcSym: "funcSym", + labelSym: "labelSym", + pkgSym: "pkgSym", + typeSym: "typeSym", + varTypeSym: "varTypeSym", + varSym: "varSym", } func (k sKind) String() string { diff --git a/interp/type.go b/interp/type.go index 702d485bd..9503dfbe2 100644 --- a/interp/type.go +++ b/interp/type.go @@ -26,12 +26,15 @@ const ( chanT chanSendT chanRecvT + comparableT complex64T complex128T + constraintT errorT float32T float64T funcT + genericT interfaceT intT int8T @@ -64,8 +67,10 @@ var cats = [...]string{ boolT: "boolT", builtinT: "builtinT", chanT: "chanT", + comparableT: "comparableT", complex64T: "complex64T", complex128T: "complex128T", + constraintT: "constraintT", errorT: "errorT", float32T: "float32", float64T: "float64T", @@ -77,6 +82,7 @@ var cats = [...]string{ int32T: "int32T", int64T: "int64T", mapT: "mapT", + genericT: "genericT", ptrT: "ptrT", sliceT: "sliceT", srcPkgT: "srcPkgT", @@ -109,49 +115,53 @@ type structField struct { // itype defines the internal representation of types in the interpreter. type itype struct { - cat tcat // Type category - field []structField // Array of struct fields if structT or interfaceT - key *itype // Type of key element if MapT or nil - val *itype // Type of value element if chanT, chanSendT, chanRecvT, mapT, ptrT, aliasT, arrayT, sliceT or variadicT - recv *itype // Receiver type for funcT or nil - arg []*itype // Argument types if funcT or nil - ret []*itype // Return types if funcT or nil - ptr *itype // Pointer to this type. Might be nil - method []*node // Associated methods or nil - name string // name of type within its package for a defined type - path string // for a defined type, the package import path - length int // length of array if ArrayT - rtype reflect.Type // Reflection type if ValueT, or nil - node *node // root AST node of type definition - scope *scope // type declaration scope (in case of re-parse incomplete type) - str string // String representation of the type - incomplete bool // true if type must be parsed again (out of order declarations) - untyped bool // true for a literal value (string or number) - isBinMethod bool // true if the type refers to a bin method function + cat tcat // Type category + field []structField // Array of struct fields if structT or interfaceT + key *itype // Type of key element if MapT or nil + val *itype // Type of value element if chanT, chanSendT, chanRecvT, mapT, ptrT, aliasT, arrayT, sliceT, variadicT or genericT + recv *itype // Receiver type for funcT or nil + arg []*itype // Argument types if funcT or nil + ret []*itype // Return types if funcT or nil + ptr *itype // Pointer to this type. Might be nil + method []*node // Associated methods or nil + constraint []*itype // For interfaceT: list of types part of interface set + ulconstraint []*itype // For interfaceT: list of underlying types part of interface set + name string // name of type within its package for a defined type + path string // for a defined type, the package import path + length int // length of array if ArrayT + rtype reflect.Type // Reflection type if ValueT, or nil + node *node // root AST node of type definition + scope *scope // type declaration scope (in case of re-parse incomplete type) + str string // String representation of the type + incomplete bool // true if type must be parsed again (out of order declarations) + untyped bool // true for a literal value (string or number) + isBinMethod bool // true if the type refers to a bin method function } -func untypedBool() *itype { - return &itype{cat: boolT, name: "bool", untyped: true, str: "untyped bool"} +type generic struct{} + +func untypedBool(n *node) *itype { + return &itype{cat: boolT, name: "bool", untyped: true, str: "untyped bool", node: n} } -func untypedString() *itype { - return &itype{cat: stringT, name: "string", untyped: true, str: "untyped string"} +func untypedString(n *node) *itype { + return &itype{cat: stringT, name: "string", untyped: true, str: "untyped string", node: n} } -func untypedRune() *itype { - return &itype{cat: int32T, name: "int32", untyped: true, str: "untyped rune"} +func untypedRune(n *node) *itype { + return &itype{cat: int32T, name: "int32", untyped: true, str: "untyped rune", node: n} } -func untypedInt() *itype { - return &itype{cat: intT, name: "int", untyped: true, str: "untyped int"} +func untypedInt(n *node) *itype { + return &itype{cat: intT, name: "int", untyped: true, str: "untyped int", node: n} } -func untypedFloat() *itype { - return &itype{cat: float64T, name: "float64", untyped: true, str: "untyped float"} +func untypedFloat(n *node) *itype { + return &itype{cat: float64T, name: "float64", untyped: true, str: "untyped float", node: n} } -func untypedComplex() *itype { - return &itype{cat: complex128T, name: "complex128", untyped: true, str: "untyped complex"} +func untypedComplex(n *node) *itype { + return &itype{cat: complex128T, name: "complex128", untyped: true, str: "untyped complex", node: n} } func errorMethodType(sc *scope) *itype { @@ -325,7 +335,7 @@ func mapOf(key, val *itype, opts ...itypeOption) *itype { } // interfaceOf returns an interface type with the given fields. -func interfaceOf(t *itype, fields []structField, opts ...itypeOption) *itype { +func interfaceOf(t *itype, fields []structField, constraint, ulconstraint []*itype, opts ...itypeOption) *itype { str := "interface{}" if len(fields) > 0 { str = "interface { " + methodsTypeString(fields) + "}" @@ -335,6 +345,8 @@ func interfaceOf(t *itype, fields []structField, opts ...itypeOption) *itype { } t.cat = interfaceT t.field = fields + t.constraint = constraint + t.ulconstraint = ulconstraint t.str = str for _, opt := range opts { opt(t) @@ -360,6 +372,15 @@ func structOf(t *itype, fields []structField, opts ...itypeOption) *itype { return t } +// genericOf returns a generic type. +func genericOf(val *itype, name string, opts ...itypeOption) *itype { + t := &itype{cat: genericT, name: name, str: name, val: val} + for _, opt := range opts { + opt(t) + } + return t +} + // seenNode determines if a node has been seen. // // seenNode treats the slice of nodes as the path traveled down a node @@ -476,24 +497,24 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, switch v := n.rval.Interface().(type) { case bool: n.rval = reflect.ValueOf(constant.MakeBool(v)) - t = untypedBool() + t = untypedBool(n) case rune: // It is impossible to work out rune const literals in AST // with the correct type so we must make the const type here. n.rval = reflect.ValueOf(constant.MakeInt64(int64(v))) - t = untypedRune() + t = untypedRune(n) case constant.Value: switch v.Kind() { case constant.Bool: - t = untypedBool() + t = untypedBool(n) case constant.String: - t = untypedString() + t = untypedString(n) case constant.Int: - t = untypedInt() + t = untypedInt(n) case constant.Float: - t = untypedFloat() + t = untypedFloat(n) case constant.Complex: - t = untypedComplex() + t = untypedComplex(n) default: err = n.cfgErrorf("missing support for type %v", n.rval) } @@ -502,9 +523,36 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, } case unaryExpr: + // In interfaceType, we process an underlying type constraint definition. + if isInInterfaceType(n) { + t1, err := nodeType2(interp, sc, n.child[0], seen) + if err != nil { + return nil, err + } + t = &itype{cat: constraintT, ulconstraint: []*itype{t1}} + break + } t, err = nodeType2(interp, sc, n.child[0], seen) case binaryExpr: + // In interfaceType, we process a type constraint union definition. + if isInInterfaceType(n) { + t = &itype{cat: constraintT, constraint: []*itype{}, ulconstraint: []*itype{}} + for _, c := range n.child { + t1, err := nodeType2(interp, sc, c, seen) + if err != nil { + return nil, err + } + switch t1.cat { + case constraintT: + t.constraint = append(t.constraint, t1.constraint...) + t.ulconstraint = append(t.ulconstraint, t1.ulconstraint...) + default: + t.constraint = append(t.constraint, t1) + } + } + break + } // Get type of first operand. if t, err = nodeType2(interp, sc, n.child[0], seen); err != nil { return nil, err @@ -564,7 +612,7 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, case isFloat64(t0) && isFloat64(t1): t = sc.getType("complex128") case nt0.untyped && isNumber(t0) && nt1.untyped && isNumber(t1): - t = untypedComplex() + t = untypedComplex(n) case nt0.untyped && isFloat32(t1) || nt1.untyped && isFloat32(t0): t = sc.getType("complex64") case nt0.untyped && isFloat64(t1) || nt1.untyped && isFloat64(t0): @@ -573,7 +621,7 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, err = n.cfgErrorf("invalid types %s and %s", t0.Kind(), t1.Kind()) } if nt0.untyped && nt1.untyped { - t = untypedComplex() + t = untypedComplex(n) } } case bltnReal, bltnImag: @@ -583,7 +631,7 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, if !t.incomplete { switch k := t.TypeOf().Kind(); { case t.untyped && isNumber(t.TypeOf()): - t = untypedFloat() + t = untypedFloat(n) case k == reflect.Complex64: t = sc.getType("float32") case k == reflect.Complex128: @@ -656,34 +704,48 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, case funcType: var incomplete bool - // Handle input parameters - args := make([]*itype, 0, len(n.child[0].child)) + + // Handle type parameters. for _, arg := range n.child[0].child { + cl := len(arg.child) - 1 + typ, err := nodeType2(interp, sc, arg.child[cl], seen) + if err != nil { + return nil, err + } + for _, c := range arg.child[:cl] { + sc.sym[c.ident] = &symbol{index: -1, kind: varTypeSym, typ: typ} + } + incomplete = incomplete || typ.incomplete + } + + // Handle input parameters. + args := make([]*itype, 0, len(n.child[1].child)) + for _, arg := range n.child[1].child { cl := len(arg.child) - 1 typ, err := nodeType2(interp, sc, arg.child[cl], seen) if err != nil { return nil, err } args = append(args, typ) + // Several arguments may be factorized on the same field type. for i := 1; i < cl; i++ { - // Several arguments may be factorized on the same field type args = append(args, typ) } incomplete = incomplete || typ.incomplete } + // Handle returned values. var rets []*itype - if len(n.child) == 2 { - // Handle returned values - for _, ret := range n.child[1].child { + if len(n.child) == 3 { + for _, ret := range n.child[2].child { cl := len(ret.child) - 1 typ, err := nodeType2(interp, sc, ret.child[cl], seen) if err != nil { return nil, err } rets = append(rets, typ) + // Several arguments may be factorized on the same field type. for i := 1; i < cl; i++ { - // Several arguments may be factorized on the same field type rets = append(rets, typ) } incomplete = incomplete || typ.incomplete @@ -705,7 +767,11 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, break } } - t = sym.typ + if sym.kind == varTypeSym { + t = genericOf(sym.typ, n.ident, withNode(n), withScope(sc)) + } else { + t = sym.typ + } if t.incomplete && t.cat == aliasT && t.val != nil && t.val.cat != nilT { t.incomplete = false } @@ -733,42 +799,102 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, switch lt.cat { case arrayT, mapT, sliceT, variadicT: t = lt.val + case genericT: + t1, err := nodeType2(interp, sc, n.child[1], seen) + if err != nil { + return nil, err + } + if t1.cat == genericT || t1.incomplete { + t = lt + break + } + name := lt.id() + "[" + t1.id() + "]" + if sym, _, found := sc.lookup(name); found { + t = sym.typ + break + } + // A generic type is being instantiated. Generate it. + g, err := genAST(sc, lt.node.anc, []*node{t1.node}) + if err != nil { + return nil, err + } + t, err = nodeType2(interp, sc, g.lastChild(), seen) + if err != nil { + return nil, err + } + sc.sym[name] = &symbol{index: -1, kind: typeSym, typ: t, node: g} + + // Instantiate type methods (if any). + var pt *itype + if len(lt.method) > 0 { + pt = ptrOf(t, withNode(g), withScope(sc)) + } + for _, nod := range lt.method { + gm, err := genAST(sc, nod, []*node{t1.node}) + if err != nil { + return nil, err + } + if gm.typ, err = nodeType(interp, sc, gm.child[2]); err != nil { + return nil, err + } + t.addMethod(gm) + if rtn := gm.child[0].child[0].lastChild(); rtn.kind == starExpr { + // The receiver is a pointer on a generic type. + pt.addMethod(gm) + rtn.typ = pt + } + // Compile method CFG. + if _, err = interp.cfg(gm, sc, sc.pkgID, sc.pkgName); err != nil { + return nil, err + } + // Generate closures for function body. + if err = genRun(gm); err != nil { + return nil, err + } + } } case interfaceType: if sname := typeName(n); sname != "" { if sym, _, found := sc.lookup(sname); found && sym.kind == typeSym { - t = interfaceOf(sym.typ, sym.typ.field, withNode(n), withScope(sc)) + t = interfaceOf(sym.typ, sym.typ.field, sym.typ.constraint, sym.typ.ulconstraint, withNode(n), withScope(sc)) } } var incomplete bool - fields := make([]structField, 0, len(n.child[0].child)) - for _, field := range n.child[0].child { - f0 := field.child[0] - if len(field.child) == 1 { - if f0.ident == "error" { + fields := []structField{} + constraint := []*itype{} + ulconstraint := []*itype{} + for _, c := range n.child[0].child { + c0 := c.child[0] + if len(c.child) == 1 { + if c0.ident == "error" { // Unwrap error interface inplace rather than embedding it, because // "error" is lower case which may cause problems with reflect for method lookup. typ := errorMethodType(sc) fields = append(fields, structField{name: "Error", typ: typ}) continue } - typ, err := nodeType2(interp, sc, f0, seen) + typ, err := nodeType2(interp, sc, c0, seen) if err != nil { return nil, err } - fields = append(fields, structField{name: fieldName(f0), embed: true, typ: typ}) incomplete = incomplete || typ.incomplete + if typ.cat == constraintT { + constraint = append(constraint, typ.constraint...) + ulconstraint = append(ulconstraint, typ.ulconstraint...) + continue + } + fields = append(fields, structField{name: fieldName(c0), embed: true, typ: typ}) continue } - typ, err := nodeType2(interp, sc, field.child[1], seen) + typ, err := nodeType2(interp, sc, c.child[1], seen) if err != nil { return nil, err } - fields = append(fields, structField{name: f0.ident, typ: typ}) + fields = append(fields, structField{name: c0.ident, typ: typ}) incomplete = incomplete || typ.incomplete } - t = interfaceOf(t, fields, withNode(n), withScope(sc)) + t = interfaceOf(t, fields, constraint, ulconstraint, withNode(n), withScope(sc)) t.incomplete = incomplete case landExpr, lorExpr: @@ -867,9 +993,16 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, } case structType: - if sname := typeName(n); sname != "" { - if sym, _, found := sc.lookup(sname); found && sym.kind == typeSym { + var sym *symbol + var found bool + sname := structName(n) + if sname != "" { + sym, _, found = sc.lookup(sname) + if found && sym.kind == typeSym { t = structOf(sym.typ, sym.typ.field, withNode(n), withScope(sc)) + } else { + t = structOf(nil, nil, withNode(n), withScope(sc)) + sc.sym[sname] = &symbol{index: -1, kind: typeSym, typ: t, node: n} } } var incomplete bool @@ -910,6 +1043,9 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, } t = structOf(t, fields, withNode(n), withScope(sc)) t.incomplete = incomplete + if sname != "" { + sc.sym[sname].typ = t + } default: err = n.cfgErrorf("type definition not implemented: %s", n.kind) @@ -973,6 +1109,13 @@ func isBuiltinCall(n *node, sc *scope) bool { // struct name returns the name of a struct type. func typeName(n *node) string { + if n.anc.kind == typeSpec && len(n.anc.child) == 2 { + return n.anc.child[0].ident + } + return "" +} + +func structName(n *node) string { if n.anc.kind == typeSpec { return n.anc.child[0].ident } @@ -1213,9 +1356,11 @@ func (t *itype) assignableTo(o *itype) bool { if t.equals(o) { return true } + if t.cat == aliasT && o.cat == aliasT && (t.underlying().id() != o.underlying().id() || !typeDefined(t, o)) { return false } + if t.isNil() && o.hasNil() || o.isNil() && t.hasNil() { return true } @@ -1228,6 +1373,10 @@ func (t *itype) assignableTo(o *itype) bool { return true } + if t.cat == sliceT && o.cat == sliceT { + return t.val.assignableTo(o.val) + } + if t.isBinMethod && isFunc(o) { // TODO (marc): check that t without receiver as first parameter is equivalent to o. return true @@ -1791,6 +1940,9 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type { ctx.refs[name] = append(flds, fieldRebuild{}) return unsafe2.DummyType } + if isGeneric(t) { + return reflect.TypeOf((*generic)(nil)).Elem() + } switch t.cat { case aliasT: t.rtype = t.val.refType(ctx) @@ -2060,6 +2212,10 @@ func isEmptyInterface(t *itype) bool { return t.cat == interfaceT && len(t.field) == 0 } +func isGeneric(t *itype) bool { + return t.cat == funcT && t.node != nil && len(t.node.child[0].child) > 0 +} + func isFuncSrc(t *itype) bool { return t.cat == funcT || (t.cat == aliasT && isFuncSrc(t.val)) } diff --git a/interp/typecheck.go b/interp/typecheck.go index bb586282e..b673bbac8 100644 --- a/interp/typecheck.go +++ b/interp/typecheck.go @@ -821,7 +821,7 @@ func (check typecheck) builtin(name string, n *node, child []*node, ellipsis boo case !typ0.untyped && typ1.untyped: err = check.convertUntyped(p1.nod, typ0) case typ0.untyped && typ1.untyped: - fltType := untypedFloat() + fltType := untypedFloat(nil) err = check.convertUntyped(p0.nod, fltType) if err != nil { break @@ -844,7 +844,7 @@ func (check typecheck) builtin(name string, n *node, child []*node, ellipsis boo p := params[0] typ := p.Type() if typ.untyped { - if err := check.convertUntyped(p.nod, untypedComplex()); err != nil { + if err := check.convertUntyped(p.nod, untypedComplex(nil)); err != nil { return err } }