diff --git a/gopls/internal/regtest/misc/references_test.go b/gopls/internal/regtest/misc/references_test.go index 768251680f..de2e9b97fd 100644 --- a/gopls/internal/regtest/misc/references_test.go +++ b/gopls/internal/regtest/misc/references_test.go @@ -5,6 +5,8 @@ package misc import ( + "fmt" + "strings" "testing" . "golang.org/x/tools/internal/lsp/regtest" @@ -81,3 +83,89 @@ func _() { } }) } + +func TestPackageReferences(t *testing.T) { + tests := []struct { + packageName string + wantRefCount int + wantFiles []string + }{ + { + "lib1", + 3, + []string{ + "main.go", + "lib1/a.go", + "lib1/b.go", + }, + }, + { + "lib2", + 2, + []string{ + "main.go", + "lib2/a.go", + }, + }, + } + + const files = ` +-- go.mod -- +module mod.com + +go 1.18 +-- lib1/a.go -- +package lib1 + +const A = 1 + +-- lib1/b.go -- +package lib1 + +const B = 1 + +-- lib2/a.go -- +package lib2 + +const C = 1 + +-- main.go -- +package main + +import ( + "mod.com/lib1" + "mod.com/lib2" +) + +func main() { + println("Hello") +} +` + Run(t, files, func(t *testing.T, env *Env) { + for _, test := range tests { + f := fmt.Sprintf("%s/a.go", test.packageName) + env.OpenFile(f) + pos := env.RegexpSearch(f, test.packageName) + refs := env.References(fmt.Sprintf("%s/a.go", test.packageName), pos) + if len(refs) != test.wantRefCount { + t.Fatalf("got %v reference(s), want %d", len(refs), test.wantRefCount) + } + var refURIs []string + for _, ref := range refs { + refURIs = append(refURIs, string(ref.URI)) + } + for _, base := range test.wantFiles { + hasBase := false + for _, ref := range refURIs { + if strings.HasSuffix(ref, base) { + hasBase = true + break + } + } + if !hasBase { + t.Fatalf("got [%v], want reference ends with \"%v\"", strings.Join(refURIs, ","), base) + } + } + } + }) +} diff --git a/internal/lsp/source/references.go b/internal/lsp/source/references.go index 3541600b20..85bf41a21b 100644 --- a/internal/lsp/source/references.go +++ b/internal/lsp/source/references.go @@ -9,12 +9,15 @@ import ( "errors" "fmt" "go/ast" + "go/token" "go/types" "sort" + "strconv" "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/lsp/protocol" + "golang.org/x/tools/internal/lsp/safetoken" "golang.org/x/tools/internal/span" ) @@ -34,6 +37,63 @@ func References(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Posit ctx, done := event.Start(ctx, "source.References") defer done() + // Find position of the package name declaration + pgf, err := s.ParseGo(ctx, f, ParseFull) + if err != nil { + return nil, err + } + + cursorOffset, err := pgf.Mapper.Offset(pp) + if err != nil { + return nil, err + } + + packageNameStart, err := safetoken.Offset(pgf.Tok, pgf.File.Name.Pos()) + if err != nil { + return nil, err + } + + packageNameEnd, err := safetoken.Offset(pgf.Tok, pgf.File.Name.End()) + if err != nil { + return nil, err + } + + if packageNameStart <= cursorOffset && cursorOffset < packageNameEnd { + renamingPkg, err := s.PackageForFile(ctx, f.URI(), TypecheckAll, NarrowestPackage) + if err != nil { + return nil, err + } + + // Find external references to the package. + rdeps, err := s.GetReverseDependencies(ctx, renamingPkg.ID()) + if err != nil { + return nil, err + } + var refs []*ReferenceInfo + for _, dep := range rdeps { + for _, f := range dep.CompiledGoFiles() { + for _, imp := range f.File.Imports { + if path, err := strconv.Unquote(imp.Path.Value); err == nil && path == renamingPkg.PkgPath() { + refs = append(refs, &ReferenceInfo{ + Name: pgf.File.Name.Name, + MappedRange: NewMappedRange(s.FileSet(), f.Mapper, imp.Pos(), imp.End()), + }) + } + } + } + } + + // Find internal references to the package within the package itself + for _, f := range renamingPkg.CompiledGoFiles() { + refs = append(refs, &ReferenceInfo{ + Name: pgf.File.Name.Name, + MappedRange: NewMappedRange(s.FileSet(), f.Mapper, f.File.Name.Pos(), f.File.Name.End()), + }) + } + + return refs, nil + } + qualifiedObjs, err := qualifiedObjsAtProtocolPos(ctx, s, f.URI(), pp) // Don't return references for builtin types. if errors.Is(err, errBuiltin) {