diff --git a/internal/lsp/source/folding_range.go b/internal/lsp/source/folding_range.go index 22da1eaf3e..96158e1700 100644 --- a/internal/lsp/source/folding_range.go +++ b/internal/lsp/source/folding_range.go @@ -16,9 +16,10 @@ type FoldingRangeInfo struct { } // FoldingRange gets all of the folding range for f. -func FoldingRange(ctx context.Context, view View, f GoFile, lineFoldingOnly bool) (ranges []FoldingRangeInfo, err error) { +func FoldingRange(ctx context.Context, view View, f GoFile, lineFoldingOnly bool) (ranges []*FoldingRangeInfo, err error) { // TODO(suzmue): consider limiting the number of folding ranges returned, and // implement a way to prioritize folding ranges in that case. + fset := f.FileSet() file, err := f.GetAST(ctx, ParseFull) if err != nil { return nil, err @@ -27,40 +28,15 @@ func FoldingRange(ctx context.Context, view View, f GoFile, lineFoldingOnly bool // Get folding ranges for comments separately as they are not walked by ast.Inspect. ranges = append(ranges, commentsFoldingRange(f.FileSet(), file)...) - visit := func(n ast.Node) bool { - var kind protocol.FoldingRangeKind - var start, end token.Pos - switch n := n.(type) { - case *ast.BlockStmt: - // Fold from position of "{" to position of "}". - start, end = n.Lbrace+1, n.Rbrace - case *ast.CaseClause: - // Fold from position of ":" to end. - start, end = n.Colon+1, n.End() - case *ast.CallExpr: - // Fold from position of "(" to position of ")". - start, end = n.Lparen+1, n.Rparen - case *ast.FieldList: - // Fold from position of opening parenthesis/brace, to position of - // closing parenthesis/brace. - start, end = n.Opening+1, n.Closing - case *ast.GenDecl: - // If this is an import declaration, set the kind to be protocol.Imports. - if n.Tok == token.IMPORT { - kind = protocol.Imports - } - // Fold from position of "(" to position of ")". - start, end = n.Lparen+1, n.Rparen - } + foldingFunc := foldingRange + if lineFoldingOnly { + foldingFunc = lineFoldingRange + } - if start.IsValid() && end.IsValid() { - if lineFoldingOnly && f.FileSet().Position(start).Line == f.FileSet().Position(end).Line { - return true - } - ranges = append(ranges, FoldingRangeInfo{ - Range: span.NewRange(f.FileSet(), start, end), - Kind: kind, - }) + visit := func(n ast.Node) bool { + rng := foldingFunc(fset, n) + if rng != nil { + ranges = append(ranges, rng) } return true } @@ -79,17 +55,135 @@ func FoldingRange(ctx context.Context, view View, f GoFile, lineFoldingOnly bool return ranges, nil } +// foldingRange calculates the folding range for n. +func foldingRange(fset *token.FileSet, n ast.Node) *FoldingRangeInfo { + var kind protocol.FoldingRangeKind + var start, end token.Pos + switch n := n.(type) { + case *ast.BlockStmt: + // Fold from position of "{" to position of "}". + start, end = n.Lbrace+1, n.Rbrace + case *ast.CaseClause: + // Fold from position of ":" to end. + start, end = n.Colon+1, n.End() + case *ast.CallExpr: + // Fold from position of "(" to position of ")". + start, end = n.Lparen+1, n.Rparen + case *ast.FieldList: + // Fold from position of opening parenthesis/brace, to position of + // closing parenthesis/brace. + start, end = n.Opening+1, n.Closing + case *ast.GenDecl: + // If this is an import declaration, set the kind to be protocol.Imports. + if n.Tok == token.IMPORT { + kind = protocol.Imports + } + start, end = n.Lparen+1, n.Rparen + } + + if !start.IsValid() || !end.IsValid() { + return nil + } + return &FoldingRangeInfo{ + Range: span.NewRange(fset, start, end), + Kind: kind, + } +} + +// lineFoldingRange calculates the line folding range for n. +func lineFoldingRange(fset *token.FileSet, n ast.Node) *FoldingRangeInfo { + // TODO(suzmue): include trailing empty lines before the closing + // parenthesis/brace. + var kind protocol.FoldingRangeKind + var start, end token.Pos + switch n := n.(type) { + case *ast.BlockStmt: + // Fold lines between "{" and "}". + if !n.Lbrace.IsValid() || !n.Rbrace.IsValid() { + break + } + nStmts := len(n.List) + if nStmts == 0 { + break + } + // Don't want to fold if the start is on the same line as the brace. + if fset.Position(n.Lbrace).Line == fset.Position(n.List[0].Pos()).Line { + break + } + // Don't want to fold if the end is on the same line as the brace. + if fset.Position(n.Rbrace).Line == fset.Position(n.List[nStmts-1].End()).Line { + break + } + start, end = n.Lbrace+1, n.List[nStmts-1].End() + case *ast.CaseClause: + // Fold from position of ":" to end. + start, end = n.Colon+1, n.End() + case *ast.FieldList: + // Fold lines between opening parenthesis/brace and closing parenthesis/brace. + if !n.Opening.IsValid() || !n.Closing.IsValid() { + break + } + nFields := len(n.List) + if nFields == 0 { + break + } + // Don't want to fold if the start is on the same line as the parenthesis/brace. + if fset.Position(n.Opening).Line == fset.Position(n.List[nFields-1].End()).Line { + break + } + // Don't want to fold if the end is on the same line as the parenthesis/brace. + if fset.Position(n.Closing).Line == fset.Position(n.List[nFields-1].End()).Line { + break + } + start, end = n.Opening+1, n.List[nFields-1].End() + case *ast.GenDecl: + // If this is an import declaration, set the kind to be protocol.Imports. + if n.Tok == token.IMPORT { + kind = protocol.Imports + } + // Fold from position of "(" to position of ")". + if !n.Lparen.IsValid() || !n.Rparen.IsValid() { + break + } + nSpecs := len(n.Specs) + if nSpecs == 0 { + break + } + // Don't want to fold if the end is on the same line as the parenthesis/brace. + if fset.Position(n.Lparen).Line == fset.Position(n.Specs[0].Pos()).Line { + break + } + // Don't want to fold if the end is on the same line as the parenthesis/brace. + if fset.Position(n.Rparen).Line == fset.Position(n.Specs[nSpecs-1].End()).Line { + break + } + start, end = n.Lparen+1, n.Specs[nSpecs-1].End() + } + + // Check that folding positions are valid. + if !start.IsValid() || !end.IsValid() { + return nil + } + // Do not fold if the start and end lines are the same. + if fset.Position(start).Line == fset.Position(end).Line { + return nil + } + return &FoldingRangeInfo{ + Range: span.NewRange(fset, start, end), + Kind: kind, + } +} + // commentsFoldingRange returns the folding ranges for all comment blocks in file. // The folding range starts at the end of the first comment, and ends at the end of the // comment block and has kind protocol.Comment. -func commentsFoldingRange(fset *token.FileSet, file *ast.File) []FoldingRangeInfo { - var comments []FoldingRangeInfo +func commentsFoldingRange(fset *token.FileSet, file *ast.File) (comments []*FoldingRangeInfo) { for _, commentGrp := range file.Comments { // Don't fold single comments. if len(commentGrp.List) <= 1 { continue } - comments = append(comments, FoldingRangeInfo{ + comments = append(comments, &FoldingRangeInfo{ // Fold from the end of the first line comment to the end of the comment block. Range: span.NewRange(fset, commentGrp.List[0].End(), commentGrp.End()), Kind: protocol.Comment, @@ -98,7 +192,7 @@ func commentsFoldingRange(fset *token.FileSet, file *ast.File) []FoldingRangeInf return comments } -func ToProtocolFoldingRanges(m *protocol.ColumnMapper, ranges []FoldingRangeInfo) ([]protocol.FoldingRange, error) { +func ToProtocolFoldingRanges(m *protocol.ColumnMapper, ranges []*FoldingRangeInfo) ([]protocol.FoldingRange, error) { var res []protocol.FoldingRange for _, r := range ranges { spn, err := r.Range.Span() diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index a80b919356..a85bcbce82 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -300,7 +300,7 @@ func (r *runner) FoldingRange(t *testing.T, data tests.FoldingRanges) { } } -func (r *runner) foldingRanges(t *testing.T, prefix string, uri span.URI, data string, ranges []source.FoldingRangeInfo) { +func (r *runner) foldingRanges(t *testing.T, prefix string, uri span.URI, data string, ranges []*source.FoldingRangeInfo) { t.Helper() // Fold all ranges. nonOverlapping := nonOverlappingRanges(ranges) @@ -323,7 +323,7 @@ func (r *runner) foldingRanges(t *testing.T, prefix string, uri span.URI, data s // Filter by kind. kinds := []protocol.FoldingRangeKind{protocol.Imports, protocol.Comment} for _, kind := range kinds { - var kindOnly []source.FoldingRangeInfo + var kindOnly []*source.FoldingRangeInfo for _, fRng := range ranges { if fRng.Kind == kind { kindOnly = append(kindOnly, fRng) @@ -350,7 +350,7 @@ func (r *runner) foldingRanges(t *testing.T, prefix string, uri span.URI, data s } } -func nonOverlappingRanges(ranges []source.FoldingRangeInfo) (res [][]source.FoldingRangeInfo) { +func nonOverlappingRanges(ranges []*source.FoldingRangeInfo) (res [][]*source.FoldingRangeInfo) { for _, fRng := range ranges { setNum := len(res) for i := 0; i < len(res); i++ { @@ -367,19 +367,19 @@ func nonOverlappingRanges(ranges []source.FoldingRangeInfo) (res [][]source.Fold } } if setNum == len(res) { - res = append(res, []source.FoldingRangeInfo{}) + res = append(res, []*source.FoldingRangeInfo{}) } res[setNum] = append(res[setNum], fRng) } return res } -func conflict(a, b source.FoldingRangeInfo) bool { +func conflict(a, b *source.FoldingRangeInfo) bool { // a start position is <= b start positions return a.Range.Start <= b.Range.Start && a.Range.End > b.Range.Start } -func foldRanges(contents string, ranges []source.FoldingRangeInfo) (string, error) { +func foldRanges(contents string, ranges []*source.FoldingRangeInfo) (string, error) { foldedText := "<>" res := contents // Apply the folds from the end of the file forward diff --git a/internal/lsp/testdata/folding/a.go b/internal/lsp/testdata/folding/a.go index e691ee00e0..388a3735a7 100644 --- a/internal/lsp/testdata/folding/a.go +++ b/internal/lsp/testdata/folding/a.go @@ -14,15 +14,17 @@ func bar() string { case true: if true { fmt.Println("true") + } else { + fmt.Println("false") } case false: fmt.Println("false") default: fmt.Println("default") } - + // This is a multiline comment + // that is not a doc comment. return ` this string is not indented` - } diff --git a/internal/lsp/testdata/folding/a.go.golden b/internal/lsp/testdata/folding/a.go.golden index 15ad786f44..15feb6d4a7 100644 --- a/internal/lsp/testdata/folding/a.go.golden +++ b/internal/lsp/testdata/folding/a.go.golden @@ -22,11 +22,10 @@ import _ "os" // With a multiline doc comment. func bar() string { switch {<>} - + // This is a multiline comment<> return ` this string is not indented` - } -- foldingRange-2 -- @@ -47,11 +46,11 @@ func bar() string { case false:<> default:<> } - + // This is a multiline comment + // that is not a doc comment. return ` this string is not indented` - } -- foldingRange-3 -- @@ -69,17 +68,17 @@ import _ "os" func bar() string { switch { case true: - if true {<>} + if true {<>} else {<>} case false: fmt.Println(<>) default: fmt.Println(<>) } - + // This is a multiline comment + // that is not a doc comment. return ` this string is not indented` - } -- foldingRange-4 -- @@ -99,17 +98,19 @@ func bar() string { case true: if true { fmt.Println(<>) + } else { + fmt.Println(<>) } case false: fmt.Println("false") default: fmt.Println("default") } - + // This is a multiline comment + // that is not a doc comment. return ` this string is not indented` - } -- foldingRange-comment-0 -- @@ -128,17 +129,18 @@ func bar() string { case true: if true { fmt.Println("true") + } else { + fmt.Println("false") } case false: fmt.Println("false") default: fmt.Println("default") } - + // This is a multiline comment<> return ` this string is not indented` - } -- foldingRange-imports-0 -- @@ -155,28 +157,32 @@ func bar() string { case true: if true { fmt.Println("true") + } else { + fmt.Println("false") } case false: fmt.Println("false") default: fmt.Println("default") } - + // This is a multiline comment + // that is not a doc comment. return ` this string is not indented` - } -- foldingRange-lineFolding-0 -- package folding //@fold("package") -import (<>) +import (<> +) import _ "os" // bar is a function.<> -func bar() string {<>} +func bar() string {<> +} -- foldingRange-lineFolding-1 -- package folding //@fold("package") @@ -191,12 +197,12 @@ import _ "os" // bar is a function. // With a multiline doc comment. func bar() string { - switch {<>} - + switch {<> + } + // This is a multiline comment<> return ` this string is not indented` - } -- foldingRange-lineFolding-2 -- @@ -217,11 +223,11 @@ func bar() string { case false:<> default:<> } - + // This is a multiline comment + // that is not a doc comment. return ` this string is not indented` - } -- foldingRange-lineFolding-3 -- @@ -239,17 +245,19 @@ import _ "os" func bar() string { switch { case true: - if true {<>} + if true {<> + } else {<> + } case false: fmt.Println("false") default: fmt.Println("default") } - + // This is a multiline comment + // that is not a doc comment. return ` this string is not indented` - } -- foldingRange-lineFolding-comment-0 -- @@ -268,23 +276,25 @@ func bar() string { case true: if true { fmt.Println("true") + } else { + fmt.Println("false") } case false: fmt.Println("false") default: fmt.Println("default") } - + // This is a multiline comment<> return ` this string is not indented` - } -- foldingRange-lineFolding-imports-0 -- package folding //@fold("package") -import (<>) +import (<> +) import _ "os" @@ -295,16 +305,18 @@ func bar() string { case true: if true { fmt.Println("true") + } else { + fmt.Println("false") } case false: fmt.Println("false") default: fmt.Println("default") } - + // This is a multiline comment + // that is not a doc comment. return ` this string is not indented` - } diff --git a/internal/lsp/testdata/folding/bad.go.golden b/internal/lsp/testdata/folding/bad.go.golden new file mode 100644 index 0000000000..a502d50847 --- /dev/null +++ b/internal/lsp/testdata/folding/bad.go.golden @@ -0,0 +1,119 @@ +-- foldingRange-0 -- +package folding //@fold("package") + +import (<>) + +import (<>) + +// badBar is a function. +func badBar(<>) string {<>} + +-- foldingRange-1 -- +package folding //@fold("package") + +import ( "fmt" + _ "log" +) + +import ( + _ "os" ) + +// badBar is a function. +func badBar() string { x := true + if x {<>} else {<>} + return +} + +-- foldingRange-2 -- +package folding //@fold("package") + +import ( "fmt" + _ "log" +) + +import ( + _ "os" ) + +// badBar is a function. +func badBar() string { x := true + if x { + // This is the only foldable thing in this file when lineFoldingOnly + fmt.Println(<>) + } else { + fmt.Println(<>) } + return +} + +-- foldingRange-comment-0 -- +package folding //@fold("package") + +import ( "fmt" + _ "log" +) + +import ( + _ "os" ) + +// badBar is a function.<> +func badBar() string { x := true + if x { + fmt.Println("true") + } else { + fmt.Println("false") } + return +} + +-- foldingRange-imports-0 -- +package folding //@fold("package") + +import (<>) + +import (<>) + +// badBar is a function. +func badBar() string { x := true + if x { + // This is the only foldable thing in this file when lineFoldingOnly + fmt.Println("true") + } else { + fmt.Println("false") } + return +} + +-- foldingRange-lineFolding-0 -- +package folding //@fold("package") + +import ( "fmt" + _ "log" +) + +import ( + _ "os" ) + +// badBar is a function. +func badBar() string { x := true + if x {<> + } else { + fmt.Println("false") } + return +} + +-- foldingRange-lineFolding-comment-0 -- +package folding //@fold("package") + +import ( "fmt" + _ "log" +) + +import ( + _ "os" ) + +// badBar is a function.<> +func badBar() string { x := true + if x { + fmt.Println("true") + } else { + fmt.Println("false") } + return +} + diff --git a/internal/lsp/testdata/folding/bad.go.in b/internal/lsp/testdata/folding/bad.go.in new file mode 100644 index 0000000000..84fcb740f4 --- /dev/null +++ b/internal/lsp/testdata/folding/bad.go.in @@ -0,0 +1,18 @@ +package folding //@fold("package") + +import ( "fmt" + _ "log" +) + +import ( + _ "os" ) + +// badBar is a function. +func badBar() string { x := true + if x { + // This is the only foldable thing in this file when lineFoldingOnly + fmt.Println("true") + } else { + fmt.Println("false") } + return +} diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index 956a9226e8..3c3d33a1da 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -36,7 +36,7 @@ const ( ExpectedImportCount = 2 ExpectedDefinitionsCount = 39 ExpectedTypeDefinitionsCount = 2 - ExpectedFoldingRangesCount = 1 + ExpectedFoldingRangesCount = 2 ExpectedHighlightsCount = 2 ExpectedReferencesCount = 6 ExpectedRenamesCount = 20