diff --git a/internal/lsp/cmd/highlight.go b/internal/lsp/cmd/highlight.go index c44ebc9c13..373bc97b89 100644 --- a/internal/lsp/cmd/highlight.go +++ b/internal/lsp/cmd/highlight.go @@ -8,6 +8,7 @@ import ( "context" "flag" "fmt" + "sort" "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/span" @@ -68,14 +69,22 @@ func (r *highlight) Run(ctx context.Context, args ...string) error { return err } + var results []span.Span for _, h := range highlights { l := protocol.Location{Range: h.Range} s, err := file.mapper.Span(l) if err != nil { return err } + results = append(results, s) + } + // Sort results to make tests deterministic since DocumentHighlight uses a map. + sort.SliceStable(results, func(i, j int) bool { + return span.Compare(results[i], results[j]) == -1 + }) + + for _, s := range results { fmt.Println(s) } - return nil } diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 686c1cee05..0fd2258357 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -495,11 +495,11 @@ func (r *runner) Implementation(t *testing.T, spn span.Span, m tests.Implementat } func (r *runner) Highlight(t *testing.T, src span.Span, locations []span.Span) { - m, err := r.data.Mapper(locations[0].URI()) + m, err := r.data.Mapper(src.URI()) if err != nil { t.Fatal(err) } - loc, err := m.Location(locations[0]) + loc, err := m.Location(src) if err != nil { t.Fatalf("failed for %v: %v", locations[0], err) } @@ -517,11 +517,23 @@ func (r *runner) Highlight(t *testing.T, src span.Span, locations []span.Span) { if len(highlights) != len(locations) { t.Fatalf("got %d highlights for highlight at %v:%v:%v, expected %d", len(highlights), src.URI().Filename(), src.Start().Line(), src.Start().Column(), len(locations)) } + // Check to make sure highlights have a valid range. + var results []span.Span for i := range highlights { - if h, err := m.RangeSpan(highlights[i].Range); err != nil { + h, err := m.RangeSpan(highlights[i].Range) + if err != nil { t.Fatalf("failed for %v: %v", highlights[i], err) - } else if h != locations[i] { - t.Errorf("want %v, got %v\n", locations[i], h) + } + results = append(results, h) + } + // Sort results to make tests deterministic since DocumentHighlight uses a map. + sort.SliceStable(results, func(i, j int) bool { + return span.Compare(results[i], results[j]) == -1 + }) + // Check to make sure all the expected highlights are found. + for i := range results { + if results[i] != locations[i] { + t.Errorf("want %v, got %v\n", locations[i], results[i]) } } } diff --git a/internal/lsp/source/highlight.go b/internal/lsp/source/highlight.go index 38ba428e6e..f5c28c224a 100644 --- a/internal/lsp/source/highlight.go +++ b/internal/lsp/source/highlight.go @@ -7,6 +7,7 @@ package source import ( "context" "go/ast" + "go/token" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/internal/lsp/protocol" @@ -52,22 +53,129 @@ func Highlight(ctx context.Context, snapshot Snapshot, f File, pos protocol.Posi if len(path) == 0 { return nil, errors.Errorf("no enclosing position found for %v:%v", int(pos.Line), int(pos.Character)) } - switch path[0].(type) { + case *ast.ReturnStmt, *ast.FuncDecl, *ast.FuncType, *ast.BasicLit: + return highlightFuncControlFlow(ctx, snapshot, m, path) case *ast.Ident: return highlightIdentifiers(ctx, snapshot, m, path, pkg) case *ast.BranchStmt, *ast.ForStmt, *ast.RangeStmt: - return highlightControlFlow(ctx, snapshot, m, path) + return highlightLoopControlFlow(ctx, snapshot, m, path) } - // If the cursor is in an unidentified area, return empty results. return nil, nil } -func highlightControlFlow(ctx context.Context, snapshot Snapshot, m *protocol.ColumnMapper, path []ast.Node) ([]protocol.Range, error) { - // Reverse walk the path till we get to the for loop. +func highlightFuncControlFlow(ctx context.Context, snapshot Snapshot, m *protocol.ColumnMapper, path []ast.Node) ([]protocol.Range, error) { + var enclosingFunc ast.Node + var returnStmt *ast.ReturnStmt + var resultsList *ast.FieldList + inReturnList := false +Outer: + // Reverse walk the path till we get to the func block. + for _, n := range path { + switch node := n.(type) { + case *ast.Field: + inReturnList = true + case *ast.FuncLit: + enclosingFunc = n + resultsList = node.Type.Results + break Outer + case *ast.FuncDecl: + enclosingFunc = n + resultsList = node.Type.Results + break Outer + case *ast.ReturnStmt: + returnStmt = node + // If the cursor is not directly in a *ast.ReturnStmt, then + // we need to know if it is within one of the values that is being returned. + inReturnList = inReturnList || path[0] != returnStmt + } + } + // If the cursor is on a "return" or "func" keyword, we should highlight all of the exit + // points of the function, including the "return" and "func" keywords. + highlightAllReturnsAndFunc := path[0] == returnStmt || path[0] == enclosingFunc + switch path[0].(type) { + case *ast.Ident, *ast.BasicLit: + // Cursor is in an identifier and not in a return statement or in the results list. + if returnStmt == nil && !inReturnList { + return nil, nil + } + case *ast.FuncType: + highlightAllReturnsAndFunc = true + } + // The user's cursor may be within the return statement of a function, + // or within the result section of a function's signature. + // index := -1 + var nodes []ast.Node + if returnStmt != nil { + for _, n := range returnStmt.Results { + nodes = append(nodes, n) + } + } else if resultsList != nil { + for _, n := range resultsList.List { + nodes = append(nodes, n) + } + } + _, index := nodeAtPos(nodes, path[0].Pos()) + + result := make(map[protocol.Range]bool) + // Highlight the correct argument in the function declaration return types. + if resultsList != nil && -1 < index && index < len(resultsList.List) { + rng, err := nodeToProtocolRange(ctx, snapshot.View(), m, resultsList.List[index]) + if err != nil { + log.Error(ctx, "Error getting range for node", err) + } else { + result[rng] = true + } + } + // Add the "func" part of the func declaration. + if highlightAllReturnsAndFunc { + funcStmt, err := posToRange(snapshot.View(), m, enclosingFunc.Pos(), enclosingFunc.Pos()+token.Pos(len("func"))) + if err != nil { + return nil, err + } + rng, err := funcStmt.Range() + if err != nil { + return nil, err + } + result[rng] = true + } + // Traverse the AST to highlight the other relevant return statements in the function. + ast.Inspect(enclosingFunc, func(n ast.Node) bool { + // Don't traverse any other functions. + switch n.(type) { + case *ast.FuncDecl, *ast.FuncLit: + return enclosingFunc == n + } + if n, ok := n.(*ast.ReturnStmt); ok { + var toAdd ast.Node + // Add the entire return statement, applies when highlight the word "return" or "func". + if highlightAllReturnsAndFunc { + toAdd = n + } + // Add the relevant field within the entire return statement. + if -1 < index && index < len(n.Results) { + toAdd = n.Results[index] + } + if toAdd != nil { + rng, err := nodeToProtocolRange(ctx, snapshot.View(), m, toAdd) + if err != nil { + log.Error(ctx, "Error getting range for node", err) + } else { + result[rng] = true + } + return false + } + } + return true + }) + return rangeMapToSlice(result), nil +} + +func highlightLoopControlFlow(ctx context.Context, snapshot Snapshot, m *protocol.ColumnMapper, path []ast.Node) ([]protocol.Range, error) { var loop ast.Node Outer: + // Reverse walk the path till we get to the for loop. for _, n := range path { switch n.(type) { case *ast.ForStmt, *ast.RangeStmt: @@ -75,15 +183,13 @@ Outer: break Outer } } + // Cursor is not in a for loop. if loop == nil { - // Cursor is not in a for loop. return nil, nil } - - var result []protocol.Range - + result := make(map[protocol.Range]bool) // Add the for statement. - forStmt, err := posToRange(snapshot.View(), m, loop.Pos(), loop.Pos()+3) + forStmt, err := posToRange(snapshot.View(), m, loop.Pos(), loop.Pos()+token.Pos(len("for"))) if err != nil { return nil, err } @@ -91,7 +197,7 @@ Outer: if err != nil { return nil, err } - result = append(result, rng) + result[rng] = true ast.Inspect(loop, func(n ast.Node) bool { // Don't traverse any other for loops. @@ -99,27 +205,34 @@ Outer: case *ast.ForStmt, *ast.RangeStmt: return loop == n } - + // Add all branch statements in same scope as the identified one. if n, ok := n.(*ast.BranchStmt); ok { - // Add all branch statements in same scope as the identified one. rng, err := nodeToProtocolRange(ctx, snapshot.View(), m, n) if err != nil { log.Error(ctx, "Error getting range for node", err) return false } - result = append(result, rng) + result[rng] = true } return true }) - return result, nil + return rangeMapToSlice(result), nil } func highlightIdentifiers(ctx context.Context, snapshot Snapshot, m *protocol.ColumnMapper, path []ast.Node, pkg Package) ([]protocol.Range, error) { - var result []protocol.Range + result := make(map[protocol.Range]bool) id, ok := path[0].(*ast.Ident) if !ok { return nil, errors.Errorf("highlightIdentifiers called with an ast.Node of type %T", id) } + // Check if ident is inside return or func decl. + if toAdd, err := highlightFuncControlFlow(ctx, snapshot, m, path); toAdd != nil && err == nil { + for _, r := range toAdd { + result[r] = true + } + } + + // TODO: maybe check if ident is a reserved word, if true then don't continue and return results. idObj := pkg.GetTypesInfo().ObjectOf(id) ast.Inspect(path[len(path)-1], func(node ast.Node) bool { @@ -133,13 +246,20 @@ func highlightIdentifiers(ctx context.Context, snapshot Snapshot, m *protocol.Co if nObj := pkg.GetTypesInfo().ObjectOf(n); nObj != idObj { return false } - if rng, err := nodeToProtocolRange(ctx, snapshot.View(), m, n); err == nil { - result = append(result, rng) + result[rng] = true } else { log.Error(ctx, "Error getting range for node", err) } return false }) - return result, nil + return rangeMapToSlice(result), nil +} + +func rangeMapToSlice(rangeMap map[protocol.Range]bool) []protocol.Range { + var list []protocol.Range + for i := range rangeMap { + list = append(list, i) + } + return list } diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index cda30b4db7..98daac223c 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -617,13 +617,23 @@ func (r *runner) Highlight(t *testing.T, src span.Span, locations []span.Span) { if len(highlights) != len(locations) { t.Errorf("got %d highlights for highlight at %v:%v:%v, expected %d", len(highlights), src.URI().Filename(), src.Start().Line(), src.Start().Column(), len(locations)) } - for i, got := range highlights { - want, err := m.Range(locations[i]) + // Check to make sure highlights have a valid range. + var results []span.Span + for i := range highlights { + h, err := m.RangeSpan(highlights[i]) if err != nil { - t.Fatal(err) + t.Fatalf("failed for %v: %v", highlights[i], err) } - if got != want { - t.Errorf("want %v, got %v\n", want, got) + results = append(results, h) + } + // Sort results to make tests deterministic since DocumentHighlight uses a map. + sort.SliceStable(results, func(i, j int) bool { + return span.Compare(results[i], results[j]) == -1 + }) + // Check to make sure all the expected highlights are found. + for i := range results { + if results[i] != locations[i] { + t.Errorf("want %v, got %v\n", locations[i], results[i]) } } } diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go index 6ff1f5d766..16c1960b93 100644 --- a/internal/lsp/source/util.go +++ b/internal/lsp/source/util.go @@ -222,6 +222,19 @@ func (k FileKind) String() string { } } +// Returns the index and the node whose position is contained inside the node list. +func nodeAtPos(nodes []ast.Node, pos token.Pos) (ast.Node, int) { + if nodes == nil { + return nil, -1 + } + for i, node := range nodes { + if node.Pos() <= pos && pos <= node.End() { + return node, i + } + } + return nil, -1 +} + // indexExprAtPos returns the index of the expression containing pos. func indexExprAtPos(pos token.Pos, args []ast.Expr) int { for i, expr := range args { diff --git a/internal/lsp/testdata/highlights/highlights.go b/internal/lsp/testdata/highlights/highlights.go index fdcc174ddd..ff435307d2 100644 --- a/internal/lsp/testdata/highlights/highlights.go +++ b/internal/lsp/testdata/highlights/highlights.go @@ -2,6 +2,7 @@ package highlights import ( "fmt" + "sort" "golang.org/x/tools/internal/lsp/protocol" ) @@ -25,7 +26,7 @@ func testFunctions() { Print() //@mark(printTest, "Print"),highlight(printTest, printFunc, printTest) } -func toProtocolHighlight(rngs []protocol.Range) []protocol.DocumentHighlight { //@mark(doc1, "DocumentHighlight"),highlight(doc1, doc1, doc2, doc3) +func toProtocolHighlight(rngs []protocol.Range) []protocol.DocumentHighlight { //@mark(doc1, "DocumentHighlight"),mark(docRet1, "[]protocol.DocumentHighlight"),highlight(doc1, docRet1, doc1, doc2, doc3, result) result := make([]protocol.DocumentHighlight, 0, len(rngs)) //@mark(doc2, "DocumentHighlight"),highlight(doc2, doc1, doc2, doc3) kind := protocol.Text for _, rng := range rngs { @@ -34,7 +35,7 @@ func toProtocolHighlight(rngs []protocol.Range) []protocol.DocumentHighlight { / Range: rng, }) } - return result + return result //@mark(result, "result") } func testForLoops() { @@ -58,7 +59,6 @@ func testForLoops() { } arr := []int{} - for i := range arr { //@mark(forDecl4, "for"),highlight(forDecl4, forDecl4, brk4, cont4) if i > 8 { break //@mark(brk4, "break"),highlight(brk4, forDecl4, brk4, cont4) @@ -68,3 +68,30 @@ func testForLoops() { } } } + +func testReturn() bool { //@mark(func1, "func"),mark(bool1, "bool"),highlight(func1, func1, fullRet11, fullRet12),highlight(bool1, bool1, false1, bool2, true1) + if 1 < 2 { + return false //@mark(ret11, "return"),mark(fullRet11, "return false"),mark(false1, "false"),highlight(ret11, func1, fullRet11, fullRet12) + } + candidates := []int{} + sort.SliceStable(candidates, func(i, j int) bool { //@mark(func2, "func"),mark(bool2, "bool"),highlight(func2, func2, fullRet2) + return candidates[i] > candidates[j] //@mark(ret2, "return"),mark(fullRet2, "return candidates[i] > candidates[j]"),highlight(ret2, func2, fullRet2) + }) + return true //@mark(ret12, "return"),mark(fullRet12, "return true"),mark(true1, "true"),highlight(ret12, func1, fullRet11, fullRet12) +} + +func testReturnFields() float64 { //@mark(retVal1, "float64"),highlight(retVal1, retVal1, retVal11, retVal21) + if 1 < 2 { + return 20.1 //@mark(retVal11, "20.1"),highlight(retVal11, retVal1, retVal11, retVal21) + } + z := 4.3 //@mark(zDecl, "z") + return z //@mark(retVal21, "z"),highlight(retVal21, retVal1, retVal11, zDecl, retVal21) +} + +func testReturnMultipleFields() (float32, string) { //@mark(retVal31, "float32"),mark(retVal32, "string"),highlight(retVal31, retVal31, retVal41, retVal51),highlight(retVal32, retVal32, retVal42, retVal52) + y := "im a var" //@mark(yDecl, "y"), + if 1 < 2 { + return 20.1, y //@mark(retVal41, "20.1"),mark(retVal42, "y"),highlight(retVal41, retVal31, retVal41, retVal51),highlight(retVal42, retVal32, yDecl, retVal42, retVal52) + } + return 4.9, "test" //@mark(retVal51, "4.9"),mark(retVal52, "\"test\""),highlight(retVal51, retVal31, retVal41, retVal51),highlight(retVal52, retVal32, retVal42, retVal52) +} diff --git a/internal/lsp/testdata/summary.txt.golden b/internal/lsp/testdata/summary.txt.golden index 7637c339a0..fdfde62f62 100644 --- a/internal/lsp/testdata/summary.txt.golden +++ b/internal/lsp/testdata/summary.txt.golden @@ -13,7 +13,7 @@ ImportCount = 7 SuggestedFixCount = 1 DefinitionsCount = 38 TypeDefinitionsCount = 2 -HighlightsCount = 22 +HighlightsCount = 37 ReferencesCount = 7 RenamesCount = 22 PrepareRenamesCount = 8