diff --git a/internal/lsp/cmd/test/references.go b/internal/lsp/cmd/test/references.go index 3039f53483..d48a0a90aa 100644 --- a/internal/lsp/cmd/test/references.go +++ b/internal/lsp/cmd/test/references.go @@ -27,7 +27,7 @@ func (r *runner) References(t *testing.T, spn span.Span, itemList []span.Span) { uri := spn.URI() filename := uri.Filename() target := filename + fmt.Sprintf(":%v:%v", spn.Start().Line(), spn.Start().Column()) - got, _ := r.NormalizeGoplsCmd(t, "references", target) + got, _ := r.NormalizeGoplsCmd(t, "references", "-d", target) if expect != got { t.Errorf("references failed for %s expected:\n%s\ngot:\n%s", target, expect, got) } diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index b2414ad857..c53a1b83a8 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -541,7 +541,6 @@ func (r *runner) References(t *testing.T, src span.Span, itemList []span.Span) { if err != nil { t.Fatalf("failed for %v: %v", src, err) } - want := make(map[protocol.Location]bool) for _, pos := range itemList { m, err := r.data.Mapper(pos.URI()) @@ -559,6 +558,9 @@ func (r *runner) References(t *testing.T, src span.Span, itemList []span.Span) { TextDocument: protocol.TextDocumentIdentifier{URI: loc.URI}, Position: loc.Range.Start, }, + Context: protocol.ReferenceContext{ + IncludeDeclaration: true, + }, } got, err := r.server.References(r.ctx, params) if err != nil { diff --git a/internal/lsp/references.go b/internal/lsp/references.go index 9d831b5c81..2dc37f2bdc 100644 --- a/internal/lsp/references.go +++ b/internal/lsp/references.go @@ -56,27 +56,18 @@ func (s *Server) references(ctx context.Context, params *protocol.ReferenceParam Range: refRange, }) } - // The declaration of this identifier may not be in the - // scope that we search for references, so make sure - // it is added to the beginning of the list if IncludeDeclaration - // was specified. + // Only add the identifier's declaration if the client requests it. if params.Context.IncludeDeclaration { - decSpan, err := ident.Declaration.Span() + rng, err := ident.Declaration.Range() if err != nil { return nil, err } - if !seen[decSpan] { - rng, err := ident.Declaration.Range() - if err != nil { - return nil, err - } - locations = append([]protocol.Location{ - { - URI: protocol.NewURI(ident.Declaration.URI()), - Range: rng, - }, - }, locations...) - } + locations = append([]protocol.Location{ + { + URI: protocol.NewURI(ident.Declaration.URI()), + Range: rng, + }, + }, locations...) } return locations, nil } diff --git a/internal/lsp/source/identifier.go b/internal/lsp/source/identifier.go index a0720c07dc..a942f5e5a5 100644 --- a/internal/lsp/source/identifier.go +++ b/internal/lsp/source/identifier.go @@ -45,6 +45,17 @@ type Declaration struct { wasImplicit bool } +func (i *IdentifierInfo) DeclarationReferenceInfo() *ReferenceInfo { + return &ReferenceInfo{ + Name: i.Declaration.obj.Name(), + mappedRange: i.Declaration.mappedRange, + obj: i.Declaration.obj, + ident: i.ident, + pkg: i.pkg, + isDeclaration: true, + } +} + // Identifier returns identifier information for a position // in a file, accounting for a potentially incomplete selector. func Identifier(ctx context.Context, snapshot Snapshot, f File, pos protocol.Position) (*IdentifierInfo, error) { diff --git a/internal/lsp/source/references.go b/internal/lsp/source/references.go index 5539f5fc42..6163425b8d 100644 --- a/internal/lsp/source/references.go +++ b/internal/lsp/source/references.go @@ -32,8 +32,6 @@ func (i *IdentifierInfo) References(ctx context.Context) ([]*ReferenceInfo, erro ctx, done := trace.StartSpan(ctx, "source.References") defer done() - var references []*ReferenceInfo - // If the object declaration is nil, assume it is an import spec and do not look for references. if i.Declaration.obj == nil { return nil, errors.Errorf("no references for an import spec") @@ -42,36 +40,6 @@ func (i *IdentifierInfo) References(ctx context.Context) ([]*ReferenceInfo, erro if info == nil { return nil, errors.Errorf("package %s has no types info", i.pkg.PkgPath()) } - if i.Declaration.wasImplicit { - // The definition is implicit, so we must add it separately. - // This occurs when the variable is declared in a type switch statement - // or is an implicit package name. Both implicits are local to a file. - references = append(references, &ReferenceInfo{ - Name: i.Declaration.obj.Name(), - mappedRange: i.Declaration.mappedRange, - obj: i.Declaration.obj, - pkg: i.pkg, - isDeclaration: true, - }) - } - for ident, obj := range info.Defs { - if obj == nil || !sameObj(obj, i.Declaration.obj) { - continue - } - rng, err := posToMappedRange(i.Snapshot.View(), i.pkg, ident.Pos(), ident.End()) - if err != nil { - return nil, err - } - // Add the declarations at the beginning of the references list. - references = append([]*ReferenceInfo{{ - Name: ident.Name, - ident: ident, - obj: obj, - pkg: i.pkg, - isDeclaration: true, - mappedRange: rng, - }}, references...) - } var searchpkgs []Package if i.Declaration.obj.Exported() { // Only search all packages if the identifier is exported. @@ -91,9 +59,11 @@ func (i *IdentifierInfo) References(ctx context.Context) ([]*ReferenceInfo, erro } // Add the package in which the identifier is declared. searchpkgs = append(searchpkgs, i.pkg) + + var references []*ReferenceInfo for _, pkg := range searchpkgs { for ident, obj := range pkg.GetTypesInfo().Uses { - if obj == nil || !(sameObj(obj, i.Declaration.obj)) { + if !sameObj(obj, i.Declaration.obj) { continue } rng, err := posToMappedRange(i.Snapshot.View(), pkg, ident.Pos(), ident.End()) @@ -117,6 +87,9 @@ func (i *IdentifierInfo) References(ctx context.Context) ([]*ReferenceInfo, erro // and their objectpath and package are the same; or if they don't // have object paths and they have the same Pos and Name. func sameObj(obj, declObj types.Object) bool { + if obj == nil || declObj == nil { + return false + } // TODO(suzmue): support the case where an identifier may have two different // declaration positions. if obj.Pkg() == nil || declObj.Pkg() == nil { diff --git a/internal/lsp/source/rename.go b/internal/lsp/source/rename.go index 29a7d8c43a..986a2865ae 100644 --- a/internal/lsp/source/rename.go +++ b/internal/lsp/source/rename.go @@ -120,6 +120,9 @@ func (i *IdentifierInfo) Rename(ctx context.Context, newName string) (map[span.U return nil, err } + // Make sure to add the declaration of the identifier. + refs = append(refs, i.DeclarationReferenceInfo()) + r := renamer{ ctx: ctx, fset: i.Snapshot.View().Session().Cache().FileSet(), diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index 48dda36034..de1408d8cd 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -653,17 +653,16 @@ func (r *runner) References(t *testing.T, src span.Span, itemList []span.Span) { if err != nil { t.Fatalf("failed for %v: %v", src, err) } - want := make(map[span.Span]bool) for _, pos := range itemList { want[pos] = true } - refs, err := ident.References(ctx) if err != nil { t.Fatalf("failed for %v: %v", src, err) } - + // Add the item's declaration, since References omits it. + refs = append([]*source.ReferenceInfo{ident.DeclarationReferenceInfo()}, refs...) got := make(map[span.Span]bool) for _, refInfo := range refs { refSpan, err := refInfo.Span() @@ -672,11 +671,9 @@ func (r *runner) References(t *testing.T, src span.Span, itemList []span.Span) { } got[refSpan] = true } - if len(got) != len(want) { t.Errorf("references failed: different lengths got %v want %v", len(got), len(want)) } - for spn := range got { if !want[spn] { t.Errorf("references failed: incorrect references got %v want locations %v", got, want)