diff --git a/go/callgraph/vta/graph.go b/go/callgraph/vta/graph.go index 6c9e6a576d..f846418f1c 100644 --- a/go/callgraph/vta/graph.go +++ b/go/callgraph/vta/graph.go @@ -190,6 +190,25 @@ func (l nestedPtrInterface) String() string { return fmt.Sprintf("PtrInterface(%v)", l.typ) } +// nestedPtrFunction node represents all references and dereferences of locals +// and globals that have a nested pointer to function type. We merge such +// constructs into a single node for simplicity and without much precision +// sacrifice as such variables are rare in practice. Both a and b would be +// represented as the same PtrFunction(func()) node in: +// var a *func() +// var b **func() +type nestedPtrFunction struct { + typ types.Type +} + +func (p nestedPtrFunction) Type() types.Type { + return p.typ +} + +func (p nestedPtrFunction) String() string { + return fmt.Sprintf("PtrFunction(%v)", p.typ) +} + // panicArg models types of all arguments passed to panic. type panicArg struct{} @@ -615,12 +634,16 @@ func (b *builder) addInFlowEdge(s, d node) { // Creates const, pointer, global, func, and local nodes based on register instructions. func (b *builder) nodeFromVal(val ssa.Value) node { - if p, ok := val.Type().(*types.Pointer); ok && !isInterface(p.Elem()) { + if p, ok := val.Type().(*types.Pointer); ok && !isInterface(p.Elem()) && !isFunction(p.Elem()) { // Nested pointer to interfaces are modeled as a special // nestedPtrInterface node. if i := interfaceUnderPtr(p.Elem()); i != nil { return nestedPtrInterface{typ: i} } + // The same goes for nested function types. + if f := functionUnderPtr(p.Elem()); f != nil { + return nestedPtrFunction{typ: f} + } return pointer{typ: p} } @@ -665,6 +688,8 @@ func (b *builder) representative(n node) node { return channelElem{typ: t} case nestedPtrInterface: return nestedPtrInterface{typ: t} + case nestedPtrFunction: + return nestedPtrFunction{typ: t} case field: return field{StructType: canonicalize(i.StructType, &b.canon), index: i.index} case indexedLocal: diff --git a/go/callgraph/vta/graph_test.go b/go/callgraph/vta/graph_test.go index 61bb05ad6e..7ccfe492f8 100644 --- a/go/callgraph/vta/graph_test.go +++ b/go/callgraph/vta/graph_test.go @@ -43,6 +43,8 @@ func TestNodeInterface(t *testing.T) { pint := types.NewPointer(bint) i := types.NewInterface(nil, nil) + voidFunc := main.Signature.Underlying() + for _, test := range []struct { n node s string @@ -59,8 +61,9 @@ func TestNodeInterface(t *testing.T) { {global{val: gl}, "Global(gl)", gl.Type()}, {local{val: reg}, "Local(t0)", bint}, {indexedLocal{val: reg, typ: X, index: 0}, "Local(t0[0])", X}, - {function{f: main}, "Function(main)", main.Signature.Underlying()}, + {function{f: main}, "Function(main)", voidFunc}, {nestedPtrInterface{typ: i}, "PtrInterface(interface{})", i}, + {nestedPtrFunction{typ: voidFunc}, "PtrFunction(func())", voidFunc}, {panicArg{}, "Panic", nil}, {recoverReturn{}, "Recover", nil}, } { @@ -181,6 +184,7 @@ func TestVTAGraphConstruction(t *testing.T) { "testdata/maps.go", "testdata/ranges.go", "testdata/closures.go", + "testdata/function_alias.go", "testdata/static_calls.go", "testdata/dynamic_calls.go", "testdata/returns.go", diff --git a/go/callgraph/vta/testdata/callgraph_field_funcs.go b/go/callgraph/vta/testdata/callgraph_field_funcs.go new file mode 100644 index 0000000000..cf4c0f1d74 --- /dev/null +++ b/go/callgraph/vta/testdata/callgraph_field_funcs.go @@ -0,0 +1,67 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// go:build ignore + +package testdata + +type WrappedFunc struct { + F func() complex64 +} + +func callWrappedFunc(f WrappedFunc) { + f.F() +} + +func foo() complex64 { + println("foo") + return -1 +} + +func Foo(b bool) { + callWrappedFunc(WrappedFunc{foo}) + x := func() {} + y := func() {} + var a *func() + if b { + a = &x + } else { + a = &y + } + (*a)() +} + +// Relevant SSA: +// func Foo(b bool): +// t0 = local WrappedFunc (complit) +// t1 = &t0.F [#0] +// *t1 = foo +// t2 = *t0 +// t3 = callWrappedFunc(t2) +// t4 = new func() (x) +// *t4 = Foo$1 +// t5 = new func() (y) +// *t5 = Foo$2 +// if b goto 1 else 3 +// 1: +// jump 2 +// 2: +// t6 = phi [1: t4, 3: t5] #a +// t7 = *t6 +// t8 = t7() +// return +// 3: +// jump 2 +// +// func callWrappedFunc(f WrappedFunc): +// t0 = local WrappedFunc (f) +// *t0 = f +// t1 = &t0.F [#0] +// t2 = *t1 +// t3 = t2() +// return + +// WANT: +// callWrappedFunc: t2() -> foo +// Foo: callWrappedFunc(t2) -> callWrappedFunc; t7() -> Foo$1, Foo$2 diff --git a/go/callgraph/vta/testdata/function_alias.go b/go/callgraph/vta/testdata/function_alias.go new file mode 100644 index 0000000000..b38e0e00d6 --- /dev/null +++ b/go/callgraph/vta/testdata/function_alias.go @@ -0,0 +1,74 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// go:build ignore + +package testdata + +type Doer func() + +type A struct { + foo func() + do Doer +} + +func Baz(f func()) { + j := &f + k := &j + **k = func() {} + a := A{} + a.foo = **k + a.foo() + a.do = a.foo + a.do() +} + +// Relevant SSA: +// func Baz(f func()): +// t0 = new func() (f) +// *t0 = f +// t1 = new *func() (j) +// *t1 = t0 +// t2 = *t1 +// *t2 = Baz$1 +// t3 = local A (a) +// t4 = &t3.foo [#0] +// t5 = *t1 +// t6 = *t5 +// *t4 = t6 +// t7 = &t3.foo [#0] +// t8 = *t7 +// t9 = t8() +// t10 = &t3.do [#1] *Doer +// t11 = &t3.foo [#0] *func() +// t12 = *t11 func() +// t13 = changetype Doer <- func() (t12) Doer +// *t10 = t13 +// t14 = &t3.do [#1] *Doer +// t15 = *t14 Doer +// t16 = t15() () + +// Flow chain showing that Baz$1 reaches t8(): +// Baz$1 -> t2 <-> PtrFunction(func()) <-> t5 -> t6 -> t4 <-> Field(testdata.A:foo) <-> t7 -> t8 +// Flow chain showing that Baz$1 reaches t15(): +// Field(testdata.A:foo) <-> t11 -> t12 -> t13 -> t10 <-> Field(testdata.A:do) <-> t14 -> t15 + +// WANT: +// Local(f) -> Local(t0) +// Local(t0) -> PtrFunction(func()) +// Function(Baz$1) -> Local(t2) +// PtrFunction(func()) -> Local(t0), Local(t2), Local(t5) +// Local(t2) -> PtrFunction(func()) +// Local(t4) -> Field(testdata.A:foo) +// Local(t5) -> Local(t6), PtrFunction(func()) +// Local(t6) -> Local(t4) +// Local(t7) -> Field(testdata.A:foo), Local(t8) +// Field(testdata.A:foo) -> Local(t11), Local(t4), Local(t7) +// Local(t4) -> Field(testdata.A:foo) +// Field(testdata.A:do) -> Local(t10), Local(t14) +// Local(t10) -> Field(testdata.A:do) +// Local(t11) -> Field(testdata.A:foo), Local(t12) +// Local(t12) -> Local(t13) +// Local(t13) -> Local(t10) +// Local(t14) -> Field(testdata.A:do), Local(t15) diff --git a/go/callgraph/vta/utils.go b/go/callgraph/vta/utils.go index cabc93be61..9633b8680b 100644 --- a/go/callgraph/vta/utils.go +++ b/go/callgraph/vta/utils.go @@ -19,6 +19,9 @@ func isReferenceNode(n node) bool { if _, ok := n.(nestedPtrInterface); ok { return true } + if _, ok := n.(nestedPtrFunction); ok { + return true + } if _, ok := n.Type().(*types.Pointer); ok { return true @@ -33,7 +36,9 @@ func isReferenceNode(n node) bool { // 2) is a (nested) pointer to interface (needed for, say, // slice elements of nested pointers to interface type) // 3) is a function type (needed for higher-order type flow) -// 4) is a global Recover or Panic node +// 4) is a (nested) pointer to function (needed for, say, +// slice elements of nested pointers to function type) +// 5) is a global Recover or Panic node func hasInFlow(n node) bool { if _, ok := n.(panicArg); ok { return true @@ -44,15 +49,14 @@ func hasInFlow(n node) bool { t := n.Type() - if _, ok := t.Underlying().(*types.Signature); ok { - return true - } - if i := interfaceUnderPtr(t); i != nil { return true } + if f := functionUnderPtr(t); f != nil { + return true + } - return isInterface(t) + return isInterface(t) || isFunction(t) } // hasInitialTypes check if a node can have initial types. @@ -72,6 +76,11 @@ func isInterface(t types.Type) bool { return ok } +func isFunction(t types.Type) bool { + _, ok := t.Underlying().(*types.Signature) + return ok +} + // interfaceUnderPtr checks if type `t` is a potentially nested // pointer to interface and if yes, returns the interface type. // Otherwise, returns nil. @@ -88,6 +97,22 @@ func interfaceUnderPtr(t types.Type) types.Type { return interfaceUnderPtr(p.Elem()) } +// functionUnderPtr checks if type `t` is a potentially nested +// pointer to function type and if yes, returns the function type. +// Otherwise, returns nil. +func functionUnderPtr(t types.Type) types.Type { + p, ok := t.Underlying().(*types.Pointer) + if !ok { + return nil + } + + if isFunction(p.Elem()) { + return p.Elem() + } + + return functionUnderPtr(p.Elem()) +} + // sliceArrayElem returns the element type of type `t` that is // expected to be a (pointer to) array or slice, consistent with // the ssa.Index and ssa.IndexAddr instructions. Panics otherwise. diff --git a/go/callgraph/vta/vta_test.go b/go/callgraph/vta/vta_test.go index b0d2de7836..e5a9b41e0e 100644 --- a/go/callgraph/vta/vta_test.go +++ b/go/callgraph/vta/vta_test.go @@ -20,6 +20,7 @@ func TestVTACallGraph(t *testing.T) { "testdata/callgraph_pointers.go", "testdata/callgraph_collections.go", "testdata/callgraph_fields.go", + "testdata/callgraph_field_funcs.go", } { t.Run(file, func(t *testing.T) { prog, want, err := testProg(file)