diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go index 9f4c68a185..eb75680fdb 100644 --- a/internal/analysisinternal/analysis.go +++ b/internal/analysisinternal/analysis.go @@ -50,7 +50,7 @@ func ZeroValue(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.T default: panic("unknown basic type") } - case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice: + case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice, *types.Array: return ast.NewIdent("nil") case *types.Struct: texpr := TypeExpr(fset, f, pkg, typ) // typ because we want the name here. @@ -60,21 +60,23 @@ func ZeroValue(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.T return &ast.CompositeLit{ Type: texpr, } - case *types.Array: - texpr := TypeExpr(fset, f, pkg, u.Elem()) - if texpr == nil { - return nil - } - return &ast.CompositeLit{ - Type: &ast.ArrayType{ - Elt: texpr, - Len: &ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%v", u.Len())}, - }, - } } return nil } +// IsZeroValue checks whether the given expression is a 'zero value' (as determined by output of +// analysisinternal.ZeroValue) +func IsZeroValue(expr ast.Expr) bool { + switch e := expr.(type) { + case *ast.BasicLit: + return e.Value == "0" || e.Value == `""` + case *ast.Ident: + return e.Name == "nil" || e.Name == "false" + default: + return false + } +} + func TypeExpr(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { switch t := typ.(type) { case *types.Basic: diff --git a/internal/lsp/analysis/fillreturns/fillreturns.go b/internal/lsp/analysis/fillreturns/fillreturns.go index 4fd88772c9..8e93b71c27 100644 --- a/internal/lsp/analysis/fillreturns/fillreturns.go +++ b/internal/lsp/analysis/fillreturns/fillreturns.go @@ -143,7 +143,8 @@ outer: fixed[i] = match remaining = append(remaining[:idx], remaining[idx+1:]...) } else { - zv := analysisinternal.ZeroValue(pass.Fset, file, pass.Pkg, info.TypeOf(result.Type)) + zv := analysisinternal.ZeroValue(pass.Fset, file, pass.Pkg, + info.TypeOf(result.Type)) if zv == nil { return nil, nil } @@ -151,8 +152,15 @@ outer: } } + // Remove any non-matching "zero values" from the leftover values. + var nonZeroRemaining []ast.Expr + for _, expr := range remaining { + if !analysisinternal.IsZeroValue(expr) { + nonZeroRemaining = append(nonZeroRemaining, expr) + } + } // Append leftover return values to end of new return statement. - fixed = append(fixed, remaining...) + fixed = append(fixed, nonZeroRemaining...) newRet := &ast.ReturnStmt{ Return: ret.Pos(), @@ -200,14 +208,11 @@ func FixesError(msg string) bool { if len(matches) < 3 { return false } - wantNum, err := strconv.Atoi(matches[1]) - if err != nil { + if _, err := strconv.Atoi(matches[1]); err != nil { return false } - gotNum, err := strconv.Atoi(matches[2]) - if err != nil { + if _, err := strconv.Atoi(matches[2]); err != nil { return false } - // Logic for handling more return values than expected is hard. - return wantNum >= gotNum + return true } diff --git a/internal/lsp/analysis/fillreturns/testdata/src/a/a.go b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go index f80bbf65c5..74c87e06f7 100644 --- a/internal/lsp/analysis/fillreturns/testdata/src/a/a.go +++ b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go @@ -111,3 +111,12 @@ func localFuncMultipleReturn() (string, int, error, string) { func multipleUnused() (int, string, string, string) { return 3, 4, 5 // want "wrong number of return values \\(want 4, got 3\\)" } + +func gotTooMany() int { + if true { + return 0, "" // want "wrong number of return values \\(want 1, got 2\\)" + } else { + return 1, 0, nil // want "wrong number of return values \\(want 1, got 3\\)" + } + return 0, 5, false // want "wrong number of return values \\(want 1, got 3\\)" +} diff --git a/internal/lsp/analysis/fillreturns/testdata/src/a/a.go.golden b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go.golden index 15a82b1be1..d4a847172d 100644 --- a/internal/lsp/analysis/fillreturns/testdata/src/a/a.go.golden +++ b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go.golden @@ -64,7 +64,7 @@ func basic() (uint8, uint16, uint32, uint64, int8, int16, int32, int64, float32, } func complex() (*int, []int, [2]int, map[int]int) { - return nil, nil, [2]int{}, nil // want "wrong number of return values \\(want 4, got 0\\)" + return nil, nil, nil, nil // want "wrong number of return values \\(want 4, got 0\\)" } func structsAndInterfaces() (T, url.URL, T1, I, I1, io.Reader, Client, ast2.Stmt) { @@ -111,3 +111,12 @@ func localFuncMultipleReturn() (string, int, error, string) { func multipleUnused() (int, string, string, string) { return 3, "", "", "", 4, 5 // want "wrong number of return values \\(want 4, got 3\\)" } + +func gotTooMany() int { + if true { + return 0 // want "wrong number of return values \\(want 1, got 2\\)" + } else { + return 1 // want "wrong number of return values \\(want 1, got 3\\)" + } + return 0, 5 // want "wrong number of return values \\(want 1, got 3\\)" +}