From eb25de6e2a94ebbd4401db4c7695bbb8d0501159 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Fri, 23 Sep 2022 10:55:58 -0400 Subject: [PATCH] go/analysis/passes/loopclosure: only check statements after t.Parallel After experimenting with the new t.Run+t.Parallel check in the loopclosure analyzer, we discovered that users rely on the fact that statements before the call to t.Parallel are executed synchronously, for example by declaring test := test inside the function literal, but before the call to t.Parallel. To avoid such false positives, only consider statements occurring after the first call to t.Parallel. Change-Id: I88466ea3bfd318d42d734c320677fbe5e3f6cb00 Reviewed-on: https://go-review.googlesource.com/c/tools/+/433535 Run-TryBot: Robert Findley TryBot-Result: Gopher Robot Reviewed-by: Alan Donovan gopls-CI: kokoro --- go/analysis/passes/loopclosure/loopclosure.go | 98 +++++++++++-------- .../testdata/src/subtests/subtest.go | 26 ++++- 2 files changed, 81 insertions(+), 43 deletions(-) diff --git a/go/analysis/passes/loopclosure/loopclosure.go b/go/analysis/passes/loopclosure/loopclosure.go index 645e5895bb..35fe15c9a2 100644 --- a/go/analysis/passes/loopclosure/loopclosure.go +++ b/go/analysis/passes/loopclosure/loopclosure.go @@ -104,55 +104,64 @@ func run(pass *analysis.Pass) (interface{}, error) { // fighting against the test runner. lastStmt := len(body.List) - 1 for i, s := range body.List { - var fun ast.Expr // if non-nil, a function that escapes the loop iteration + var stmts []ast.Stmt // statements that must be checked for escaping references switch s := s.(type) { case *ast.GoStmt: if i == lastStmt { - fun = s.Call.Fun + stmts = litStmts(s.Call.Fun) } case *ast.DeferStmt: if i == lastStmt { - fun = s.Call.Fun + stmts = litStmts(s.Call.Fun) } case *ast.ExprStmt: // check for errgroup.Group.Go and testing.T.Run (with T.Parallel) if call, ok := s.X.(*ast.CallExpr); ok { if i == lastStmt { - fun = goInvoke(pass.TypesInfo, call) + stmts = litStmts(goInvoke(pass.TypesInfo, call)) } - if fun == nil && analysisinternal.LoopclosureParallelSubtests { - fun = parallelSubtest(pass.TypesInfo, call) + if stmts == nil && analysisinternal.LoopclosureParallelSubtests { + stmts = parallelSubtest(pass.TypesInfo, call) } } } - lit, ok := fun.(*ast.FuncLit) - if !ok { - continue - } - - ast.Inspect(lit.Body, func(n ast.Node) bool { - id, ok := n.(*ast.Ident) - if !ok { - return true - } - obj := pass.TypesInfo.Uses[id] - if obj == nil { - return true - } - for _, v := range vars { - if v == obj { - pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name) + for _, stmt := range stmts { + ast.Inspect(stmt, func(n ast.Node) bool { + id, ok := n.(*ast.Ident) + if !ok { + return true } - } - return true - }) + obj := pass.TypesInfo.Uses[id] + if obj == nil { + return true + } + for _, v := range vars { + if v == obj { + pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name) + } + } + return true + }) + } } }) return nil, nil } +// litStmts returns all statements from the function body of a function +// literal. +// +// If fun is not a function literal, it returns nil. +func litStmts(fun ast.Expr) []ast.Stmt { + lit, _ := fun.(*ast.FuncLit) + if lit == nil { + return nil + } + return lit.Body.List +} + // goInvoke returns a function expression that would be called asynchronously // (but not awaited) in another goroutine as a consequence of the call. // For example, given the g.Go call below, it returns the function literal expression. @@ -169,38 +178,45 @@ func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr { return call.Args[0] } -// parallelSubtest returns a function expression that would be called +// parallelSubtest returns statements that would would be executed // asynchronously via the go test runner, as t.Run has been invoked with a // function literal that calls t.Parallel. // -// import "testing" +// In practice, users rely on the fact that statements before the call to +// t.Parallel are synchronous. For example by declaring test := test inside the +// function literal, but before the call to t.Parallel. // -// func TestFoo(t *testing.T) { -// tests := []int{0, 1, 2} -// for i, t := range tests { -// t.Run("subtest", func(t *testing.T) { -// t.Parallel() -// println(i, t) -// }) -// } +// Therefore, we only flag references that occur after the call to t.Parallel: +// +// import "testing" +// +// func TestFoo(t *testing.T) { +// tests := []int{0, 1, 2} +// for i, test := range tests { +// t.Run("subtest", func(t *testing.T) { +// println(i, test) // OK +// t.Parallel() +// println(i, test) // Not OK +// }) // } -func parallelSubtest(info *types.Info, call *ast.CallExpr) ast.Expr { +// } +func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt { if !isMethodCall(info, call, "testing", "T", "Run") { return nil } - lit, ok := call.Args[1].(*ast.FuncLit) - if !ok { + lit, _ := call.Args[1].(*ast.FuncLit) + if lit == nil { return nil } - for _, stmt := range lit.Body.List { + for i, stmt := range lit.Body.List { exprStmt, ok := stmt.(*ast.ExprStmt) if !ok { continue } if isMethodCall(info, exprStmt.X, "testing", "T", "Parallel") { - return lit + return lit.Body.List[i+1:] } } diff --git a/go/analysis/passes/loopclosure/testdata/src/subtests/subtest.go b/go/analysis/passes/loopclosure/testdata/src/subtests/subtest.go index 4bcd495367..2a97244a1a 100644 --- a/go/analysis/passes/loopclosure/testdata/src/subtests/subtest.go +++ b/go/analysis/passes/loopclosure/testdata/src/subtests/subtest.go @@ -18,7 +18,7 @@ type T struct{} // Run should not match testing.T.Run. Note that the second argument is // intentionally a *testing.T, not a *T, so that we can check both // testing.T.Parallel inside a T.Run, and a T.Parallel inside a testing.T.Run. -func (t *T) Run(string, func(*testing.T)) { // The second argument here is testing.T +func (t *T) Run(string, func(*testing.T)) { } func (t *T) Parallel() {} @@ -38,11 +38,33 @@ func _(t *testing.T) { println(test) }) - // Check that the location of t.Parallel does not matter. + // Check that the location of t.Parallel matters. t.Run("", func(t *testing.T) { + println(i) + println(test) + t.Parallel() println(i) // want "loop variable i captured by func literal" println(test) // want "loop variable test captured by func literal" + }) + + // Check that shadowing the loop variables within the test literal is OK if + // it occurs before t.Parallel(). + t.Run("", func(t *testing.T) { + i := i + test := test t.Parallel() + println(i) + println(test) + }) + + // Check that shadowing the loop variables within the test literal is Not + // OK if it occurs after t.Parallel(). + t.Run("", func(t *testing.T) { + t.Parallel() + i := i // want "loop variable i captured by func literal" + test := test // want "loop variable test captured by func literal" + println(i) // OK + println(test) // OK }) // Check uses in nested blocks.