diff --git a/go/ssa/builder.go b/go/ssa/builder.go index 5c2a2efc36..b36775a4e3 100644 --- a/go/ssa/builder.go +++ b/go/ssa/builder.go @@ -840,16 +840,15 @@ func (b *builder) expr0(fn *Function, e ast.Expr, tv types.TypeAndValue) Value { panic("unexpected expression-relative selector") case *typeparams.IndexListExpr: - if ident, ok := e.X.(*ast.Ident); ok { - // IndexListExpr is an instantiation. It will be handled by the *Ident case. - return b.expr(fn, ident) + // f[X, Y] must be a generic function + if !instance(fn.info, e.X) { + panic("unexpected expression-could not match index list to instantiation") } + return b.expr(fn, e.X) // Handle instantiation within the *Ident or *SelectorExpr cases. + case *ast.IndexExpr: - if ident, ok := e.X.(*ast.Ident); ok { - if _, ok := typeparams.GetInstances(fn.info)[ident]; ok { - // If the IndexExpr is an instantiation, it will be handled by the *Ident case. - return b.expr(fn, ident) - } + if instance(fn.info, e.X) { + return b.expr(fn, e.X) // Handle instantiation within the *Ident or *SelectorExpr cases. } // not a generic instantiation. switch t := fn.typeOf(e.X).Underlying().(type) { diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go index 6b9c798037..ba9aaf768f 100644 --- a/go/ssa/builder_test.go +++ b/go/ssa/builder_test.go @@ -20,6 +20,7 @@ import ( "strings" "testing" + "golang.org/x/tools/go/buildutil" "golang.org/x/tools/go/loader" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" @@ -823,3 +824,56 @@ func sliceMax(s []int) []int { return s[a():b():c()] } }) } } + +// TestGenericFunctionSelector ensures generic functions from other packages can be selected. +func TestGenericFunctionSelector(t *testing.T) { + if !typeparams.Enabled { + t.Skip("TestGenericFunctionSelector uses type parameters.") + } + + pkgs := map[string]map[string]string{ + "main": {"m.go": `package main; import "a"; func main() { a.F[int](); a.G[int,string](); a.H(0) }`}, + "a": {"a.go": `package a; func F[T any](){}; func G[S, T any](){}; func H[T any](a T){} `}, + } + + for _, mode := range []ssa.BuilderMode{ + ssa.SanityCheckFunctions, + ssa.SanityCheckFunctions | ssa.InstantiateGenerics, + } { + conf := loader.Config{ + Build: buildutil.FakeContext(pkgs), + } + conf.Import("main") + + lprog, err := conf.Load() + if err != nil { + t.Errorf("Load failed: %s", err) + } + if lprog == nil { + t.Fatalf("Load returned nil *Program") + } + // Create and build SSA + prog := ssautil.CreateProgram(lprog, mode) + p := prog.Package(lprog.Package("main").Pkg) + p.Build() + + var callees []string // callees of the CallInstruction.String() in main(). + for _, b := range p.Func("main").Blocks { + for _, i := range b.Instrs { + if call, ok := i.(ssa.CallInstruction); ok { + if callee := call.Common().StaticCallee(); call != nil { + callees = append(callees, callee.String()) + } else { + t.Errorf("CallInstruction without StaticCallee() %q", call) + } + } + } + } + sort.Strings(callees) // ignore the order in the code. + + want := "[a.F[[int]] a.G[[int string]] a.H[[int]]]" + if got := fmt.Sprint(callees); got != want { + t.Errorf("Expected main() to contain calls %v. got %v", want, got) + } + } +} diff --git a/go/ssa/util.go b/go/ssa/util.go index dfeaeebdbc..80c7d5cbec 100644 --- a/go/ssa/util.go +++ b/go/ssa/util.go @@ -175,6 +175,24 @@ func recvAsFirstArg(sig *types.Signature) *types.Signature { return typeparams.NewSignatureType(nil, nil, nil, types.NewTuple(params...), sig.Results(), sig.Variadic()) } +// instance returns whether an expression is a simple or qualified identifier +// that is a generic instantiation. +func instance(info *types.Info, expr ast.Expr) bool { + // Compare the logic here against go/types.instantiatedIdent, + // which also handles *IndexExpr and *IndexListExpr. + var id *ast.Ident + switch x := expr.(type) { + case *ast.Ident: + id = x + case *ast.SelectorExpr: + id = x.Sel + default: + return false + } + _, ok := typeparams.GetInstances(info)[id] + return ok +} + // instanceArgs returns the Instance[id].TypeArgs as a slice. func instanceArgs(info *types.Info, id *ast.Ident) []types.Type { targList := typeparams.GetInstances(info)[id].TypeArgs