diff --git a/go/callgraph/vta/testdata/src/callgraph_recursive_types.go b/go/callgraph/vta/testdata/src/callgraph_recursive_types.go new file mode 100644 index 0000000000..6c3fef6f7c --- /dev/null +++ b/go/callgraph/vta/testdata/src/callgraph_recursive_types.go @@ -0,0 +1,56 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// go:build ignore + +package testdata + +type I interface { + Foo() I +} + +type A struct { + i int + a *A +} + +func (a *A) Foo() I { + return a +} + +type B **B + +type C *D +type D *C + +func Bar(a *A, b *B, c *C, d *D) { + Baz(a) + Baz(a.a) + + sink(*b) + sink(*c) + sink(*d) +} + +func Baz(i I) { + i.Foo() +} + +func sink(i interface{}) { + print(i) +} + +// Relevant SSA: +// func Baz(i I): +// t0 = invoke i.Foo() +// return +// +// func Bar(a *A, b *B): +// t0 = make I <- *A (a) +// t1 = Baz(t0) +// ... + +// WANT: +// Bar: Baz(t0) -> Baz; Baz(t4) -> Baz; sink(t10) -> sink; sink(t13) -> sink; sink(t7) -> sink +// Baz: invoke i.Foo() -> A.Foo diff --git a/go/callgraph/vta/utils.go b/go/callgraph/vta/utils.go index e7a97e2d84..0049955493 100644 --- a/go/callgraph/vta/utils.go +++ b/go/callgraph/vta/utils.go @@ -85,32 +85,52 @@ func isFunction(t types.Type) bool { // pointer to interface and if yes, returns the interface type. // Otherwise, returns nil. func interfaceUnderPtr(t types.Type) types.Type { - p, ok := t.Underlying().(*types.Pointer) - if !ok { - return nil - } + seen := make(map[types.Type]bool) + var visit func(types.Type) types.Type + visit = func(t types.Type) types.Type { + if seen[t] { + return nil + } + seen[t] = true - if isInterface(p.Elem()) { - return p.Elem() - } + p, ok := t.Underlying().(*types.Pointer) + if !ok { + return nil + } - return interfaceUnderPtr(p.Elem()) + if isInterface(p.Elem()) { + return p.Elem() + } + + return visit(p.Elem()) + } + return visit(t) } // functionUnderPtr checks if type `t` is a potentially nested // pointer to function type and if yes, returns the function type. // Otherwise, returns nil. func functionUnderPtr(t types.Type) types.Type { - p, ok := t.Underlying().(*types.Pointer) - if !ok { - return nil - } + seen := make(map[types.Type]bool) + var visit func(types.Type) types.Type + visit = func(t types.Type) types.Type { + if seen[t] { + return nil + } + seen[t] = true - if isFunction(p.Elem()) { - return p.Elem() - } + p, ok := t.Underlying().(*types.Pointer) + if !ok { + return nil + } - return functionUnderPtr(p.Elem()) + if isFunction(p.Elem()) { + return p.Elem() + } + + return visit(p.Elem()) + } + return visit(t) } // sliceArrayElem returns the element type of type `t` that is diff --git a/go/callgraph/vta/vta_test.go b/go/callgraph/vta/vta_test.go index 33ceaf9091..830390850a 100644 --- a/go/callgraph/vta/vta_test.go +++ b/go/callgraph/vta/vta_test.go @@ -24,6 +24,7 @@ func TestVTACallGraph(t *testing.T) { "testdata/src/callgraph_collections.go", "testdata/src/callgraph_fields.go", "testdata/src/callgraph_field_funcs.go", + "testdata/src/callgraph_recursive_types.go", } { t.Run(file, func(t *testing.T) { prog, want, err := testProg(file)