Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

matchOverloadNamedTypeCast #342

Merged
merged 3 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 40 additions & 21 deletions ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,10 @@ var (
)

func toInterface(pkg *Package, t *types.Interface) ast.Expr {
if enableTypeParams {
if t == universeAny.Type() {
return ast.NewIdent("any")
} else if interfaceIsImplicit(t) && t.NumEmbeddeds() == 1 {
return toType(pkg, t.EmbeddedType(0))
}
if t == universeAny.Type() {
return ast.NewIdent("any")
} else if interfaceIsImplicit(t) && t.NumEmbeddeds() == 1 {
return toType(pkg, t.EmbeddedType(0))
}
var flds []*ast.Field
for i, n := 0, t.NumEmbeddeds(); i < n; i++ {
Expand Down Expand Up @@ -593,20 +591,23 @@ func matchFuncCall(pkg *Package, fn *internal.Elem, args []*internal.Elem, flags
var cval constant.Value
retry:
switch t := fnType.(type) {
case *inferFuncType:
sig = t.InstanceWithArgs(args, flags)
if debugMatch {
log.Println("==> InferFunc", sig)
}
case *types.Signature:
if enableTypeParams && funcHasTypeParams(t) {
rt, err := inferFunc(pkg, fn, t, nil, args, flags)
if err != nil {
pkg.cb.panicCodeError(getSrcPos(fn.Src), err.Error())
}
sig = rt.(*types.Signature)
if debugMatch {
log.Println("==> InferFunc", sig)
if t.TypeParams() != nil {
if (flags & instrFlagGopxFunc) == 0 {
rt, err := inferFunc(pkg, fn, t, nil, args, flags)
if err != nil {
pkg.cb.panicCodeError(getSrcPos(fn.Src), err.Error())
}
sig = rt.(*types.Signature)
if debugMatch {
log.Println("==> InferFunc", sig)
}
} else {
fn, sig, args, err = boundTypeParams(pkg, fn, t, args)
if err != nil {
return
}
flags &= ^instrFlagGopxFunc
}
break
}
Expand Down Expand Up @@ -683,8 +684,6 @@ retry:
} else {
sig = t
}
case *TypeType: // type convert
return matchTypeCast(pkg, t.Type(), fn, args, flags)
case *TemplateSignature: // template function
sig, it = t.instantiate()
if t.isUnaryOp() {
Expand All @@ -696,9 +695,18 @@ retry:
}
case *TyInstruction:
return t.instr.Call(pkg, args, flags, fn.Src)
case *TypeType: // type convert
return matchTypeCast(pkg, t.Type(), fn, args, flags)
case *TyOverloadNamed:
return matchOverloadNamedTypeCast(pkg, t.Obj, fn.Src, args, flags)
case *types.Named:
fnType = pkg.cb.getUnderlying(t)
goto retry
case *inferFuncType:
sig = t.InstanceWithArgs(args, flags)
if debugMatch {
log.Println("==> InferFunc", sig)
}
default:
src, pos := pkg.cb.loadExpr(fn.Src)
pkg.cb.panicCodeErrorf(pos, "cannot call non-function %s (type %v)", src, fn.Type)
Expand Down Expand Up @@ -736,6 +744,17 @@ retry:
}, nil
}

func matchOverloadNamedTypeCast(pkg *Package, t *types.TypeName, src ast.Node, args []*internal.Elem, flags InstrFlags) (ret *internal.Elem, err error) {
cast := gopxPrefix + t.Name() + "_Cast"
o := t.Pkg().Scope().Lookup(cast)
if o == nil {
err := pkg.cb.newCodeErrorf(getSrcPos(src), "typecast %v not found", t.Type())
return nil, err
}
fn := toObject(pkg, o, src)
return matchFuncCall(pkg, fn, args, flags|instrFlagGopxFunc)
}

func matchTypeCast(pkg *Package, typ types.Type, fn *internal.Elem, args []*internal.Elem, flags InstrFlags) (ret *internal.Elem, err error) {
fnVal := fn.Val
switch typ.(type) {
Expand Down
11 changes: 11 additions & 0 deletions builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ func getConf() *Config {
return &Config{Fset: fset, Importer: imp}
}

func TestMatchOverloadNamedTypeCast(t *testing.T) {
pkg := NewPackage("", "foo", nil)
foo := types.NewPackage("github.com/bar/foo", "foo")
tn := types.NewTypeName(0, foo, "t", nil)
types.NewNamed(tn, types.Typ[types.Int], nil)
_, err := matchOverloadNamedTypeCast(pkg, tn, nil, nil, 0)
if err == nil || err.Error() != "-: typecast github.com/bar/foo.t not found" {
t.Fatal("TestMatchOverloadNamedTypeCast:", err)
}
}

func TestSetTypeParams(t *testing.T) {
pkg := types.NewPackage("", "")
tn := types.NewTypeName(0, pkg, "foo__1", nil)
Expand Down
2 changes: 1 addition & 1 deletion codebuild.go
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ func (p *CodeBuilder) Index(nidx int, twoValue bool, src ...ast.Node) *CodeBuild
log.Println("Index", nidx, twoValue)
}
args := p.stk.GetArgs(nidx + 1)
if enableTypeParams && nidx > 0 {
if nidx > 0 {
if _, ok := args[1].Type.(*TypeType); ok {
return p.instantiate(nidx, args, src...)
}
Expand Down
1 change: 1 addition & 0 deletions func.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ const (

instrFlagApproxType // restricts to all types whose underlying type is T
instrFlagOpFunc // from callOpFunc
instrFlagGopxFunc // call Gopx_xxx functions
)

type Instruction interface {
Expand Down
1 change: 1 addition & 0 deletions import.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ func checkTemplateMethod(pkg *types.Package, name string, o types.Object) {
const (
goptPrefix = "Gopt_" // template method
gopoPrefix = "Gopo_" // overload function/method
gopxPrefix = "Gopx_"
gopPackage = "GopPackage"
)

Expand Down
8 changes: 6 additions & 2 deletions type_ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ func isTypeType(t types.Type) bool {

type TyOverloadNamed struct {
Types []*types.Named
Obj *types.TypeName
}

func (p *TyOverloadNamed) typeEx() {}
func (p *TyOverloadNamed) Underlying() types.Type { return p }
func (p *TyOverloadNamed) String() string { return "TyOverloadNamed" }
func (p *TyOverloadNamed) String() string { return p.Obj.Name() }

func NewOverloadNamed(pos token.Pos, pkg *types.Package, name string, typs ...*types.Named) *types.TypeName {
return types.NewTypeName(pos, pkg, name, &TyOverloadNamed{Types: typs})
t := &TyOverloadNamed{Types: typs}
o := types.NewTypeName(pos, pkg, name, t)
t.Obj = o
return o
}

type TyInstruction struct {
Expand Down
34 changes: 28 additions & 6 deletions typeparams.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ import (

// ----------------------------------------------------------------------------

const enableTypeParams = true

type TypeParam = types.TypeParam
type Union = types.Union
type Term = types.Term
Expand Down Expand Up @@ -354,10 +352,6 @@ func inferFuncTargs(pkg *Package, fn *internal.Elem, sig *types.Signature, targs
return types.Instantiate(pkg.cb.ctxt, sig, targs, true)
}

func funcHasTypeParams(t *types.Signature) bool {
return t.TypeParams() != nil
}

func toFieldListX(pkg *Package, t *types.TypeParamList) *ast.FieldList {
if t == nil {
return nil
Expand Down Expand Up @@ -443,6 +437,34 @@ func interfaceIsImplicit(t *types.Interface) bool {

// ----------------------------------------------------------------------------

func boundTypeParams(p *Package, fn *Element, sig *types.Signature, args []*Element) (*Element, *types.Signature, []*Element, error) {
params := sig.TypeParams()
if n := params.Len(); n > 0 {
targs := make([]types.Type, n)
for i, arg := range args {
t, ok := arg.Type.(*TypeType)
if !ok {
src, pos := p.cb.loadExpr(arg.Src)
err := p.cb.newCodeErrorf(pos, "%s (type %v) is not a type", src, arg.Type)
return fn, sig, args, err
}
targs[i] = t.typ
}
ret, err := types.Instantiate(p.cb.ctxt, sig, targs, true)
if err != nil {
return fn, sig, args, err
}
indices := make([]ast.Expr, n)
for i, arg := range args {
indices[i] = arg.Val
}
fn = &Element{Val: &ast.IndexListExpr{X: fn.Val, Indices: indices}, Type: ret, Src: fn.Src}
sig = ret.(*types.Signature)
args = args[n:]
}
return fn, sig, args, nil
}

func (p *Package) Instantiate(orig types.Type, targs []types.Type, src ...ast.Node) types.Type {
p.cb.ensureLoaded(orig)
for _, targ := range targs {
Expand Down
44 changes: 35 additions & 9 deletions typeparams_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
//go:build go1.18
// +build go1.18

/*
Copyright 2022 The GoPlus Authors (goplus.org)
Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -44,6 +41,14 @@ type Var__0[T basetype] struct {
type Var__1[T map[string]any] struct {
val T
}

func Gopx_Var_Cast__0[T basetype]() *Var__0[T] {
return new(Var__0[T])
}

func Gopx_Var_Cast__1[T map[string]any]() *Var__1[T] {
return new(Var__1[T])
}
`
gt := newGoxTest()
_, err := gt.LoadGoPackage("foo", "foo.go", src)
Expand All @@ -69,22 +74,43 @@ type Var__1[T map[string]any] struct {
ty2 := pkg.Instantiate(on, []types.Type{tyM})
pkg.NewTypeDefs().NewType("t1").InitType(pkg, ty1)
pkg.NewTypeDefs().NewType("t2").InitType(pkg, ty2)
pkg.NewFunc(nil, "main", nil, nil, false).BodyStart(pkg).
Val(objVar).Typ(tyInt).Call(1).EndStmt().
Val(objVar).Typ(tyM).Call(1).EndStmt().
End()

domTest(t, pkg, `package main

import "foo"

type t1 foo.Var__0[int]
type t2 foo.Var__1[map[string]any]

func main() {
foo.Gopx_Var_Cast__0[int]()
foo.Gopx_Var_Cast__1[map[string]any]()
}
`)

defer func() {
if e := recover(); e == nil {
t.Fatal("TestOverloadNamed failed: no error?")
}
func() {
defer func() {
if e := recover(); e == nil {
t.Fatal("TestOverloadNamed failed: no error?")
}
}()
ty3 := pkg.Instantiate(on, []types.Type{gox.TyByte})
pkg.NewTypeDefs().NewType("t3").InitType(pkg, ty3)
}()
func() {
defer func() {
if e := recover(); e != nil && e.(error).Error() != "-: 1 (type untyped int) is not a type" {
t.Fatal("TestOverloadNamed failed:", e)
}
}()
pkg.NewFunc(nil, "bar", nil, nil, false).BodyStart(pkg).
Val(objVar).Val(1, source("1")).Call(1).EndStmt().
End()
}()
ty3 := pkg.Instantiate(on, []types.Type{gox.TyByte})
pkg.NewTypeDefs().NewType("t3").InitType(pkg, ty3)
}

func TestInstantiate(t *testing.T) {
Expand Down