From 46bc274e027b115f6b27988c1d28d2e91afdfbcb Mon Sep 17 00:00:00 2001 From: Tim King Date: Thu, 14 Apr 2022 11:39:08 -0700 Subject: [PATCH] go/ssa: Update callee for wrapper function instantiation. Updates golang/go#48525 Change-Id: Iee30bee08f124118d22524e276762389c8358244 Reviewed-on: https://go-review.googlesource.com/c/tools/+/400374 Reviewed-by: Zvonimir Pavlinovic --- go/ssa/builder_test.go | 123 +++++++++++++++++++++++++++++++++++++++++ go/ssa/wrappers.go | 2 +- 2 files changed, 124 insertions(+), 1 deletion(-) diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go index 84f0692838..3fd9a8ae2b 100644 --- a/go/ssa/builder_test.go +++ b/go/ssa/builder_test.go @@ -565,6 +565,129 @@ func LoadPointer(addr *unsafe.Pointer) (val unsafe.Pointer) } } +func TestGenericWrappers(t *testing.T) { + if !typeparams.Enabled { + t.Skip("TestGenericWrappers only works with type parameters enabled.") + } + const input = ` +package p + +type S[T any] struct { + t *T +} + +func (x S[T]) M() T { + return *(x.t) +} + +var thunk = S[int].M + +var g S[int] +var bound = g.M + +type R[T any] struct{ S[T] } + +var indirect = R[int].M +` + // The relevant SSA members for this package should look something like this: + // var bound func() int + // var thunk func(S[int]) int + // var wrapper func(R[int]) int + + // Parse + var conf loader.Config + f, err := conf.ParseFile("", input) + if err != nil { + t.Fatalf("parse: %v", err) + } + conf.CreateFromFiles("p", f) + + // Load + lprog, err := conf.Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + + // Create and build SSA + prog := ssautil.CreateProgram(lprog, 0) + p := prog.Package(lprog.Package("p").Pkg) + p.Build() + + for _, entry := range []struct { + name string // name of the package variable + typ string // type of the package variable + wrapper string // wrapper function to which the package variable is set + callee string // callee within the wrapper function + }{ + { + "bound", + "*func() int", + "(p.S[int]).M$bound", + "(p.S[int]).M[[int]]", + }, + { + "thunk", + "*func(p.S[int]) int", + "(p.S[int]).M$thunk", + "(p.S[int]).M[[int]]", + }, + { + "indirect", + "*func(p.R[int]) int", + "(p.R[int]).M$thunk", + "(p.S[int]).M[[int]]", + }, + } { + entry := entry + t.Run(entry.name, func(t *testing.T) { + v := p.Var(entry.name) + if v == nil { + t.Fatalf("Did not find variable for %q in %s", entry.name, p.String()) + } + if v.Type().String() != entry.typ { + t.Errorf("Expected type for variable %s: %q. got %q", v, entry.typ, v.Type()) + } + + // Find the wrapper for v. This is stored exactly once in init. + var wrapper *ssa.Function + for _, bb := range p.Func("init").Blocks { + for _, i := range bb.Instrs { + if store, ok := i.(*ssa.Store); ok && v == store.Addr { + switch val := store.Val.(type) { + case *ssa.Function: + wrapper = val + case *ssa.MakeClosure: + wrapper = val.Fn.(*ssa.Function) + } + } + } + } + if wrapper == nil { + t.Fatalf("failed to find wrapper function for %s", entry.name) + } + if wrapper.String() != entry.wrapper { + t.Errorf("Expected wrapper function %q. got %q", wrapper, entry.wrapper) + } + + // Find the callee within the wrapper. There should be exactly one call. + var callee *ssa.Function + for _, bb := range wrapper.Blocks { + for _, i := range bb.Instrs { + if call, ok := i.(*ssa.Call); ok { + callee = call.Call.StaticCallee() + } + } + } + if callee == nil { + t.Fatalf("failed to find callee within wrapper %s", wrapper) + } + if callee.String() != entry.callee { + t.Errorf("Expected callee in wrapper %q is %q. got %q", v, entry.callee, callee) + } + }) + } +} + // TestTypeparamTest builds SSA over compilable examples in $GOROOT/test/typeparam/*.go. func TestTypeparamTest(t *testing.T) { diff --git a/go/ssa/wrappers.go b/go/ssa/wrappers.go index 799ba14ed2..deaa87f19e 100644 --- a/go/ssa/wrappers.go +++ b/go/ssa/wrappers.go @@ -126,7 +126,7 @@ func makeWrapper(prog *Program, sel *types.Selection, cr *creator) *Function { } callee := prog.originFunc(obj) if len(callee._TypeParams) > 0 { - prog.instances[callee].lookupOrCreate(receiverTypeArgs(obj), cr) + callee = prog.instances[callee].lookupOrCreate(receiverTypeArgs(obj), cr) } c.Call.Value = callee c.Call.Args = append(c.Call.Args, v)