diff --git a/ast.go b/ast.go index 8ff3b9d3..885025ab 100644 --- a/ast.go +++ b/ast.go @@ -228,12 +228,10 @@ var ( ) func toInterface(pkg *Package, t *types.Interface) ast.Expr { - if enableTypeParams { - if t == universeAny.Type() { - return ast.NewIdent("any") - } else if interfaceIsImplicit(t) && t.NumEmbeddeds() == 1 { - return toType(pkg, t.EmbeddedType(0)) - } + if t == universeAny.Type() { + return ast.NewIdent("any") + } else if interfaceIsImplicit(t) && t.NumEmbeddeds() == 1 { + return toType(pkg, t.EmbeddedType(0)) } var flds []*ast.Field for i, n := 0, t.NumEmbeddeds(); i < n; i++ { @@ -593,20 +591,23 @@ func matchFuncCall(pkg *Package, fn *internal.Elem, args []*internal.Elem, flags var cval constant.Value retry: switch t := fnType.(type) { - case *inferFuncType: - sig = t.InstanceWithArgs(args, flags) - if debugMatch { - log.Println("==> InferFunc", sig) - } case *types.Signature: - if enableTypeParams && funcHasTypeParams(t) { - rt, err := inferFunc(pkg, fn, t, nil, args, flags) - if err != nil { - pkg.cb.panicCodeError(getSrcPos(fn.Src), err.Error()) - } - sig = rt.(*types.Signature) - if debugMatch { - log.Println("==> InferFunc", sig) + if t.TypeParams() != nil { + if (flags & instrFlagGopxFunc) == 0 { + rt, err := inferFunc(pkg, fn, t, nil, args, flags) + if err != nil { + pkg.cb.panicCodeError(getSrcPos(fn.Src), err.Error()) + } + sig = rt.(*types.Signature) + if debugMatch { + log.Println("==> InferFunc", sig) + } + } else { + fn, sig, args, err = boundTypeParams(pkg, fn, t, args) + if err != nil { + return + } + flags &= ^instrFlagGopxFunc } break } @@ -683,8 +684,6 @@ retry: } else { sig = t } - case *TypeType: // type convert - return matchTypeCast(pkg, t.Type(), fn, args, flags) case *TemplateSignature: // template function sig, it = t.instantiate() if t.isUnaryOp() { @@ -696,9 +695,18 @@ retry: } case *TyInstruction: return t.instr.Call(pkg, args, flags, fn.Src) + case *TypeType: // type convert + return matchTypeCast(pkg, t.Type(), fn, args, flags) + case *TyOverloadNamed: + return matchOverloadNamedTypeCast(pkg, t.Obj, fn.Src, args, flags) case *types.Named: fnType = pkg.cb.getUnderlying(t) goto retry + case *inferFuncType: + sig = t.InstanceWithArgs(args, flags) + if debugMatch { + log.Println("==> InferFunc", sig) + } default: src, pos := pkg.cb.loadExpr(fn.Src) pkg.cb.panicCodeErrorf(pos, "cannot call non-function %s (type %v)", src, fn.Type) @@ -736,6 +744,17 @@ retry: }, nil } +func matchOverloadNamedTypeCast(pkg *Package, t *types.TypeName, src ast.Node, args []*internal.Elem, flags InstrFlags) (ret *internal.Elem, err error) { + cast := gopxPrefix + t.Name() + "_Cast" + o := t.Pkg().Scope().Lookup(cast) + if o == nil { + err := pkg.cb.newCodeErrorf(getSrcPos(src), "typecast %v not found", t.Type()) + return nil, err + } + fn := toObject(pkg, o, src) + return matchFuncCall(pkg, fn, args, flags|instrFlagGopxFunc) +} + func matchTypeCast(pkg *Package, typ types.Type, fn *internal.Elem, args []*internal.Elem, flags InstrFlags) (ret *internal.Elem, err error) { fnVal := fn.Val switch typ.(type) { diff --git a/builtin_test.go b/builtin_test.go index 5368bfd9..f7aa11cf 100644 --- a/builtin_test.go +++ b/builtin_test.go @@ -44,6 +44,17 @@ func getConf() *Config { return &Config{Fset: fset, Importer: imp} } +func TestMatchOverloadNamedTypeCast(t *testing.T) { + pkg := NewPackage("", "foo", nil) + foo := types.NewPackage("github.com/bar/foo", "foo") + tn := types.NewTypeName(0, foo, "t", nil) + types.NewNamed(tn, types.Typ[types.Int], nil) + _, err := matchOverloadNamedTypeCast(pkg, tn, nil, nil, 0) + if err == nil || err.Error() != "-: typecast github.com/bar/foo.t not found" { + t.Fatal("TestMatchOverloadNamedTypeCast:", err) + } +} + func TestSetTypeParams(t *testing.T) { pkg := types.NewPackage("", "") tn := types.NewTypeName(0, pkg, "foo__1", nil) diff --git a/codebuild.go b/codebuild.go index 3a5efc29..94c89484 100644 --- a/codebuild.go +++ b/codebuild.go @@ -1165,7 +1165,7 @@ func (p *CodeBuilder) Index(nidx int, twoValue bool, src ...ast.Node) *CodeBuild log.Println("Index", nidx, twoValue) } args := p.stk.GetArgs(nidx + 1) - if enableTypeParams && nidx > 0 { + if nidx > 0 { if _, ok := args[1].Type.(*TypeType); ok { return p.instantiate(nidx, args, src...) } diff --git a/func.go b/func.go index 0a89d345..25b782d1 100644 --- a/func.go +++ b/func.go @@ -247,6 +247,7 @@ const ( instrFlagApproxType // restricts to all types whose underlying type is T instrFlagOpFunc // from callOpFunc + instrFlagGopxFunc // call Gopx_xxx functions ) type Instruction interface { diff --git a/import.go b/import.go index 4d4931e7..e257e645 100644 --- a/import.go +++ b/import.go @@ -274,6 +274,7 @@ func checkTemplateMethod(pkg *types.Package, name string, o types.Object) { const ( goptPrefix = "Gopt_" // template method gopoPrefix = "Gopo_" // overload function/method + gopxPrefix = "Gopx_" gopPackage = "GopPackage" ) diff --git a/type_ext.go b/type_ext.go index 3f0c790c..806836a0 100644 --- a/type_ext.go +++ b/type_ext.go @@ -57,14 +57,18 @@ func isTypeType(t types.Type) bool { type TyOverloadNamed struct { Types []*types.Named + Obj *types.TypeName } func (p *TyOverloadNamed) typeEx() {} func (p *TyOverloadNamed) Underlying() types.Type { return p } -func (p *TyOverloadNamed) String() string { return "TyOverloadNamed" } +func (p *TyOverloadNamed) String() string { return p.Obj.Name() } func NewOverloadNamed(pos token.Pos, pkg *types.Package, name string, typs ...*types.Named) *types.TypeName { - return types.NewTypeName(pos, pkg, name, &TyOverloadNamed{Types: typs}) + t := &TyOverloadNamed{Types: typs} + o := types.NewTypeName(pos, pkg, name, t) + t.Obj = o + return o } type TyInstruction struct { diff --git a/typeparams.go b/typeparams.go index b08671f8..39a3fab2 100644 --- a/typeparams.go +++ b/typeparams.go @@ -28,8 +28,6 @@ import ( // ---------------------------------------------------------------------------- -const enableTypeParams = true - type TypeParam = types.TypeParam type Union = types.Union type Term = types.Term @@ -354,10 +352,6 @@ func inferFuncTargs(pkg *Package, fn *internal.Elem, sig *types.Signature, targs return types.Instantiate(pkg.cb.ctxt, sig, targs, true) } -func funcHasTypeParams(t *types.Signature) bool { - return t.TypeParams() != nil -} - func toFieldListX(pkg *Package, t *types.TypeParamList) *ast.FieldList { if t == nil { return nil @@ -443,6 +437,34 @@ func interfaceIsImplicit(t *types.Interface) bool { // ---------------------------------------------------------------------------- +func boundTypeParams(p *Package, fn *Element, sig *types.Signature, args []*Element) (*Element, *types.Signature, []*Element, error) { + params := sig.TypeParams() + if n := params.Len(); n > 0 { + targs := make([]types.Type, n) + for i, arg := range args { + t, ok := arg.Type.(*TypeType) + if !ok { + src, pos := p.cb.loadExpr(arg.Src) + err := p.cb.newCodeErrorf(pos, "%s (type %v) is not a type", src, arg.Type) + return fn, sig, args, err + } + targs[i] = t.typ + } + ret, err := types.Instantiate(p.cb.ctxt, sig, targs, true) + if err != nil { + return fn, sig, args, err + } + indices := make([]ast.Expr, n) + for i, arg := range args { + indices[i] = arg.Val + } + fn = &Element{Val: &ast.IndexListExpr{X: fn.Val, Indices: indices}, Type: ret, Src: fn.Src} + sig = ret.(*types.Signature) + args = args[n:] + } + return fn, sig, args, nil +} + func (p *Package) Instantiate(orig types.Type, targs []types.Type, src ...ast.Node) types.Type { p.cb.ensureLoaded(orig) for _, targ := range targs { diff --git a/typeparams_test.go b/typeparams_test.go index be2d30f0..226cb6d3 100644 --- a/typeparams_test.go +++ b/typeparams_test.go @@ -1,6 +1,3 @@ -//go:build go1.18 -// +build go1.18 - /* Copyright 2022 The GoPlus Authors (goplus.org) Licensed under the Apache License, Version 2.0 (the "License"); @@ -44,6 +41,14 @@ type Var__0[T basetype] struct { type Var__1[T map[string]any] struct { val T } + +func Gopx_Var_Cast__0[T basetype]() *Var__0[T] { + return new(Var__0[T]) +} + +func Gopx_Var_Cast__1[T map[string]any]() *Var__1[T] { + return new(Var__1[T]) +} ` gt := newGoxTest() _, err := gt.LoadGoPackage("foo", "foo.go", src) @@ -69,6 +74,10 @@ type Var__1[T map[string]any] struct { ty2 := pkg.Instantiate(on, []types.Type{tyM}) pkg.NewTypeDefs().NewType("t1").InitType(pkg, ty1) pkg.NewTypeDefs().NewType("t2").InitType(pkg, ty2) + pkg.NewFunc(nil, "main", nil, nil, false).BodyStart(pkg). + Val(objVar).Typ(tyInt).Call(1).EndStmt(). + Val(objVar).Typ(tyM).Call(1).EndStmt(). + End() domTest(t, pkg, `package main @@ -76,15 +85,32 @@ import "foo" type t1 foo.Var__0[int] type t2 foo.Var__1[map[string]any] + +func main() { + foo.Gopx_Var_Cast__0[int]() + foo.Gopx_Var_Cast__1[map[string]any]() +} `) - defer func() { - if e := recover(); e == nil { - t.Fatal("TestOverloadNamed failed: no error?") - } + func() { + defer func() { + if e := recover(); e == nil { + t.Fatal("TestOverloadNamed failed: no error?") + } + }() + ty3 := pkg.Instantiate(on, []types.Type{gox.TyByte}) + pkg.NewTypeDefs().NewType("t3").InitType(pkg, ty3) + }() + func() { + defer func() { + if e := recover(); e != nil && e.(error).Error() != "-: 1 (type untyped int) is not a type" { + t.Fatal("TestOverloadNamed failed:", e) + } + }() + pkg.NewFunc(nil, "bar", nil, nil, false).BodyStart(pkg). + Val(objVar).Val(1, source("1")).Call(1).EndStmt(). + End() }() - ty3 := pkg.Instantiate(on, []types.Type{gox.TyByte}) - pkg.NewTypeDefs().NewType("t3").InitType(pkg, ty3) } func TestInstantiate(t *testing.T) {