diff --git a/internal/lsp/source/completion/statements.go b/internal/lsp/source/completion/statements.go index 62d3cf0ed0..3280bb52c8 100644 --- a/internal/lsp/source/completion/statements.go +++ b/internal/lsp/source/completion/statements.go @@ -18,7 +18,7 @@ import ( // addStatementCandidates adds full statement completion candidates // appropriate for the current context. func (c *completer) addStatementCandidates() { - c.addErrCheckAndReturn() + c.addErrCheck() c.addAssignAppend() } @@ -162,27 +162,36 @@ func (c *completer) topCandidate() *CompletionItem { return bestItem } -// addErrCheckAndReturn offers a completion candidate of the form: +// addErrCheck offers a completion candidate of the form: // // if err != nil { // return nil, err // } // +// In the case of test functions, it offers a completion candidate of the form: +// +// if err != nil { +// t.Fatal(err) +// } +// // The position must be in a function that returns an error, and the // statement preceding the position must be an assignment where the -// final LHS object is an error. addErrCheckAndReturn will synthesize +// final LHS object is an error. addErrCheck will synthesize // zero values as necessary to make the return statement valid. -func (c *completer) addErrCheckAndReturn() { +func (c *completer) addErrCheck() { if len(c.path) < 2 || c.enclosingFunc == nil || !c.opts.placeholders { return } var ( - errorType = types.Universe.Lookup("error").Type() - result = c.enclosingFunc.sig.Results() + errorType = types.Universe.Lookup("error").Type() + result = c.enclosingFunc.sig.Results() + testVar = getTestVar(c.enclosingFunc, c.pkg) + isTest = testVar != "" + doesNotReturnErr = result.Len() == 0 || !types.Identical(result.At(result.Len()-1).Type(), errorType) ) - // Make sure our enclosing function returns an error. - if result.Len() == 0 || !types.Identical(result.At(result.Len()-1).Type(), errorType) { + // Make sure our enclosing function is a Test func or returns an error. + if !isTest && doesNotReturnErr { return } @@ -205,15 +214,17 @@ func (c *completer) addErrCheckAndReturn() { } var ( - // errText is e.g. "err" in "foo, err := bar()". - errText = source.FormatNode(c.snapshot.FileSet(), lastAssignee) + // errVar is e.g. "err" in "foo, err := bar()". + errVar = source.FormatNode(c.snapshot.FileSet(), lastAssignee) // Whether we need to include the "if" keyword in our candidate. needsIf = true ) - // "_" isn't a real object. - if errText == "_" { + // If the returned error from the previous statement is "_", it is not a real object. + // If we don't have an error, and the function signature takes a testing.TB that is either ignored + // or an "_", then we also can't call t.Fatal(err). + if errVar == "_" { return } @@ -240,7 +251,7 @@ func (c *completer) addErrCheckAndReturn() { // if er<> // Make sure they are typing the error's name. - if c.matcher.Score(errText) <= 0 { + if c.matcher.Score(errVar) <= 0 { return } @@ -277,20 +288,26 @@ func (c *completer) addErrCheckAndReturn() { if needsIf { snip.WriteText("if ") } - snip.WriteText(fmt.Sprintf("%s != nil {\n\treturn ", errText)) + snip.WriteText(fmt.Sprintf("%s != nil {\n\t", errVar)) - for i := 0; i < result.Len()-1; i++ { - snip.WriteText(formatZeroValue(result.At(i).Type(), c.qf)) - snip.WriteText(", ") + var label string + if isTest { + snip.WriteText(fmt.Sprintf("%s.Fatal(%s)", testVar, errVar)) + label = fmt.Sprintf("%[1]s != nil { %[2]s.Fatal(%[1]s) }", errVar, testVar) + } else { + snip.WriteText("return ") + for i := 0; i < result.Len()-1; i++ { + snip.WriteText(formatZeroValue(result.At(i).Type(), c.qf)) + snip.WriteText(", ") + } + snip.WritePlaceholder(func(b *snippet.Builder) { + b.WriteText(errVar) + }) + label = fmt.Sprintf("%[1]s != nil { return %[1]s }", errVar) } - snip.WritePlaceholder(func(b *snippet.Builder) { - b.WriteText(errText) - }) - snip.WriteText("\n}") - label := fmt.Sprintf("%[1]s != nil { return %[1]s }", errText) if needsIf { label = "if " + label } @@ -303,3 +320,41 @@ func (c *completer) addErrCheckAndReturn() { snippet: &snip, }) } + +// getTestVar checks the function signature's input parameters and returns +// the name of the first parameter that implements "testing.TB". For example, +// func someFunc(t *testing.T) returns the string "t", func someFunc(b *testing.B) +// returns "b" etc. An empty string indicates that the function signature +// does not take a testing.TB parameter or does so but is ignored such +// as func someFunc(*testing.T). +func getTestVar(enclosingFunc *funcInfo, pkg source.Package) string { + if enclosingFunc == nil || enclosingFunc.sig == nil { + return "" + } + + sig := enclosingFunc.sig + for i := 0; i < sig.Params().Len(); i++ { + param := sig.Params().At(i) + if param.Name() == "_" { + continue + } + testingPkg, err := pkg.GetImport("testing") + if err != nil { + continue + } + tbObj := testingPkg.GetTypes().Scope().Lookup("TB") + if tbObj == nil { + continue + } + iface, ok := tbObj.Type().Underlying().(*types.Interface) + if !ok { + continue + } + if !types.Implements(param.Type(), iface) { + continue + } + return param.Name() + } + + return "" +} diff --git a/internal/lsp/testdata/statements/if_err_check_test.go b/internal/lsp/testdata/statements/if_err_check_test.go new file mode 100644 index 0000000000..6de5878798 --- /dev/null +++ b/internal/lsp/testdata/statements/if_err_check_test.go @@ -0,0 +1,20 @@ +package statements + +import ( + "os" + "testing" +) + +func TestErr(t *testing.T) { + /* if err != nil { t.Fatal(err) } */ //@item(stmtOneIfErrTFatal, "if err != nil { t.Fatal(err) }", "", "") + + _, err := os.Open("foo") + //@snippet("", stmtOneIfErrTFatal, "", "if err != nil {\n\tt.Fatal(err)\n\\}") +} + +func BenchmarkErr(b *testing.B) { + /* if err != nil { b.Fatal(err) } */ //@item(stmtOneIfErrBFatal, "if err != nil { b.Fatal(err) }", "", "") + + _, err := os.Open("foo") + //@snippet("", stmtOneIfErrBFatal, "", "if err != nil {\n\tb.Fatal(err)\n\\}") +} diff --git a/internal/lsp/testdata/summary.txt.golden b/internal/lsp/testdata/summary.txt.golden index e0b6366efd..8151426bda 100644 --- a/internal/lsp/testdata/summary.txt.golden +++ b/internal/lsp/testdata/summary.txt.golden @@ -2,7 +2,7 @@ CallHierarchyCount = 2 CodeLensCount = 5 CompletionsCount = 258 -CompletionSnippetCount = 92 +CompletionSnippetCount = 94 UnimportedCompletionsCount = 5 DeepCompletionsCount = 5 FuzzyCompletionsCount = 8