diff --git a/internal/lsp/source/highlight.go b/internal/lsp/source/highlight.go index dc791be2b4..42fe30603e 100644 --- a/internal/lsp/source/highlight.go +++ b/internal/lsp/source/highlight.go @@ -54,7 +54,7 @@ func Highlight(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protoc } } - switch path[0].(type) { + switch node := path[0].(type) { case *ast.BasicLit: if len(path) > 1 { if _, ok := path[1].(*ast.ImportSpec); ok { @@ -66,9 +66,24 @@ func Highlight(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protoc return highlightFuncControlFlow(ctx, snapshot.View(), pkg, path) case *ast.Ident: return highlightIdentifiers(ctx, snapshot.View(), pkg, path) - case *ast.BranchStmt, *ast.ForStmt, *ast.RangeStmt: + case *ast.ForStmt, *ast.RangeStmt: return highlightLoopControlFlow(ctx, snapshot.View(), pkg, path) + case *ast.BranchStmt: + // BREAK can exit a loop, switch or select, while CONTINUE exit a loop so + // these need to be handled separately. They can also be embedded in any + // other loop/switch/select if they have a label. TODO: add support for + // GOTO and FALLTHROUGH as well. + if node.Label != nil { + return highlightLabeledFlow(ctx, snapshot.View(), pkg, node) + } + switch node.Tok { + case token.BREAK: + return highlightUnlabeledBreakFlow(ctx, snapshot.View(), pkg, path) + case token.CONTINUE: + return highlightLoopControlFlow(ctx, snapshot.View(), pkg, path) + } } + // If the cursor is in an unidentified area, return empty results. return nil, nil } @@ -78,6 +93,7 @@ func highlightFuncControlFlow(ctx context.Context, view View, pkg Package, path var returnStmt *ast.ReturnStmt var resultsList *ast.FieldList inReturnList := false + Outer: // Reverse walk the path till we get to the func block. for i, n := range path { @@ -195,21 +211,70 @@ Outer: return rangeMapToSlice(result), nil } -func highlightLoopControlFlow(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) { - var loop ast.Node -Outer: - // Reverse walk the path till we get to the for loop. +func highlightUnlabeledBreakFlow(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) { + // Reverse walk the path until we find closest loop, select or switch. for _, n := range path { switch n.(type) { case *ast.ForStmt, *ast.RangeStmt: - loop = n - break Outer + return highlightLoopControlFlow(ctx, view, pkg, path) + // TODO: add highlight when breaking a select or switch. + case *ast.SelectStmt, *ast.SwitchStmt: + return nil, nil + } + } + return nil, nil +} + +func highlightLabeledFlow(ctx context.Context, view View, pkg Package, node *ast.BranchStmt) ([]protocol.Range, error) { + obj := node.Label.Obj + if obj == nil || obj.Decl == nil { + return nil, nil + } + + label, ok := obj.Decl.(*ast.LabeledStmt) + if !ok { + return nil, nil + } + + switch label.Stmt.(type) { + case *ast.ForStmt, *ast.RangeStmt: + return highlightLoopControlFlow(ctx, view, pkg, []ast.Node{label.Stmt, label}) + } + + return nil, nil +} + +func highlightLoopControlFlow(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) { + labelFor := func(path []ast.Node) *ast.Ident { + if len(path) > 1 { + if n, ok := path[1].(*ast.LabeledStmt); ok { + return n.Label + } + } + return nil + } + + var loop ast.Node + var loopLabel *ast.Ident + stmtLabel := labelFor(path) +Outer: + // Reverse walk the path till we get to the for loop. + for i := range path { + switch n := path[i].(type) { + case *ast.ForStmt, *ast.RangeStmt: + loopLabel = labelFor(path[i:]) + + if stmtLabel == nil || loopLabel == stmtLabel { + loop = n + break Outer + } } } // Cursor is not in a for loop. if loop == nil { return nil, nil } + result := make(map[protocol.Range]bool) // Add the for statement. forStmt, err := posToMappedRange(view, pkg, loop.Pos(), loop.Pos()+token.Pos(len("for"))) @@ -222,14 +287,39 @@ Outer: } result[rng] = true + // Traverse AST to find branch statements within the same for-loop. + ast.Inspect(loop, func(n ast.Node) bool { + switch n.(type) { + case *ast.ForStmt, *ast.RangeStmt: + return loop == n + case *ast.SwitchStmt, *ast.SelectStmt: + return false + } + + b, ok := n.(*ast.BranchStmt) + if !ok { + return true + } + + if b.Label == nil || labelDecl(b.Label) == loopLabel { + rng, err := nodeToProtocolRange(view, pkg, b) + if err != nil { + event.Error(ctx, "Error getting range for node", err) + return false + } + result[rng] = true + } + return true + }) + + // Find continue statements in the same loop or switches/selects. ast.Inspect(loop, func(n ast.Node) bool { - // Don't traverse any other for loops. switch n.(type) { case *ast.ForStmt, *ast.RangeStmt: return loop == n } - // Add all branch statements in same scope as the identified one. - if n, ok := n.(*ast.BranchStmt); ok { + + if n, ok := n.(*ast.BranchStmt); ok && n.Tok == token.CONTINUE { rng, err := nodeToProtocolRange(view, pkg, n) if err != nil { event.Error(ctx, "Error getting range for node", err) @@ -239,9 +329,53 @@ Outer: } return true }) + + // We don't need to check other for loops if we aren't looking for labeled statements. + if loopLabel == nil { + return rangeMapToSlice(result), nil + } + + // Find labeled branch statements in any loop + ast.Inspect(loop, func(n ast.Node) bool { + b, ok := n.(*ast.BranchStmt) + if !ok { + return true + } + + // Statment with labels that matches the loop. + if b.Label != nil && labelDecl(b.Label) == loopLabel { + rng, err := nodeToProtocolRange(view, pkg, b) + if err != nil { + event.Error(ctx, "Error getting range for node", err) + return false + } + result[rng] = true + } + + return true + }) + return rangeMapToSlice(result), nil } +func labelDecl(n *ast.Ident) *ast.Ident { + if n == nil { + return nil + } + if n.Obj == nil { + return nil + } + if n.Obj.Decl == nil { + return nil + } + stmt, ok := n.Obj.Decl.(*ast.LabeledStmt) + if !ok { + return nil + } + + return stmt.Label +} + func highlightImportUses(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) { result := make(map[protocol.Range]bool) basicLit, ok := path[0].(*ast.BasicLit) diff --git a/internal/lsp/testdata/lsp/primarymod/highlights/highlights.go b/internal/lsp/testdata/lsp/primarymod/highlights/highlights.go index db09b5650b..1bcc9e285f 100644 --- a/internal/lsp/testdata/lsp/primarymod/highlights/highlights.go +++ b/internal/lsp/testdata/lsp/primarymod/highlights/highlights.go @@ -72,6 +72,24 @@ func testForLoops() { continue //@mark(cont4, "continue"),highlight(cont4, forDecl4, brk4, cont4) } } + +Outer: + for i := 0; i < 10; i++ { //@mark(forDecl5, "for"),highlight(forDecl5, forDecl5, brk5, brk6, brk8) + break //@mark(brk5, "break"),highlight(brk5, forDecl5, brk5, brk6, brk8) + for { //@mark(forDecl6, "for"),highlight(forDecl6, forDecl6, cont5) + if i == 1 { + break Outer //@mark(brk6, "break Outer"),highlight(brk6, forDecl5, brk5, brk6, brk8) + } + switch i { //@mark(switch1, "switch"),highlight(switch1) + case 5: + break //@mark(brk7, "break"),highlight(brk7) + case 6: + continue //@mark(cont5, "continue"),highlight(cont5, forDecl6, cont5) + case 7: + break Outer //@mark(brk8, "break Outer"),highlight(brk8, forDecl5, brk5, brk6, brk8) + } + } + } } func testReturn() bool { //@mark(func1, "func"),mark(bool1, "bool"),highlight(func1, func1, fullRet11, fullRet12),highlight(bool1, bool1, false1, bool2, true1) diff --git a/internal/lsp/testdata/lsp/summary.txt.golden b/internal/lsp/testdata/lsp/summary.txt.golden index 2990fec677..7469015615 100644 --- a/internal/lsp/testdata/lsp/summary.txt.golden +++ b/internal/lsp/testdata/lsp/summary.txt.golden @@ -14,7 +14,7 @@ ImportCount = 8 SuggestedFixCount = 6 DefinitionsCount = 53 TypeDefinitionsCount = 2 -HighlightsCount = 52 +HighlightsCount = 60 ReferencesCount = 11 RenamesCount = 24 PrepareRenamesCount = 7