Skip to content

Commit

Permalink
go/ssa: substitute type parameterized aliases
Browse files Browse the repository at this point in the history
Adds support to substitute type parameterized aliases in
generic functions.

Change-Id: I4fb2e5f5fd9b626781efdc4db808c52cb22ba241
Reviewed-on: https://go-review.googlesource.com/c/tools/+/602195
Reviewed-by: Alan Donovan <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
  • Loading branch information
timothy-king committed Aug 1, 2024
1 parent f6a2390 commit 6a6fd99
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 26 deletions.
26 changes: 15 additions & 11 deletions go/ssa/builder_generic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,13 @@ func TestGenericBodies(t *testing.T) {
}

// Collect calls to the builtin print function.
probes := callsTo(p, "print")
fns := make(map[*ssa.Function]bool)
for _, mem := range p.Members {
if fn, ok := mem.(*ssa.Function); ok {
fns[fn] = true
}
}
probes := callsTo(fns, "print")
expectations := matchNotes(prog.Fset, notes, probes)

for call := range probes {
Expand All @@ -576,17 +582,15 @@ func TestGenericBodies(t *testing.T) {

// callsTo finds all calls to an SSA value named fname,
// and returns a map from each call site to its enclosing function.
func callsTo(p *ssa.Package, fname string) map[*ssa.CallCommon]*ssa.Function {
func callsTo(fns map[*ssa.Function]bool, fname string) map[*ssa.CallCommon]*ssa.Function {
callsites := make(map[*ssa.CallCommon]*ssa.Function)
for _, mem := range p.Members {
if fn, ok := mem.(*ssa.Function); ok {
for _, bb := range fn.Blocks {
for _, i := range bb.Instrs {
if i, ok := i.(ssa.CallInstruction); ok {
call := i.Common()
if call.Value.Name() == fname {
callsites[call] = fn
}
for fn := range fns {
for _, bb := range fn.Blocks {
for _, i := range bb.Instrs {
if i, ok := i.(ssa.CallInstruction); ok {
call := i.Common()
if call.Value.Name() == fname {
callsites[call] = fn
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion go/ssa/builder_go122_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ func TestRangeOverInt(t *testing.T) {
}

// Collect calls to the built-in print function.
probes := callsTo(p, "print")
fns := make(map[*ssa.Function]bool)
for _, mem := range p.Members {
if fn, ok := mem.(*ssa.Function); ok {
fns[fn] = true
}
}
probes := callsTo(fns, "print")
expectations := matchNotes(fset, notes, probes)

for call := range probes {
Expand Down
141 changes: 141 additions & 0 deletions go/ssa/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"go/token"
"go/types"
"os"
"os/exec"
"path/filepath"
"reflect"
"sort"
Expand Down Expand Up @@ -1260,3 +1261,143 @@ func TestIssue67079(t *testing.T) {

g.Wait() // ignore error
}

func TestGenericAliases(t *testing.T) {
testenv.NeedsGo1Point(t, 23)

if os.Getenv("GENERICALIASTEST_CHILD") == "1" {
testGenericAliases(t)
return
}

testenv.NeedsExec(t)
testenv.NeedsTool(t, "go")

cmd := exec.Command(os.Args[0], "-test.run=TestGenericAliases")
cmd.Env = append(os.Environ(),
"GENERICALIASTEST_CHILD=1",
"GODEBUG=gotypesalias=1",
"GOEXPERIMENT=aliastypeparams",
)
out, err := cmd.CombinedOutput()
if len(out) > 0 {
t.Logf("out=<<%s>>", out)
}
var exitcode int
if err, ok := err.(*exec.ExitError); ok {
exitcode = err.ExitCode()
}
const want = 0
if exitcode != want {
t.Errorf("exited %d, want %d", exitcode, want)
}
}

func testGenericAliases(t *testing.T) {
t.Setenv("GOEXPERIMENT", "aliastypeparams=1")

const source = `
package P
type A = uint8
type B[T any] = [4]T
var F = f[string]
func f[S any]() {
// Two copies of f are made: p.f[S] and p.f[string]
var v A // application of A that is declared outside of f without no type arguments
print("p.f", "String", "p.A", v)
print("p.f", "==", v, uint8(0))
print("p.f[string]", "String", "p.A", v)
print("p.f[string]", "==", v, uint8(0))
var u B[S] // application of B that is declared outside declared outside of f with type arguments
print("p.f", "String", "p.B[S]", u)
print("p.f", "==", u, [4]S{})
print("p.f[string]", "String", "p.B[string]", u)
print("p.f[string]", "==", u, [4]string{})
type C[T any] = struct{ s S; ap *B[T]} // declaration within f with type params
var w C[int] // application of C with type arguments
print("p.f", "String", "p.C[int]", w)
print("p.f", "==", w, struct{ s S; ap *[4]int}{})
print("p.f[string]", "String", "p.C[int]", w)
print("p.f[string]", "==", w, struct{ s string; ap *[4]int}{})
}
`

conf := loader.Config{Fset: token.NewFileSet()}
f, err := parser.ParseFile(conf.Fset, "p.go", source, 0)
if err != nil {
t.Fatal(err)
}
conf.CreateFromFiles("p", f)
iprog, err := conf.Load()
if err != nil {
t.Fatal(err)
}

// Create and build SSA program.
prog := ssautil.CreateProgram(iprog, ssa.InstantiateGenerics)
prog.Build()

probes := callsTo(ssautil.AllFunctions(prog), "print")
if got, want := len(probes), 3*4*2; got != want {
t.Errorf("Found %v probes, expected %v", got, want)
}

const debug = false // enable to debug skips
skipped := 0
for probe, fn := range probes {
// Each probe is of the form:
// print("within", "test", head, tail)
// The probe only matches within a function whose fn.String() is within.
// This allows for different instantiations of fn to match different probes.
// On a match, it applies the test named "test" to head::tail.
if len(probe.Args) < 3 {
t.Fatalf("probe %v did not have enough arguments", probe)
}
within, test, head, tail := constString(probe.Args[0]), probe.Args[1], probe.Args[2], probe.Args[3:]
if within != fn.String() {
skipped++
if debug {
t.Logf("Skipping %q within %q", within, fn.String())
}
continue // does not match function
}

switch test := constString(test); test {
case "==": // All of the values are types.Identical.
for _, v := range tail {
if !types.Identical(head.Type(), v.Type()) {
t.Errorf("Expected %v and %v to have identical types", head, v)
}
}
case "String": // head is a string constant that all values in tail must match Type().String()
want := constString(head)
for _, v := range tail {
if got := v.Type().String(); got != want {
t.Errorf("%s: %v had the Type().String()=%q. expected %q", within, v, got, want)
}
}
default:
t.Errorf("%q is not a test subcommand", test)
}
}
if want := 3 * 4; skipped != want {
t.Errorf("Skipped %d probes, expected to skip %d", skipped, want)
}
}

// constString returns the value of a string constant
// or "<not a constant string>" if the value is not a string constant.
func constString(v ssa.Value) string {
if c, ok := v.(*ssa.Const); ok {
str := c.Value.String()
return strings.Trim(str, `"`)
}
return "<not a constant string>"
}
83 changes: 74 additions & 9 deletions go/ssa/subst.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,80 @@ func (subst *subster) interface_(iface *types.Interface) *types.Interface {
}

func (subst *subster) alias(t *aliases.Alias) types.Type {
// TODO(go.dev/issues/46477): support TypeParameters once these are available from go/types.
u := aliases.Unalias(t)
if s := subst.typ(u); s != u {
// If there is any change, do not create a new alias.
return s
// See subster.named. This follows the same strategy.
tparams := aliases.TypeParams(t)
targs := aliases.TypeArgs(t)
tname := t.Obj()
torigin := aliases.Origin(t)

if !declaredWithin(tname, subst.origin) {
// t is declared outside of the function origin. So t is a package level type alias.
if targs.Len() == 0 {
// No type arguments so no instantiation needed.
return t
}

// Instantiate with the substituted type arguments.
newTArgs := subst.typelist(targs)
return subst.instantiate(torigin, newTArgs)
}
// If there is no change, t did not reach any type parameter.
// Keep the Alias.
return t

if targs.Len() == 0 {
// t is declared within the function origin and has no type arguments.
//
// Example: This corresponds to A or B in F, but not A[int]:
//
// func F[T any]() {
// type A[S any] = struct{t T, s S}
// type B = T
// var x A[int]
// ...
// }
//
// This is somewhat different than *Named as *Alias cannot be created recursively.

// Copy and substitute type params.
var newTParams []*types.TypeParam
for i := 0; i < tparams.Len(); i++ {
cur := tparams.At(i)
cobj := cur.Obj()
cname := types.NewTypeName(cobj.Pos(), cobj.Pkg(), cobj.Name(), nil)
ntp := types.NewTypeParam(cname, nil)
subst.cache[cur] = ntp // See the comment "Note: Subtle" in subster.named.
newTParams = append(newTParams, ntp)
}

// Substitute rhs.
rhs := subst.typ(aliases.Rhs(t))

// Create the fresh alias.
obj := aliases.NewAlias(true, tname.Pos(), tname.Pkg(), tname.Name(), rhs)
fresh := obj.Type()
if fresh, ok := fresh.(*aliases.Alias); ok {
// TODO: assume ok when aliases are always materialized (go1.27).
aliases.SetTypeParams(fresh, newTParams)
}

// Substitute into all of the constraints after they are created.
for i, ntp := range newTParams {
bound := tparams.At(i).Constraint()
ntp.SetConstraint(subst.typ(bound))
}
return fresh
}

// t is declared within the function origin and has type arguments.
//
// Example: This corresponds to A[int] in F. Cases A and B are handled above.
// func F[T any]() {
// type A[S any] = struct{t T, s S}
// type B = T
// var x A[int]
// ...
// }
subOrigin := subst.typ(torigin)
subTArgs := subst.typelist(targs)
return subst.instantiate(subOrigin, subTArgs)
}

func (subst *subster) named(t *types.Named) types.Type {
Expand Down Expand Up @@ -456,7 +521,7 @@ func (subst *subster) named(t *types.Named) types.Type {

func (subst *subster) instantiate(orig types.Type, targs []types.Type) types.Type {
i, err := types.Instantiate(subst.ctxt, orig, targs, false)
assert(err == nil, "failed to Instantiate Named type")
assert(err == nil, "failed to Instantiate named (Named or Alias) type")
if c, _ := subst.uniqueness.At(i).(types.Type); c != nil {
return c.(types.Type)
}
Expand Down
13 changes: 8 additions & 5 deletions internal/aliases/aliases_go121.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ import (
// It will never be created by go/types.
type Alias struct{}

func (*Alias) String() string { panic("unreachable") }
func (*Alias) Underlying() types.Type { panic("unreachable") }
func (*Alias) Obj() *types.TypeName { panic("unreachable") }
func Rhs(alias *Alias) types.Type { panic("unreachable") }
func TypeParams(alias *Alias) *types.TypeParamList { panic("unreachable") }
func (*Alias) String() string { panic("unreachable") }
func (*Alias) Underlying() types.Type { panic("unreachable") }
func (*Alias) Obj() *types.TypeName { panic("unreachable") }
func Rhs(alias *Alias) types.Type { panic("unreachable") }
func TypeParams(alias *Alias) *types.TypeParamList { panic("unreachable") }
func SetTypeParams(alias *Alias, tparams []*types.TypeParam) { panic("unreachable") }
func TypeArgs(alias *Alias) *types.TypeList { panic("unreachable") }
func Origin(alias *Alias) *Alias { panic("unreachable") }

// Unalias returns the type t for go <=1.21.
func Unalias(t types.Type) types.Type { return t }
Expand Down
28 changes: 28 additions & 0 deletions internal/aliases/aliases_go122.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,34 @@ func TypeParams(alias *Alias) *types.TypeParamList {
return nil
}

// SetTypeParams sets the type parameters of the alias type.
func SetTypeParams(alias *Alias, tparams []*types.TypeParam) {
if alias, ok := any(alias).(interface {
SetTypeParams(tparams []*types.TypeParam)
}); ok {
alias.SetTypeParams(tparams) // go1.23+
} else if len(tparams) > 0 {
panic("cannot set type parameters of an Alias type in go1.22")
}
}

// TypeArgs returns the type arguments used to instantiate the Alias type.
func TypeArgs(alias *Alias) *types.TypeList {
if alias, ok := any(alias).(interface{ TypeArgs() *types.TypeList }); ok {
return alias.TypeArgs() // go1.23+
}
return nil // empty (go1.22)
}

// Origin returns the generic Alias type of which alias is an instance.
// If alias is not an instance of a generic alias, Origin returns alias.
func Origin(alias *Alias) *Alias {
if alias, ok := any(alias).(interface{ Origin() *types.Alias }); ok {
return alias.Origin() // go1.23+
}
return alias // not an instance of a generic alias (go1.22)
}

// Unalias is a wrapper of types.Unalias.
func Unalias(t types.Type) types.Type { return types.Unalias(t) }

Expand Down

0 comments on commit 6a6fd99

Please sign in to comment.