diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go index 8729b92848..a1cf59417f 100644 --- a/internal/lsp/source/completion.go +++ b/internal/lsp/source/completion.go @@ -363,6 +363,13 @@ type candidate struct { // For example, expandFuncCall=true yields "foo()", expandFuncCall=false yields "foo". expandFuncCall bool + // takeAddress is true if the completion should take a pointer to obj. + // For example, takeAddress=true yields "&foo", takeAddress=false yields "foo". + takeAddress bool + + // addressable is true if a pointer can be taken to the candidate. + addressable bool + // imp is the import that needs to be added to this package in order // for this candidate to be valid. nil if no import needed. imp *importInfo @@ -537,7 +544,8 @@ func Completion(ctx context.Context, snapshot Snapshot, fh FileHandle, pos proto return c.items, c.getSurrounding(), nil } -// populateCommentCompletions returns completions for an exported variable immediately preceeding comment +// populateCommentCompletions yields completions for an exported +// variable immediately preceding comment. func (c *completer) populateCommentCompletions(comment *ast.CommentGroup) { // Using the comment position find the line after @@ -653,10 +661,12 @@ func (c *completer) selector(sel *ast.SelectorExpr) error { func (c *completer) packageMembers(pkg *types.Package, imp *importInfo) { scope := pkg.Scope() for _, name := range scope.Names() { + obj := scope.Lookup(name) c.found(candidate{ - obj: scope.Lookup(name), - score: stdScore, - imp: imp, + obj: obj, + score: stdScore, + imp: imp, + addressable: isVar(obj), }) } } @@ -676,18 +686,20 @@ func (c *completer) methodsAndFields(typ types.Type, addressable bool, imp *impo for i := 0; i < mset.Len(); i++ { c.found(candidate{ - obj: mset.At(i).Obj(), - score: stdScore, - imp: imp, + obj: mset.At(i).Obj(), + score: stdScore, + imp: imp, + addressable: addressable || isPointer(typ), }) } // Add fields of T. for _, f := range fieldSelections(typ) { c.found(candidate{ - obj: f, - score: stdScore, - imp: imp, + obj: f, + score: stdScore, + imp: imp, + addressable: addressable || isPointer(typ), }) } return nil @@ -778,8 +790,9 @@ func (c *completer) lexical() error { if _, ok := seen[obj.Name()]; !ok { seen[obj.Name()] = struct{}{} c.found(candidate{ - obj: obj, - score: score, + obj: obj, + score: score, + addressable: isVar(obj), }) } } @@ -1116,11 +1129,11 @@ type typeModifier struct { type typeMod int const ( - star typeMod = iota // dereference operator for expressions, pointer indicator for types - reference // reference ("&") operator - chanRead // channel read ("<-") operator - slice // make a slice type ("[]" in "[]int") - array // make an array type ("[2]" in "[2]int") + star typeMod = iota // pointer indirection for expressions, pointer indicator for types + address // address operator ("&") + chanRead // channel read operator ("<-") + slice // make a slice type ("[]" in "[]int") + array // make an array type ("[2]" in "[2]int") ) // typeInference holds information we have inferred about a type that can be @@ -1347,7 +1360,7 @@ Nodes: case *ast.UnaryExpr: switch node.Op { case token.AND: - inf.modifiers = append(inf.modifiers, typeModifier{mod: reference}) + inf.modifiers = append(inf.modifiers, typeModifier{mod: address}) case token.ARROW: inf.modifiers = append(inf.modifiers, typeModifier{mod: chanRead}) } @@ -1362,22 +1375,36 @@ Nodes: } // applyTypeModifiers applies the list of type modifiers to a type. -func (ti typeInference) applyTypeModifiers(typ types.Type) types.Type { +// It returns nil if the modifiers could not be applied. +func (ti typeInference) applyTypeModifiers(typ types.Type, addressable bool) types.Type { for _, mod := range ti.modifiers { switch mod.mod { case star: - // For every "*" deref operator, remove a pointer layer from candidate type. - typ = deref(typ) - case reference: - // For every "&" ref operator, add another pointer layer to candidate type. - typ = types.NewPointer(typ) + // For every "*" indirection operator, remove a pointer layer + // from candidate type. + if ptr, ok := typ.Underlying().(*types.Pointer); ok { + typ = ptr.Elem() + } else { + return nil + } + case address: + // For every "&" address operator, add another pointer layer to + // candidate type, if the candidate is addressable. + if addressable { + typ = types.NewPointer(typ) + } else { + return nil + } case chanRead: // For every "<-" operator, remove a layer of channelness. if ch, ok := typ.(*types.Chan); ok { typ = ch.Elem() + } else { + return nil } } } + return typ } @@ -1544,10 +1571,8 @@ Nodes: } } -// matchingType reports whether a type matches the expected type. -func (c *completer) matchingType(T types.Type) bool { - fakeObj := types.NewVar(token.NoPos, c.pkg.GetTypes(), "", T) - return c.matchingCandidate(&candidate{obj: fakeObj}) +func (c *completer) fakeObj(T types.Type) *types.Var { + return types.NewVar(token.NoPos, c.pkg.GetTypes(), "", T) } // matchingCandidate reports whether a candidate matches our type @@ -1576,7 +1601,10 @@ func (c *completer) matchingCandidate(cand *candidate) bool { } // Take into account any type modifiers on the expected type. - candType = c.expectedType.applyTypeModifiers(candType) + candType = c.expectedType.applyTypeModifiers(candType, cand.addressable) + if candType == nil { + return false + } // Handle untyped values specially since AssignableTo gives false negatives // for them (see https://golang.org/issue/32146). @@ -1629,8 +1657,15 @@ func (c *completer) matchingCandidate(cand *candidate) bool { } } - if c.expectedType.convertibleTo != nil { - return types.ConvertibleTo(candType, c.expectedType.convertibleTo) + if c.expectedType.convertibleTo != nil && types.ConvertibleTo(candType, c.expectedType.convertibleTo) { + return true + } + + // Check if cand is addressable and a pointer to cand matches our type inference. + if cand.addressable && c.matchingCandidate(&candidate{obj: c.fakeObj(types.NewPointer(candType))}) { + // Mark the candidate so we know to prepend "&" when formatting. + cand.takeAddress = true + return true } return false diff --git a/internal/lsp/source/completion_format.go b/internal/lsp/source/completion_format.go index 7f965beedb..9147cdf605 100644 --- a/internal/lsp/source/completion_format.go +++ b/internal/lsp/source/completion_format.go @@ -130,6 +130,30 @@ func (c *completer) item(cand candidate) (CompletionItem, error) { } } + // Prepend "&" operator if our candidate needs address taken. + if cand.takeAddress { + var ( + sel *ast.SelectorExpr + ok bool + ) + if sel, ok = c.path[0].(*ast.SelectorExpr); !ok && len(c.path) > 1 { + sel, _ = c.path[1].(*ast.SelectorExpr) + } + + // If we are in a selector, add an edit to place "&" before selector node. + if sel != nil { + edits, err := referenceEdit(c.snapshot.View().Session().Cache().FileSet(), c.mapper, sel) + if err != nil { + log.Error(c.ctx, "error generating reference edit", err) + } else { + protocolEdits = append(protocolEdits, edits...) + } + } else { + // If there is no selector, just stick the "&" at the start. + insert = "&" + insert + } + } + detail = strings.TrimPrefix(detail, "untyped ") item := CompletionItem{ Label: label, diff --git a/internal/lsp/source/completion_literal.go b/internal/lsp/source/completion_literal.go index 858fc48d4e..816e1b3148 100644 --- a/internal/lsp/source/completion_literal.go +++ b/internal/lsp/source/completion_literal.go @@ -63,14 +63,13 @@ func (c *completer) literal(literalType types.Type, imp *importInfo) { } } - // Check if an object of type literalType or *literalType would - // match our expected type. - var isPointer bool - if !c.matchingType(literalType) { - isPointer = true - if !c.matchingType(types.NewPointer(literalType)) { - return - } + // Check if an object of type literalType would match our expected type. + cand := candidate{ + obj: c.fakeObj(literalType), + addressable: true, + } + if !c.matchingCandidate(&cand) { + return } var ( @@ -105,7 +104,7 @@ func (c *completer) literal(literalType types.Type, imp *importInfo) { // If prefix matches the type name, client may want a composite literal. if score := c.matcher.Score(matchName); score >= 0 { - if isPointer { + if cand.takeAddress { if sel != nil { // If we are in a selector we must place the "&" before the selector. // For example, "foo.B<>" must complete to "&foo.Bar{}", not @@ -146,7 +145,7 @@ func (c *completer) literal(literalType types.Type, imp *importInfo) { // If prefix matches "make", client may want a "make()" // invocation. We also include the type name to allow for more // flexible fuzzy matching. - if score := c.matcher.Score("make." + matchName); !isPointer && score >= 0 { + if score := c.matcher.Score("make." + matchName); !cand.takeAddress && score >= 0 { switch literalType.Underlying().(type) { case *types.Slice: // The second argument to "make()" for slices is required, so default to "0". @@ -159,7 +158,7 @@ func (c *completer) literal(literalType types.Type, imp *importInfo) { } // If prefix matches "func", client may want a function literal. - if score := c.matcher.Score("func"); !isPointer && score >= 0 && !isInterface(expType) { + if score := c.matcher.Score("func"); !cand.takeAddress && score >= 0 && !isInterface(expType) { switch t := literalType.Underlying().(type) { case *types.Signature: c.functionLiteral(t, float64(score)) diff --git a/internal/lsp/source/deep_completion.go b/internal/lsp/source/deep_completion.go index de030c71e5..d6297d06cb 100644 --- a/internal/lsp/source/deep_completion.go +++ b/internal/lsp/source/deep_completion.go @@ -188,9 +188,7 @@ func (c *completer) deepSearch(cand candidate) { case *types.PkgName: c.packageMembers(obj.Imported(), cand.imp) default: - // For now it is okay to assume obj is addressable since we don't search beyond - // function calls. - c.methodsAndFields(obj.Type(), true, cand.imp) + c.methodsAndFields(obj.Type(), cand.addressable, cand.imp) } // Pop the object off our search stack. diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index a2e391535d..c53a787b10 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -189,23 +189,12 @@ func (r *runner) FuzzyCompletion(t *testing.T, src span.Span, test tests.Complet for _, pos := range test.CompletionItems { want = append(want, tests.ToProtocolCompletionItem(*items[pos])) } - prefix, list := r.callCompletion(t, src, source.CompletionOptions{ + _, got := r.callCompletion(t, src, source.CompletionOptions{ FuzzyMatching: true, Deep: true, }) if !strings.Contains(string(src.URI()), "builtins") { - list = tests.FilterBuiltins(list) - } - var fuzzyMatcher *fuzzy.Matcher - if prefix != "" { - fuzzyMatcher = fuzzy.NewMatcher(prefix) - } - var got []protocol.CompletionItem - for _, item := range list { - if fuzzyMatcher != nil && fuzzyMatcher.Score(item.Label) <= 0 { - continue - } - got = append(got, item) + got = tests.FilterBuiltins(got) } if msg := tests.DiffCompletionItems(want, got); msg != "" { t.Errorf("%s: %s", src, msg) @@ -233,19 +222,11 @@ func (r *runner) RankCompletion(t *testing.T, src span.Span, test tests.Completi for _, pos := range test.CompletionItems { want = append(want, tests.ToProtocolCompletionItem(*items[pos])) } - prefix, list := r.callCompletion(t, src, source.CompletionOptions{ + _, got := r.callCompletion(t, src, source.CompletionOptions{ FuzzyMatching: true, Deep: true, Literal: true, }) - fuzzyMatcher := fuzzy.NewMatcher(prefix) - var got []protocol.CompletionItem - for _, item := range list { - if fuzzyMatcher.Score(item.Label) <= 0 { - continue - } - got = append(got, item) - } if msg := tests.CheckCompletionOrder(want, got, true); msg != "" { t.Errorf("%s: %s", src, msg) } diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go index 7a648394b7..91678b67a1 100644 --- a/internal/lsp/source/util.go +++ b/internal/lsp/source/util.go @@ -387,6 +387,11 @@ func isPointer(T types.Type) bool { return ok } +func isVar(obj types.Object) bool { + _, ok := obj.(*types.Var) + return ok +} + // deref returns a pointer's element type, traversing as many levels as needed. // Otherwise it returns typ. func deref(typ types.Type) types.Type { diff --git a/internal/lsp/testdata/address/address.go b/internal/lsp/testdata/address/address.go new file mode 100644 index 0000000000..478ce72036 --- /dev/null +++ b/internal/lsp/testdata/address/address.go @@ -0,0 +1,53 @@ +package address + +func wantsPtr(*int) {} + +type foo struct{ c int } //@item(addrFieldC, "c", "int", "field") + +func _() { + var ( + a string //@item(addrA, "a", "string", "var") + b int //@item(addrB, "b", "int", "var") + ) + + wantsPtr() //@rank(")", addrB, addrA),snippet(")", addrB, "&b", "&b") + wantsPtr(&b) //@snippet(")", addrB, "b", "b") + + var s foo + s.c //@item(addrDeepC, "s.c", "int", "field") + wantsPtr() //@snippet(")", addrDeepC, "&s.c", "&s.c") + wantsPtr(s) //@snippet(")", addrDeepC, "&s.c", "&s.c") + wantsPtr(&s) //@snippet(")", addrDeepC, "s.c", "s.c") + + // don't add "&" in item (it gets added as an additional edit) + wantsPtr(&s.c) //@snippet(")", addrFieldC, "c", "c") +} + +func (f foo) ptr() *foo { return &f } + +func _() { + getFoo := func() foo { return foo{} } + + // not addressable + getFoo().c //@item(addrGetFooC, "getFoo().c", "int", "field") + + // addressable + getFoo().ptr().c //@item(addrGetFooPtrC, "getFoo().ptr().c", "int", "field") + + wantsPtr() //@rank(addrGetFooPtrC, addrGetFooC),snippet(")", addrGetFooPtrC, "&getFoo().ptr().c", "&getFoo().ptr().c") + wantsPtr(&g) //@rank(addrGetFooPtrC, addrGetFooC),snippet(")", addrGetFooPtrC, "getFoo().ptr().c", "getFoo().ptr().c") +} + +type nested struct { + f foo +} + +func _() { + getNested := func() nested { return nested{} } + + getNested().f.c //@item(addrNestedC, "getNested().f.c", "int", "field") + getNested().f.ptr().c //@item(addrNestedPtrC, "getNested().f.ptr().c", "int", "field") + + // addrNestedC is not addressable, so rank lower + wantsPtr(getNestedfc) //@fuzzy(")", addrNestedPtrC, addrNestedC) +} diff --git a/internal/lsp/testdata/channel/channel.go b/internal/lsp/testdata/channel/channel.go index dc559513bf..d6bd311e33 100644 --- a/internal/lsp/testdata/channel/channel.go +++ b/internal/lsp/testdata/channel/channel.go @@ -20,6 +20,6 @@ func _() { { var foo chan int //@item(channelFoo, "foo", "chan int", "var") wantsInt := func(int) {} //@item(channelWantsInt, "wantsInt", "func(int)", "var") - wantsInt(<-) //@complete(")", channelFoo, channelAB, channelWantsInt, channelAA) + wantsInt(<-) //@rank(")", channelFoo, channelAB) } } diff --git a/internal/lsp/testdata/deep/deep.go b/internal/lsp/testdata/deep/deep.go index 1bedc9416c..a7c659aa2f 100644 --- a/internal/lsp/testdata/deep/deep.go +++ b/internal/lsp/testdata/deep/deep.go @@ -1,7 +1,3 @@ -// Copyright 2019 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - package deep import "context" diff --git a/internal/lsp/testdata/summary.txt.golden b/internal/lsp/testdata/summary.txt.golden index ff727e2566..94c8cf6738 100644 --- a/internal/lsp/testdata/summary.txt.golden +++ b/internal/lsp/testdata/summary.txt.golden @@ -1,10 +1,10 @@ -- summary -- -CompletionsCount = 224 -CompletionSnippetCount = 53 +CompletionsCount = 223 +CompletionSnippetCount = 61 UnimportedCompletionsCount = 4 DeepCompletionsCount = 5 -FuzzyCompletionsCount = 7 -RankedCompletionsCount = 28 +FuzzyCompletionsCount = 8 +RankedCompletionsCount = 32 CaseSensitiveCompletionsCount = 4 DiagnosticsCount = 35 FoldingRangesCount = 2