diff --git a/internal/lsp/source/highlight.go b/internal/lsp/source/highlight.go index b453157c32..07a2ac6161 100644 --- a/internal/lsp/source/highlight.go +++ b/internal/lsp/source/highlight.go @@ -9,6 +9,8 @@ import ( "fmt" "go/ast" "go/token" + "go/types" + "strings" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/internal/lsp/protocol" @@ -54,18 +56,25 @@ func Highlight(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protoc } switch path[0].(type) { - case *ast.ReturnStmt, *ast.FuncDecl, *ast.FuncType, *ast.BasicLit: - return highlightFuncControlFlow(ctx, snapshot, pkg, path) + case *ast.BasicLit: + if len(path) > 1 { + if _, ok := path[1].(*ast.ImportSpec); ok { + return highlightImportUses(ctx, snapshot.View(), pkg, path) + } + } + return highlightFuncControlFlow(ctx, snapshot.View(), pkg, path) + case *ast.ReturnStmt, *ast.FuncDecl, *ast.FuncType: + return highlightFuncControlFlow(ctx, snapshot.View(), pkg, path) case *ast.Ident: - return highlightIdentifiers(ctx, snapshot, pkg, path) + return highlightIdentifiers(ctx, snapshot.View(), pkg, path) case *ast.BranchStmt, *ast.ForStmt, *ast.RangeStmt: - return highlightLoopControlFlow(ctx, snapshot, pkg, path) + return highlightLoopControlFlow(ctx, snapshot.View(), pkg, path) } // If the cursor is in an unidentified area, return empty results. return nil, nil } -func highlightFuncControlFlow(ctx context.Context, snapshot Snapshot, pkg Package, path []ast.Node) ([]protocol.Range, error) { +func highlightFuncControlFlow(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) { var enclosingFunc ast.Node var returnStmt *ast.ReturnStmt var resultsList *ast.FieldList @@ -137,7 +146,7 @@ Outer: result := make(map[protocol.Range]bool) // Highlight the correct argument in the function declaration return types. if resultsList != nil && -1 < index && index < len(resultsList.List) { - rng, err := nodeToProtocolRange(snapshot.View(), pkg, resultsList.List[index]) + rng, err := nodeToProtocolRange(view, pkg, resultsList.List[index]) if err != nil { log.Error(ctx, "Error getting range for node", err) } else { @@ -146,7 +155,7 @@ Outer: } // Add the "func" part of the func declaration. if highlightAllReturnsAndFunc { - funcStmt, err := posToMappedRange(snapshot.View(), pkg, enclosingFunc.Pos(), enclosingFunc.Pos()+token.Pos(len("func"))) + funcStmt, err := posToMappedRange(view, pkg, enclosingFunc.Pos(), enclosingFunc.Pos()+token.Pos(len("func"))) if err != nil { return nil, err } @@ -174,7 +183,7 @@ Outer: toAdd = n.Results[index] } if toAdd != nil { - rng, err := nodeToProtocolRange(snapshot.View(), pkg, toAdd) + rng, err := nodeToProtocolRange(view, pkg, toAdd) if err != nil { log.Error(ctx, "Error getting range for node", err) } else { @@ -188,7 +197,7 @@ Outer: return rangeMapToSlice(result), nil } -func highlightLoopControlFlow(ctx context.Context, snapshot Snapshot, pkg Package, path []ast.Node) ([]protocol.Range, error) { +func highlightLoopControlFlow(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) { var loop ast.Node Outer: // Reverse walk the path till we get to the for loop. @@ -205,7 +214,7 @@ Outer: } result := make(map[protocol.Range]bool) // Add the for statement. - forStmt, err := posToMappedRange(snapshot.View(), pkg, loop.Pos(), loop.Pos()+token.Pos(len("for"))) + forStmt, err := posToMappedRange(view, pkg, loop.Pos(), loop.Pos()+token.Pos(len("for"))) if err != nil { return nil, err } @@ -223,7 +232,7 @@ Outer: } // Add all branch statements in same scope as the identified one. if n, ok := n.(*ast.BranchStmt); ok { - rng, err := nodeToProtocolRange(snapshot.View(), pkg, n) + rng, err := nodeToProtocolRange(view, pkg, n) if err != nil { log.Error(ctx, "Error getting range for node", err) return false @@ -235,14 +244,49 @@ Outer: return rangeMapToSlice(result), nil } -func highlightIdentifiers(ctx context.Context, snapshot Snapshot, pkg Package, path []ast.Node) ([]protocol.Range, error) { +func highlightImportUses(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) { + result := make(map[protocol.Range]bool) + basicLit, ok := path[0].(*ast.BasicLit) + if !ok { + return nil, errors.Errorf("highlightImportUses called with an ast.Node of type %T", basicLit) + } + + ast.Inspect(path[len(path)-1], func(node ast.Node) bool { + if imp, ok := node.(*ast.ImportSpec); ok && imp.Path == basicLit { + if rng, err := nodeToProtocolRange(view, pkg, node); err == nil { + result[rng] = true + return false + } + } + n, ok := node.(*ast.Ident) + if !ok { + return true + } + obj, ok := pkg.GetTypesInfo().ObjectOf(n).(*types.PkgName) + if !ok { + return true + } + if !strings.Contains(basicLit.Value, obj.Name()) { + return true + } + if rng, err := nodeToProtocolRange(view, pkg, n); err == nil { + result[rng] = true + } else { + log.Error(ctx, "Error getting range for node", err) + } + return false + }) + return rangeMapToSlice(result), nil +} + +func highlightIdentifiers(ctx context.Context, view View, pkg Package, path []ast.Node) ([]protocol.Range, error) { result := make(map[protocol.Range]bool) id, ok := path[0].(*ast.Ident) if !ok { return nil, errors.Errorf("highlightIdentifiers called with an ast.Node of type %T", id) } // Check if ident is inside return or func decl. - if toAdd, err := highlightFuncControlFlow(ctx, snapshot, pkg, path); toAdd != nil && err == nil { + if toAdd, err := highlightFuncControlFlow(ctx, view, pkg, path); toAdd != nil && err == nil { for _, r := range toAdd { result[r] = true } @@ -251,7 +295,13 @@ func highlightIdentifiers(ctx context.Context, snapshot Snapshot, pkg Package, p // TODO: maybe check if ident is a reserved word, if true then don't continue and return results. idObj := pkg.GetTypesInfo().ObjectOf(id) + pkgObj, isImported := idObj.(*types.PkgName) ast.Inspect(path[len(path)-1], func(node ast.Node) bool { + if imp, ok := node.(*ast.ImportSpec); ok && isImported { + if rng, err := highlightImport(view, pkg, pkgObj, imp); rng != nil && err == nil { + result[*rng] = true + } + } n, ok := node.(*ast.Ident) if !ok { return true @@ -262,7 +312,7 @@ func highlightIdentifiers(ctx context.Context, snapshot Snapshot, pkg Package, p if nObj := pkg.GetTypesInfo().ObjectOf(n); nObj != idObj { return false } - if rng, err := nodeToProtocolRange(snapshot.View(), pkg, n); err == nil { + if rng, err := nodeToProtocolRange(view, pkg, n); err == nil { result[rng] = true } else { log.Error(ctx, "Error getting range for node", err) @@ -272,6 +322,20 @@ func highlightIdentifiers(ctx context.Context, snapshot Snapshot, pkg Package, p return rangeMapToSlice(result), nil } +func highlightImport(view View, pkg Package, obj *types.PkgName, imp *ast.ImportSpec) (*protocol.Range, error) { + if imp.Name != nil || imp.Path == nil { + return nil, nil + } + if !strings.Contains(imp.Path.Value, obj.Name()) { + return nil, nil + } + rng, err := nodeToProtocolRange(view, pkg, imp.Path) + if err != nil { + return nil, err + } + return &rng, nil +} + func rangeMapToSlice(rangeMap map[protocol.Range]bool) []protocol.Range { var list []protocol.Range for i := range rangeMap { diff --git a/internal/lsp/testdata/highlights/highlights.go b/internal/lsp/testdata/highlights/highlights.go index c8c37d211e..de67efec2c 100644 --- a/internal/lsp/testdata/highlights/highlights.go +++ b/internal/lsp/testdata/highlights/highlights.go @@ -1,7 +1,8 @@ package highlights import ( - "fmt" + "fmt" //@mark(fmtImp, "\"fmt\""),highlight(fmtImp, fmtImp, fmt1, fmt2, fmt3, fmt4) + h2 "net/http" //@mark(hImp, "h2"),highlight(hImp, hImp, hUse) "sort" "golang.org/x/tools/internal/lsp/protocol" @@ -18,8 +19,10 @@ func _() F { var foo = F{bar: 52} //@mark(fooDeclaration, "foo"),mark(bar2, "bar"),highlight(fooDeclaration, fooDeclaration, fooUse),highlight(bar2, barDeclaration, bar1, bar2, bar3) func Print() { //@mark(printFunc, "Print"),highlight(printFunc, printFunc, printTest) - fmt.Println(foo) //@mark(fooUse, "foo"),highlight(fooUse, fooDeclaration, fooUse) - fmt.Print("yo") //@mark(printSep, "Print"),highlight(printSep, printSep, print1, print2) + _ = h2.Client{} //@mark(hUse, "h2"),highlight(hUse, hImp, hUse) + + fmt.Println(foo) //@mark(fooUse, "foo"),highlight(fooUse, fooDeclaration, fooUse),mark(fmt1, "fmt"),highlight(fmt1, fmtImp, fmt1, fmt2, fmt3, fmt4) + fmt.Print("yo") //@mark(printSep, "Print"),highlight(printSep, printSep, print1, print2),mark(fmt2, "fmt"),highlight(fmt2, fmtImp, fmt1, fmt2, fmt3, fmt4) } func (x *F) Inc() { //@mark(xRightDecl, "x"),mark(xLeftDecl, " *"),highlight(xRightDecl, xRightDecl, xUse),highlight(xLeftDecl, xRightDecl, xUse) @@ -27,8 +30,8 @@ func (x *F) Inc() { //@mark(xRightDecl, "x"),mark(xLeftDecl, " *"),highlight(xRi } func testFunctions() { - fmt.Print("main start") //@mark(print1, "Print"),highlight(print1, printSep, print1, print2) - fmt.Print("ok") //@mark(print2, "Print"),highlight(print2, printSep, print1, print2) + fmt.Print("main start") //@mark(print1, "Print"),highlight(print1, printSep, print1, print2),mark(fmt3, "fmt"),highlight(fmt3, fmtImp, fmt1, fmt2, fmt3, fmt4) + fmt.Print("ok") //@mark(print2, "Print"),highlight(print2, printSep, print1, print2),mark(fmt4, "fmt"),highlight(fmt4, fmtImp, fmt1, fmt2, fmt3, fmt4) Print() //@mark(printTest, "Print"),highlight(printTest, printFunc, printTest) } diff --git a/internal/lsp/testdata/summary.txt.golden b/internal/lsp/testdata/summary.txt.golden index ee564e3b35..51c0290b17 100644 --- a/internal/lsp/testdata/summary.txt.golden +++ b/internal/lsp/testdata/summary.txt.golden @@ -13,7 +13,7 @@ ImportCount = 7 SuggestedFixCount = 1 DefinitionsCount = 43 TypeDefinitionsCount = 2 -HighlightsCount = 45 +HighlightsCount = 52 ReferencesCount = 8 RenamesCount = 22 PrepareRenamesCount = 8