From f80fb1dfa15c7ff8069faff74eb603bf0f63164c Mon Sep 17 00:00:00 2001 From: Muir Manders Date: Tue, 17 Dec 2019 21:06:31 -0800 Subject: [PATCH] internal/lsp: refactor find-references and rename The main goal is to push the package variant logic from internal/lsp into internal/lsp/source so all users of internal/lsp/source benefit. "references" and "rename" now have top-level source.References() and source.Rename() entry points (as opposed to hanging off source.Identifier()). I expanded objectsAtProtocolPos() to know about implicit objects (type switch and import spec), and to handle *ast.ImportSpec generically. This gets rid of special case handling of *types.PkgName in various places. The biggest practical benefit, though, is that "references" no longer needs to compute the objectpath for every types.Object comparison it does, instead using direct types.Object equality. This speeds up "references" and "rename" a lot. Two other notable improvements that fell out of not using source.Identifier()'s logic: - Finding references on an embedded field now shows references to the field, not the type being embedded. - Finding references on an imported object now works correctly (previously it searched the importing package's dependents rather than the imported package's dependents). Finally, I refactored findIdentifier() to use pathEnclosingObjNode() instead of astutil.PathEnclosingInterval. Now we only need a single call to get the path because pathEnclosingObjNode() has the "try pos || try pos-1" logic built in. Change-Id: I667be9bed6ad83912404b90257c5c1485b3a7025 Reviewed-on: https://go-review.googlesource.com/c/tools/+/211999 Run-TryBot: Muir Manders TryBot-Result: Gobot Gobot Reviewed-by: Rebecca Stambler --- internal/lsp/references.go | 59 +----- internal/lsp/rename.go | 15 +- internal/lsp/source/identifier.go | 144 +++++++------- internal/lsp/source/implementation.go | 124 +++++++++--- internal/lsp/source/references.go | 139 +++++++------- internal/lsp/source/rename.go | 180 ++++++------------ internal/lsp/source/source_test.go | 24 +-- internal/lsp/testdata/references/refs.go | 12 +- internal/lsp/testdata/references/refs_test.go | 10 + internal/lsp/testdata/summary.txt.golden | 2 +- 10 files changed, 337 insertions(+), 372 deletions(-) create mode 100644 internal/lsp/testdata/references/refs_test.go diff --git a/internal/lsp/references.go b/internal/lsp/references.go index 47120173f6..d91fc16b7f 100644 --- a/internal/lsp/references.go +++ b/internal/lsp/references.go @@ -27,62 +27,23 @@ func (s *Server) references(ctx context.Context, params *protocol.ReferenceParam if fh.Identity().Kind != source.Go { return nil, nil } - phs, err := snapshot.PackageHandles(ctx, fh) + + references, err := source.References(ctx, view.Snapshot(), fh, params.Position, params.Context.IncludeDeclaration) if err != nil { - return nil, nil + return nil, err } - // Get the location of each reference to return as the result. - var ( - locations []protocol.Location - seen = make(map[span.Span]bool) - lastIdent *source.IdentifierInfo - ) - for _, ph := range phs { - ident, err := source.Identifier(ctx, snapshot, fh, params.Position, source.SpecificPackageHandle(ph.ID())) + var locations []protocol.Location + for _, ref := range references { + refRange, err := ref.Range() if err != nil { return nil, err } - lastIdent = ident - - references, err := ident.References(ctx) - if err != nil { - return nil, err - } - - for _, ref := range references { - refSpan, err := ref.Span() - if err != nil { - return nil, err - } - if seen[refSpan] { - continue // already added this location - } - seen[refSpan] = true - refRange, err := ref.Range() - if err != nil { - return nil, err - } - locations = append(locations, protocol.Location{ - URI: protocol.NewURI(ref.URI()), - Range: refRange, - }) - } - } - - // Only add the identifier's declaration if the client requests it. - if params.Context.IncludeDeclaration && lastIdent != nil { - rng, err := lastIdent.Declaration.Range() - if err != nil { - return nil, err - } - locations = append([]protocol.Location{ - { - URI: protocol.NewURI(lastIdent.Declaration.URI()), - Range: rng, - }, - }, locations...) + locations = append(locations, protocol.Location{ + URI: protocol.NewURI(ref.URI()), + Range: refRange, + }) } return locations, nil diff --git a/internal/lsp/rename.go b/internal/lsp/rename.go index 7fa10389f2..04f00546f5 100644 --- a/internal/lsp/rename.go +++ b/internal/lsp/rename.go @@ -26,14 +26,12 @@ func (s *Server) rename(ctx context.Context, params *protocol.RenameParams) (*pr if fh.Identity().Kind != source.Go { return nil, nil } - ident, err := source.Identifier(ctx, snapshot, fh, params.Position, source.WidestPackageHandle) - if err != nil { - return nil, nil - } - edits, err := ident.Rename(ctx, params.NewName) + + edits, err := source.Rename(ctx, snapshot, fh, params.Position, params.NewName) if err != nil { return nil, err } + var docChanges []protocol.TextDocumentEdit for uri, e := range edits { fh, err := snapshot.GetFile(uri) @@ -61,13 +59,10 @@ func (s *Server) prepareRename(ctx context.Context, params *protocol.PrepareRena if fh.Identity().Kind != source.Go { return nil, nil } - ident, err := source.Identifier(ctx, snapshot, fh, params.Position, source.WidestPackageHandle) - if err != nil { - return nil, nil // ignore errors - } + // Do not return errors here, as it adds clutter. // Returning a nil result means there is not a valid rename. - item, err := ident.PrepareRename(ctx) + item, err := source.PrepareRename(ctx, snapshot, fh, params.Position) if err != nil { return nil, nil // ignore errors } diff --git a/internal/lsp/source/identifier.go b/internal/lsp/source/identifier.go index 987bae2d65..26943049f3 100644 --- a/internal/lsp/source/identifier.go +++ b/internal/lsp/source/identifier.go @@ -46,17 +46,6 @@ type Declaration struct { obj types.Object } -func (i *IdentifierInfo) DeclarationReferenceInfo() *ReferenceInfo { - return &ReferenceInfo{ - Name: i.Declaration.obj.Name(), - mappedRange: i.Declaration.mappedRange, - obj: i.Declaration.obj, - ident: i.ident, - pkg: i.pkg, - isDeclaration: true, - } -} - // Identifier returns identifier information for a position // in a file, accounting for a potentially incomplete selector. func Identifier(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position, selectPackage PackagePolicy) (*IdentifierInfo, error) { @@ -84,46 +73,31 @@ func Identifier(ctx context.Context, snapshot Snapshot, fh FileHandle, pos proto var ErrNoIdentFound = errors.New("no identifier found") -func findIdentifier(ctx context.Context, snapshot Snapshot, pkg Package, file *ast.File, pos token.Pos) (*IdentifierInfo, error) { - if result, err := identifier(ctx, snapshot, pkg, file, pos); err != nil || result != nil { - return result, err - } - // If the position is not an identifier but immediately follows - // an identifier or selector period (as is common when - // requesting a completion), use the path to the preceding node. - ident, err := identifier(ctx, snapshot, pkg, file, pos-1) - if ident == nil && err == nil { - err = ErrNoIdentFound - } - return ident, err -} - -// identifier checks a single position for a potential identifier. -func identifier(ctx context.Context, s Snapshot, pkg Package, file *ast.File, pos token.Pos) (*IdentifierInfo, error) { - var err error - +func findIdentifier(ctx context.Context, s Snapshot, pkg Package, file *ast.File, pos token.Pos) (*IdentifierInfo, error) { // Handle import specs separately, as there is no formal position for a package declaration. if result, err := importSpec(s, pkg, file, pos); result != nil || err != nil { return result, err } - path, _ := astutil.PathEnclosingInterval(file, pos, pos) + path := pathEnclosingObjNode(file, pos) if path == nil { - return nil, errors.Errorf("can't find node enclosing position") + return nil, ErrNoIdentFound } view := s.View() + + ident, _ := path[0].(*ast.Ident) + if ident == nil { + return nil, ErrNoIdentFound + } + result := &IdentifierInfo{ Snapshot: s, qf: qualifier(file, pkg.GetTypes(), pkg.GetTypesInfo()), pkg: pkg, - ident: searchForIdent(path[0]), + ident: ident, enclosing: searchForEnclosing(pkg, path), } - // No identifier at the given position. - if result.ident == nil { - return nil, nil - } var wasEmbeddedField bool for _, n := range path[1:] { if field, ok := n.(*ast.Field); ok { @@ -131,7 +105,9 @@ func identifier(ctx context.Context, s Snapshot, pkg Package, file *ast.File, po break } } + result.Name = result.ident.Name + var err error if result.mappedRange, err = posToMappedRange(view, pkg, result.ident.Pos(), result.ident.End()); err != nil { return nil, err } @@ -139,7 +115,7 @@ func identifier(ctx context.Context, s Snapshot, pkg Package, file *ast.File, po if result.Declaration.obj == nil { // If there was no types.Object for the declaration, there might be an implicit local variable // declaration in a type switch. - if objs := typeSwitchVar(pkg.GetTypesInfo(), path); len(objs) > 0 { + if objs := typeSwitchImplicits(pkg, path); len(objs) > 0 { // There is no types.Object for the declaration of an implicit local variable, // but all of the types.Objects associated with the usages of this variable can be // used to connect it back to the declaration. @@ -202,18 +178,6 @@ func identifier(ctx context.Context, s Snapshot, pkg Package, file *ast.File, po return result, nil } -func searchForIdent(n ast.Node) *ast.Ident { - switch node := n.(type) { - case *ast.Ident: - return node - case *ast.SelectorExpr: - return node.Sel - case *ast.StarExpr: - return searchForIdent(node.X) - } - return nil -} - func searchForEnclosing(pkg Package, path []ast.Node) types.Type { for _, n := range path { switch n := n.(type) { @@ -326,31 +290,69 @@ func importSpec(s Snapshot, pkg Package, file *ast.File, pos token.Pos) (*Identi return result, nil } -// typeSwitchVar handles the special case of a local variable implicitly defined in a type switch. -// In such cases, the definition of the implicit variable will not be recorded in the *types.Info.Defs map, -// but rather in the *types.Info.Implicits map. -func typeSwitchVar(info *types.Info, path []ast.Node) []types.Object { - if len(path) < 3 { - return nil - } - // Check for [Ident AssignStmt TypeSwitchStmt...] - if _, ok := path[0].(*ast.Ident); !ok { - return nil - } - if _, ok := path[1].(*ast.AssignStmt); !ok { - return nil - } - sw, ok := path[2].(*ast.TypeSwitchStmt) - if !ok { +// typeSwitchImplicits returns all the implicit type switch objects +// that correspond to the leaf *ast.Ident. +func typeSwitchImplicits(pkg Package, path []ast.Node) []types.Object { + ident, _ := path[0].(*ast.Ident) + if ident == nil { return nil } - var res []types.Object - for _, stmt := range sw.Body.List { - obj := info.Implicits[stmt.(*ast.CaseClause)] - if obj != nil { - res = append(res, obj) + var ( + ts *ast.TypeSwitchStmt + assign *ast.AssignStmt + cc *ast.CaseClause + obj = pkg.GetTypesInfo().ObjectOf(ident) + ) + + // Walk our ancestors to determine if our leaf ident refers to a + // type switch variable, e.g. the "a" from "switch a := b.(type)". +Outer: + for i := 1; i < len(path); i++ { + switch n := path[i].(type) { + case *ast.AssignStmt: + // Check if ident is the "a" in "a := foo.(type)". The "a" in + // this case has no types.Object, so check for ident equality. + if len(n.Lhs) == 1 && n.Lhs[0] == ident { + assign = n + } + case *ast.CaseClause: + // Check if ident is a use of "a" within a case clause. Each + // case clause implicitly maps "a" to a different types.Object, + // so check if ident's object is the case clause's implicit + // object. + if obj != nil && pkg.GetTypesInfo().Implicits[n] == obj { + cc = n + } + case *ast.TypeSwitchStmt: + // Look for the type switch that owns our previously found + // *ast.AssignStmt or *ast.CaseClause. + + if n.Assign == assign { + ts = n + break Outer + } + + for _, stmt := range n.Body.List { + if stmt == cc { + ts = n + break Outer + } + } } } - return res + + if ts == nil { + return nil + } + + // Our leaf ident refers to a type switch variable. Fan out to the + // type switch's implicit case clause objects. + var objs []types.Object + for _, cc := range ts.Body.List { + if ccObj := pkg.GetTypesInfo().Implicits[cc]; ccObj != nil { + objs = append(objs, ccObj) + } + } + return objs } diff --git a/internal/lsp/source/implementation.go b/internal/lsp/source/implementation.go index fb896f8ab0..b2db818af1 100644 --- a/internal/lsp/source/implementation.go +++ b/internal/lsp/source/implementation.go @@ -48,25 +48,25 @@ func Implementation(ctx context.Context, s Snapshot, f FileHandle, pp protocol.P var ErrNotAnInterface = errors.New("not an interface or interface method") -func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]implementation, error) { +func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]qualifiedObject, error) { var ( - impls []implementation + impls []qualifiedObject seen = make(map[token.Position]bool) fset = s.View().Session().Cache().FileSet() ) - objs, err := objectsAtProtocolPos(ctx, s, f, pp) + qos, err := qualifiedObjsAtProtocolPos(ctx, s, f, pp) if err != nil { return nil, err } - for _, obj := range objs { + for _, qo := range qos { var ( T *types.Interface method *types.Func ) - switch obj := obj.(type) { + switch obj := qo.obj.(type) { case *types.Func: method = obj if recv := obj.Type().(*types.Signature).Recv(); recv != nil { @@ -141,7 +141,7 @@ func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol. seen[pos] = true - impls = append(impls, implementation{ + impls = append(impls, qualifiedObject{ obj: obj, pkg: pkgs[obj.Pkg()], }) @@ -151,23 +151,29 @@ func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol. return impls, nil } -type implementation struct { - // obj is the implementation, either a *types.TypeName or *types.Func. +type qualifiedObject struct { obj types.Object // pkg is the Package that contains obj's definition. pkg Package + + // node is the *ast.Ident or *ast.ImportSpec we followed to find obj, if any. + node ast.Node + + // sourcePkg is the Package that contains node, if any. + sourcePkg Package } -// objectsAtProtocolPos returns all the type.Objects referenced at the given position. -// An object will be returned for every package that the file belongs to. -func objectsAtProtocolPos(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]types.Object, error) { +// qualifiedObjsAtProtocolPos returns info for all the type.Objects +// referenced at the given position. An object will be returned for +// every package that the file belongs to. +func qualifiedObjsAtProtocolPos(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]qualifiedObject, error) { phs, err := s.PackageHandles(ctx, f) if err != nil { return nil, err } - var objs []types.Object + var qualifiedObjs []qualifiedObject // Check all the packages that the file belongs to. for _, ph := range phs { @@ -181,22 +187,52 @@ func objectsAtProtocolPos(ctx context.Context, s Snapshot, f FileHandle, pp prot return nil, err } - path := pathEnclosingIdent(astFile, pos) - if len(path) == 0 { + path := pathEnclosingObjNode(astFile, pos) + if path == nil { return nil, ErrNoIdentFound } - ident := path[len(path)-1].(*ast.Ident) - - obj := pkg.GetTypesInfo().ObjectOf(ident) - if obj == nil { - return nil, fmt.Errorf("no object for %q", ident.Name) + var objs []types.Object + switch leaf := path[0].(type) { + case *ast.Ident: + // If leaf represents an implicit type switch object or the type + // switch "assign" variable, expand to all of the type switch's + // implicit objects. + if implicits := typeSwitchImplicits(pkg, path); len(implicits) > 0 { + objs = append(objs, implicits...) + } else { + obj := pkg.GetTypesInfo().ObjectOf(leaf) + if obj == nil { + return nil, fmt.Errorf("no object for %q", leaf.Name) + } + objs = append(objs, obj) + } + case *ast.ImportSpec: + // Look up the implicit *types.PkgName. + obj := pkg.GetTypesInfo().Implicits[leaf] + if obj == nil { + return nil, fmt.Errorf("no object for import %q", importPath(leaf)) + } + objs = append(objs, obj) } - objs = append(objs, obj) + pkgs := make(map[*types.Package]Package) + pkgs[pkg.GetTypes()] = pkg + for _, imp := range pkg.Imports() { + pkgs[imp.GetTypes()] = imp + } + + for _, obj := range objs { + qualifiedObjs = append(qualifiedObjs, qualifiedObject{ + obj: obj, + pkg: pkgs[obj.Pkg()], + sourcePkg: pkg, + node: path[0], + }) + } } - return objs, nil + return qualifiedObjs, nil } func getASTFile(pkg Package, f FileHandle, pos protocol.Position) (*ast.File, token.Pos, error) { @@ -223,10 +259,11 @@ func getASTFile(pkg Package, f FileHandle, pos protocol.Position) (*ast.File, to return file, rng.Start, nil } -// pathEnclosingIdent returns the ast path to the node that contains pos. -// It is similar to astutil.PathEnclosingInterval, but simpler, and it -// matches *ast.Ident nodes if pos is equal to node.End(). -func pathEnclosingIdent(f *ast.File, pos token.Pos) []ast.Node { +// pathEnclosingObjNode returns the AST path to the object-defining +// node associated with pos. "Object-defining" means either an +// *ast.Ident mapped directly to a types.Object or an ast.Node mapped +// implicitly to a types.Object. +func pathEnclosingObjNode(f *ast.File, pos token.Pos) []ast.Node { var ( path []ast.Node found bool @@ -242,15 +279,48 @@ func pathEnclosingIdent(f *ast.File, pos token.Pos) []ast.Node { return false } + path = append(path, n) + switch n := n.(type) { case *ast.Ident: + // Include the position directly after identifier. This handles + // the common case where the cursor is right after the + // identifier the user is currently typing. Previously we + // handled this by calling astutil.PathEnclosingInterval twice, + // once for "pos" and once for "pos-1". found = n.Pos() <= pos && pos <= n.End() + case *ast.ImportSpec: + if n.Path.Pos() <= pos && pos < n.Path.End() { + found = true + // If import spec has a name, add name to path even though + // position isn't in the name. + if n.Name != nil { + path = append(path, n.Name) + } + } + case *ast.StarExpr: + // Follow star expressions to the inner identifer. + if pos == n.Star { + pos = n.X.Pos() + } + case *ast.SelectorExpr: + // If pos is on the ".", move it into the selector. + if pos == n.X.End() { + pos = n.Sel.Pos() + } } - path = append(path, n) - return !found }) + if len(path) == 0 { + return nil + } + + // Reverse path so leaf is first element. + for i := 0; i < len(path)/2; i++ { + path[i], path[len(path)-1-i] = path[len(path)-1-i], path[i] + } + return path } diff --git a/internal/lsp/source/references.go b/internal/lsp/source/references.go index ec1acfb0c7..8db48ddda9 100644 --- a/internal/lsp/source/references.go +++ b/internal/lsp/source/references.go @@ -7,11 +7,11 @@ package source import ( "context" "go/ast" + "go/token" "go/types" - "golang.org/x/tools/go/types/objectpath" + "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/telemetry/trace" - errors "golang.org/x/xerrors" ) // ReferenceInfo holds information about reference to an identifier in Go source. @@ -26,79 +26,86 @@ type ReferenceInfo struct { // References returns a list of references for a given identifier within the packages // containing i.File. Declarations appear first in the result. -func (i *IdentifierInfo) References(ctx context.Context) ([]*ReferenceInfo, error) { +func References(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position, includeDeclaration bool) ([]*ReferenceInfo, error) { ctx, done := trace.StartSpan(ctx, "source.References") defer done() - // If the object declaration is nil, assume it is an import spec and do not look for references. - if i.Declaration.obj == nil { - return nil, errors.Errorf("no references for an import spec") + qualifiedObjs, err := qualifiedObjsAtProtocolPos(ctx, s, f, pp) + if err != nil { + return nil, err } - info := i.pkg.GetTypesInfo() - if info == nil { - return nil, errors.Errorf("package %s has no types info", i.pkg.PkgPath()) + + var ( + references []*ReferenceInfo + seen = make(map[token.Position]bool) + fset = s.View().Session().Cache().FileSet() + ) + + for _, qo := range qualifiedObjs { + var searchPkgs []Package + + // Only search dependents if the object is exported. + if qo.obj.Exported() { + reverseDeps, err := s.GetReverseDependencies(ctx, qo.pkg.ID()) + if err != nil { + return nil, err + } + + for _, ph := range reverseDeps { + pkg, err := ph.Check(ctx) + if err != nil { + return nil, err + } + searchPkgs = append(searchPkgs, pkg) + } + } + + // Add the package in which the identifier is declared. + searchPkgs = append(searchPkgs, qo.pkg) + + for _, pkg := range searchPkgs { + for ident, obj := range pkg.GetTypesInfo().Uses { + if obj != qo.obj { + continue + } + + pos := fset.Position(ident.Pos()) + if seen[pos] { + continue + } + seen[pos] = true + + rng, err := posToMappedRange(s.View(), pkg, ident.Pos(), ident.End()) + if err != nil { + return nil, err + } + references = append(references, &ReferenceInfo{ + Name: ident.Name, + ident: ident, + pkg: pkg, + obj: obj, + mappedRange: rng, + }) + } + } } - var searchpkgs []Package - if i.Declaration.obj.Exported() { - // Only search all packages if the identifier is exported. - reverseDeps, err := i.Snapshot.GetReverseDependencies(ctx, i.pkg.ID()) + + if includeDeclaration { + rng, err := objToMappedRange(s.View(), qualifiedObjs[0].pkg, qualifiedObjs[0].obj) if err != nil { return nil, err } - for _, ph := range reverseDeps { - pkg, err := ph.Check(ctx) - if err != nil { - return nil, err - } - searchpkgs = append(searchpkgs, pkg) - } - } - // Add the package in which the identifier is declared. - searchpkgs = append(searchpkgs, i.pkg) - var references []*ReferenceInfo - for _, pkg := range searchpkgs { - for ident, obj := range pkg.GetTypesInfo().Uses { - if !sameObj(obj, i.Declaration.obj) { - continue - } - rng, err := posToMappedRange(i.Snapshot.View(), pkg, ident.Pos(), ident.End()) - if err != nil { - return nil, err - } - references = append(references, &ReferenceInfo{ - Name: ident.Name, - ident: ident, - pkg: i.pkg, - obj: obj, - mappedRange: rng, - }) - } + ident, _ := qualifiedObjs[0].node.(*ast.Ident) + references = append(references, &ReferenceInfo{ + mappedRange: rng, + Name: qualifiedObjs[0].obj.Name(), + ident: ident, + obj: qualifiedObjs[0].obj, + pkg: qualifiedObjs[0].pkg, + isDeclaration: true, + }) } + return references, nil } - -// sameObj returns true if obj is the same as declObj. -// Objects are the same if either they have they have objectpaths -// and their objectpath and package are the same; or if they don't -// have object paths and they have the same Pos and Name. -func sameObj(obj, declObj types.Object) bool { - if obj == nil || declObj == nil { - return false - } - // TODO(suzmue): support the case where an identifier may have two different - // declaration positions. - if obj.Pkg() == nil || declObj.Pkg() == nil { - if obj.Pkg() != declObj.Pkg() { - return false - } - } else if obj.Pkg().Path() != declObj.Pkg().Path() { - return false - } - objPath, operr := objectpath.For(obj) - declObjPath, doperr := objectpath.For(declObj) - if operr != nil || doperr != nil { - return obj.Pos() == declObj.Pos() && obj.Name() == declObj.Name() - } - return objPath == declObjPath -} diff --git a/internal/lsp/source/rename.go b/internal/lsp/source/rename.go index 7187f14e41..01e0ff865e 100644 --- a/internal/lsp/source/rename.go +++ b/internal/lsp/source/rename.go @@ -7,7 +7,6 @@ package source import ( "bytes" "context" - "fmt" "go/ast" "go/format" "go/token" @@ -42,93 +41,84 @@ type PrepareItem struct { Text string } -func (i *IdentifierInfo) PrepareRename(ctx context.Context) (*PrepareItem, error) { +func PrepareRename(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) (*PrepareItem, error) { ctx, done := trace.StartSpan(ctx, "source.PrepareRename") defer done() - // TODO(rstambler): We should handle this in a better way. - // If the object declaration is nil, assume it is an import spec. - if i.Declaration.obj == nil { - // Find the corresponding package name for this import spec - // and rename that instead. - ident, err := i.getPkgName(ctx) - if err != nil { - return nil, err - } - rng, err := ident.mappedRange.Range() - if err != nil { - return nil, err - } - // We're not really renaming the import path. - rng.End = rng.Start - return &PrepareItem{ - Range: rng, - Text: ident.Name, - }, nil - } - - // Do not rename builtin identifiers. - if i.Declaration.obj.Parent() == types.Universe { - return nil, errors.Errorf("cannot rename builtin %q", i.Name) - } - rng, err := i.mappedRange.Range() + qos, err := qualifiedObjsAtProtocolPos(ctx, s, f, pp) if err != nil { return nil, err } + + // Do not rename builtin identifiers. + if qos[0].obj.Parent() == types.Universe { + return nil, errors.Errorf("cannot rename builtin %q", qos[0].obj.Name()) + } + + mr, err := posToMappedRange(s.View(), qos[0].sourcePkg, qos[0].node.Pos(), qos[0].node.End()) + if err != nil { + return nil, err + } + + rng, err := mr.Range() + if err != nil { + return nil, err + } + + if _, isImport := qos[0].node.(*ast.ImportSpec); isImport { + // We're not really renaming the import path. + rng.End = rng.Start + } + return &PrepareItem{ Range: rng, - Text: i.Name, + Text: qos[0].obj.Name(), }, nil } // Rename returns a map of TextEdits for each file modified when renaming a given identifier within a package. -func (i *IdentifierInfo) Rename(ctx context.Context, newName string) (map[span.URI][]protocol.TextEdit, error) { +func Rename(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position, newName string) (map[span.URI][]protocol.TextEdit, error) { ctx, done := trace.StartSpan(ctx, "source.Rename") defer done() - // TODO(rstambler): We should handle this in a better way. - // If the object declaration is nil, assume it is an import spec. - if i.Declaration.obj == nil { - // Find the corresponding package name for this import spec - // and rename that instead. - ident, err := i.getPkgName(ctx) - if err != nil { - return nil, err - } - return ident.Rename(ctx, newName) - } - if i.Name == newName { - return nil, errors.Errorf("old and new names are the same: %s", newName) - } - if !isValidIdentifier(newName) { - return nil, errors.Errorf("invalid identifier to rename: %q", i.Name) - } - // Do not rename builtin identifiers. - if i.Declaration.obj.Parent() == types.Universe { - return nil, errors.Errorf("cannot rename builtin %q", i.Name) - } - if i.pkg == nil || i.pkg.IsIllTyped() { - return nil, errors.Errorf("package for %s is ill typed", i.URI()) - } - // Do not rename identifiers declared in another package. - if i.pkg.GetTypes() != i.Declaration.obj.Pkg() { - return nil, errors.Errorf("failed to rename because %q is declared in package %q", i.Name, i.Declaration.obj.Pkg().Name()) - } - - refs, err := i.References(ctx) + qos, err := qualifiedObjsAtProtocolPos(ctx, s, f, pp) if err != nil { return nil, err } - // Make sure to add the declaration of the identifier. - refs = append(refs, i.DeclarationReferenceInfo()) + obj := qos[0].obj + pkg := qos[0].pkg + sourcePkg := qos[0].sourcePkg + + if obj.Name() == newName { + return nil, errors.Errorf("old and new names are the same: %s", newName) + } + if !isValidIdentifier(newName) { + return nil, errors.Errorf("invalid identifier to rename: %q", newName) + } + // Do not rename builtin identifiers. + if obj.Parent() == types.Universe { + return nil, errors.Errorf("cannot rename builtin %q", obj.Name()) + } + if pkg == nil || pkg.IsIllTyped() { + return nil, errors.Errorf("package for %s is ill typed", f.Identity().URI) + } + // Do not rename identifiers declared in another package. + if pkg != sourcePkg { + return nil, errors.Errorf("rename failed because %q is declared in a different package (%s)", obj.Name(), pkg.PkgPath()) + } + + refs, err := References(ctx, s, f, pp, true) + if err != nil { + return nil, err + } r := renamer{ ctx: ctx, - fset: i.Snapshot.View().Session().Cache().FileSet(), + fset: s.View().Session().Cache().FileSet(), refs: refs, objsToUpdate: make(map[types.Object]bool), - from: i.Name, + from: obj.Name(), to: newName, packages: make(map[*types.Package]Package), } @@ -155,7 +145,7 @@ func (i *IdentifierInfo) Rename(ctx context.Context, newName string) (map[span.U for uri, edits := range changes { // These edits should really be associated with FileHandles for maximal correctness. // For now, this is good enough. - fh, err := i.Snapshot.GetFile(uri) + fh, err := s.GetFile(uri) if err != nil { return nil, err } @@ -180,66 +170,6 @@ func (i *IdentifierInfo) Rename(ctx context.Context, newName string) (map[span.U return result, nil } -// getPkgName gets the pkg name associated with an identifier representing -// the import path in an import spec. -func (i *IdentifierInfo) getPkgName(ctx context.Context) (*IdentifierInfo, error) { - ph, err := i.pkg.File(i.URI()) - if err != nil { - return nil, fmt.Errorf("finding file for identifier %v: %v", i.Name, err) - } - file, _, _, err := ph.Cached() - if err != nil { - return nil, err - } - var namePos token.Pos - for _, spec := range file.Imports { - if spec.Path.Pos() == i.spanRange.Start { - namePos = spec.Pos() - break - } - } - if !namePos.IsValid() { - return nil, errors.Errorf("import spec not found for %q", i.Name) - } - // Look for the object defined at NamePos. - for _, obj := range i.pkg.GetTypesInfo().Defs { - pkgName, ok := obj.(*types.PkgName) - if ok && pkgName.Pos() == namePos { - return getPkgNameIdentifier(ctx, i, pkgName) - } - } - for _, obj := range i.pkg.GetTypesInfo().Implicits { - pkgName, ok := obj.(*types.PkgName) - if ok && pkgName.Pos() == namePos { - return getPkgNameIdentifier(ctx, i, pkgName) - } - } - return nil, errors.Errorf("no package name for %q", i.Name) -} - -// getPkgNameIdentifier returns an IdentifierInfo representing pkgName. -// pkgName must be in the same package and file as ident. -func getPkgNameIdentifier(ctx context.Context, ident *IdentifierInfo, pkgName *types.PkgName) (*IdentifierInfo, error) { - decl := Declaration{ - obj: pkgName, - } - var err error - if decl.mappedRange, err = objToMappedRange(ident.Snapshot.View(), ident.pkg, decl.obj); err != nil { - return nil, err - } - if decl.node, err = objToNode(ident.Snapshot.View(), ident.pkg, decl.obj); err != nil { - return nil, err - } - return &IdentifierInfo{ - Snapshot: ident.Snapshot, - Name: pkgName.Name(), - mappedRange: decl.mappedRange, - Declaration: decl, - pkg: ident.pkg, - qf: ident.qf, - }, nil -} - // Rename all references to the identifier. func (r *renamer) update() (map[span.URI][]diff.TextEdit, error) { result := make(map[span.URI][]diff.TextEdit) diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index d649673e70..401ecd4cf9 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -626,20 +626,14 @@ func (r *runner) References(t *testing.T, src span.Span, itemList []span.Span) { if err != nil { t.Fatal(err) } - ident, err := source.Identifier(ctx, r.view.Snapshot(), fh, srcRng.Start, source.WidestPackageHandle) - if err != nil { - t.Fatalf("failed for %v: %v", src, err) - } want := make(map[span.Span]bool) for _, pos := range itemList { want[pos] = true } - refs, err := ident.References(ctx) + refs, err := source.References(ctx, r.view.Snapshot(), fh, srcRng.Start, true) if err != nil { t.Fatalf("failed for %v: %v", src, err) } - // Add the item's declaration, since References omits it. - refs = append([]*source.ReferenceInfo{ident.DeclarationReferenceInfo()}, refs...) got := make(map[span.Span]bool) for _, refInfo := range refs { refSpan, err := refInfo.Span() @@ -670,12 +664,7 @@ func (r *runner) Rename(t *testing.T, spn span.Span, newText string) { if err != nil { t.Fatal(err) } - ident, err := source.Identifier(r.ctx, r.view.Snapshot(), fh, srcRng.Start, source.WidestPackageHandle) - if err != nil { - t.Error(err) - return - } - changes, err := ident.Rename(r.ctx, newText) + changes, err := source.Rename(r.ctx, r.view.Snapshot(), fh, srcRng.Start, newText) if err != nil { renamed := string(r.data.Golden(tag, spn.URI().Filename(), func() ([]byte, error) { return []byte(err.Error()), nil @@ -757,14 +746,7 @@ func (r *runner) PrepareRename(t *testing.T, src span.Span, want *source.Prepare if err != nil { t.Fatal(err) } - ident, err := source.Identifier(r.ctx, r.view.Snapshot(), fh, srcRng.Start, source.WidestPackageHandle) - if err != nil { - if want.Text != "" { // expected an ident. - t.Errorf("prepare rename failed for %v: got error: %v", src, err) - } - return - } - item, err := ident.PrepareRename(r.ctx) + item, err := source.PrepareRename(r.ctx, r.view.Snapshot(), fh, srcRng.Start) if err != nil { if want.Text != "" { // expected an ident. t.Errorf("prepare rename failed for %v: got error: %v", src, err) diff --git a/internal/lsp/testdata/references/refs.go b/internal/lsp/testdata/references/refs.go index e5a51fd97b..644723992c 100644 --- a/internal/lsp/testdata/references/refs.go +++ b/internal/lsp/testdata/references/refs.go @@ -1,6 +1,6 @@ package refs -type i int //@mark(typeI, "i"),refs("i", typeI, argI, returnI) +type i int //@mark(typeI, "i"),refs("i", typeI, argI, returnI, embeddedI) func _(_ i) []bool { //@mark(argI, "i") return nil @@ -12,10 +12,18 @@ func _(_ []byte) i { //@mark(returnI, "i") var q string //@mark(declQ, "q"),refs("q", declQ, assignQ, bobQ) -var Q string //@mark(declExpQ, "Q"), refs("Q", declExpQ, assignExpQ, bobExpQ) +var Q string //@mark(declExpQ, "Q"),refs("Q", declExpQ, assignExpQ, bobExpQ) func _() { q = "hello" //@mark(assignQ, "q") bob := func(_ string) {} bob(q) //@mark(bobQ, "q") } + +type e struct { + i //@mark(embeddedI, "i"),refs("i", embeddedI, embeddedIUse) +} + +func _() { + _ = e{}.i //@mark(embeddedIUse, "i") +} diff --git a/internal/lsp/testdata/references/refs_test.go b/internal/lsp/testdata/references/refs_test.go new file mode 100644 index 0000000000..08c0db1f05 --- /dev/null +++ b/internal/lsp/testdata/references/refs_test.go @@ -0,0 +1,10 @@ +package references + +import ( + "testing" +) + +// This test exists to bring the test package into existence. + +func TestReferences(t *testing.T) { +} diff --git a/internal/lsp/testdata/summary.txt.golden b/internal/lsp/testdata/summary.txt.golden index 132779e20e..4c60c56b1c 100644 --- a/internal/lsp/testdata/summary.txt.golden +++ b/internal/lsp/testdata/summary.txt.golden @@ -14,7 +14,7 @@ SuggestedFixCount = 1 DefinitionsCount = 43 TypeDefinitionsCount = 2 HighlightsCount = 45 -ReferencesCount = 7 +ReferencesCount = 8 RenamesCount = 22 PrepareRenamesCount = 8 SymbolsCount = 1