diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go index 2f2f247286..541306245b 100644 --- a/internal/lsp/cache/parse.go +++ b/internal/lsp/cache/parse.go @@ -11,6 +11,7 @@ import ( "go/parser" "go/scanner" "go/token" + "reflect" "golang.org/x/tools/internal/lsp/source" "golang.org/x/tools/internal/lsp/telemetry" @@ -171,21 +172,30 @@ func isEllipsisArray(n ast.Expr) bool { // fix inspects the AST and potentially modifies any *ast.BadStmts so that it can be // type-checked more effectively. func fix(ctx context.Context, file *ast.File, tok *token.File, src []byte) error { - var parent ast.Node - var err error + var ( + ancestors []ast.Node + err error + ) ast.Inspect(file, func(n ast.Node) bool { if n == nil { + if len(ancestors) > 0 { + ancestors = ancestors[:len(ancestors)-1] + } return false } switch n := n.(type) { case *ast.BadStmt: + var parent ast.Node + if len(ancestors) > 0 { + parent = ancestors[len(ancestors)-1] + } err = parseDeferOrGoStmt(n, parent, tok, src) // don't shadow err if err != nil { err = errors.Errorf("unable to parse defer or go from *ast.BadStmt: %v", err) } return false default: - parent = n + ancestors = append(ancestors, n) return true } }) @@ -204,9 +214,10 @@ func parseDeferOrGoStmt(bad *ast.BadStmt, parent ast.Node, tok *token.File, src s := &scanner.Scanner{} s.Init(tok, src, nil, 0) - var pos token.Pos - var tkn token.Token - var lit string + var ( + pos token.Pos + tkn token.Token + ) for { if tkn == token.EOF { return errors.Errorf("reached the end of the file") @@ -214,15 +225,16 @@ func parseDeferOrGoStmt(bad *ast.BadStmt, parent ast.Node, tok *token.File, src if pos >= bad.From { break } - pos, tkn, lit = s.Scan() + pos, tkn, _ = s.Scan() } + var stmt ast.Stmt - switch lit { - case "defer": + switch tkn { + case token.DEFER: stmt = &ast.DeferStmt{ Defer: pos, } - case "go": + case token.GO: stmt = &ast.GoStmt{ Go: pos, } @@ -230,38 +242,99 @@ func parseDeferOrGoStmt(bad *ast.BadStmt, parent ast.Node, tok *token.File, src return errors.Errorf("no defer or go statement found") } - // The expression after the "defer" or "go" starts at this position. - from, _, _ := s.Scan() - var to, curr token.Pos + var ( + from, to, last token.Pos + lastToken token.Token + braceDepth int + phantomSelectors []token.Pos + ) FindTo: for { - curr, tkn, _ = s.Scan() - // TODO(rstambler): This still needs more handling to work correctly. - // We encounter a specific issue with code that looks like this: - // - // defer x.<> - // y := 1 - // - // In this scenario, we parse it as "defer x.y", which then fails to - // type-check, and we don't get completions as expected. - switch tkn { - case token.COMMENT, token.EOF, token.SEMICOLON, token.DEFINE: - break FindTo + to, tkn, _ = s.Scan() + + if from == token.NoPos { + from = to + } + + switch tkn { + case token.EOF: + break FindTo + case token.SEMICOLON: + // If we aren't in nested braces, end of statement means + // end of expression. + if braceDepth == 0 { + break FindTo + } + case token.LBRACE: + braceDepth++ + } + + // This handles the common dangling selector case. For example in + // + // defer fmt. + // y := 1 + // + // we notice the dangling period and end our expression. + // + // If the previous token was a "." and we are looking at a "}", + // the period is likely a dangling selector and needs a phantom + // "_". Likewise if the currnet token is on a different line than + // the period, the period is likely a dangling selector. + if lastToken == token.PERIOD && (tkn == token.RBRACE || tok.Line(to) > tok.Line(last)) { + // Insert phantom "_" selector after the dangling ".". + phantomSelectors = append(phantomSelectors, last+1) + // If we aren't in a block then end the expression after the ".". + if braceDepth == 0 { + to = last + 1 + break + } + } + + lastToken = tkn + last = to + + switch tkn { + case token.RBRACE: + braceDepth-- + if braceDepth <= 0 { + if braceDepth == 0 { + // +1 to include the "}" itself. + to += 1 + } + break FindTo + } } - // to is the end of expression that should become the Fun part of the call. - to = curr } + if !from.IsValid() || tok.Offset(from) >= len(src) { return errors.Errorf("invalid from position") } - if !to.IsValid() || tok.Offset(to)+1 >= len(src) { - return errors.Errorf("invalid to position") + + if !to.IsValid() || tok.Offset(to) >= len(src) { + return errors.Errorf("invalid to position %d", to) } - exprstr := string(src[tok.Offset(from) : tok.Offset(to)+1]) - expr, err := parser.ParseExpr(exprstr) + + // Insert any phantom selectors needed to prevent dangling "." from messing + // up the AST. + exprBytes := make([]byte, 0, int(to-from)+len(phantomSelectors)) + for i, b := range src[tok.Offset(from):tok.Offset(to)] { + if len(phantomSelectors) > 0 && from+token.Pos(i) == phantomSelectors[0] { + exprBytes = append(exprBytes, '_') + phantomSelectors = phantomSelectors[1:] + } + exprBytes = append(exprBytes, b) + } + + if len(phantomSelectors) > 0 { + exprBytes = append(exprBytes, '_') + } + + exprStr := string(exprBytes) + expr, err := parser.ParseExpr(exprStr) if expr == nil { - return errors.Errorf("no expr in %s: %v", exprstr, err) + return errors.Errorf("no expr in %s: %v", exprStr, err) } + // parser.ParseExpr returns undefined positions. // Adjust them for the current file. offsetPositions(expr, from-1) @@ -290,16 +363,33 @@ FindTo: return nil } +var tokenPosType = reflect.TypeOf(token.NoPos) + // offsetPositions applies an offset to the positions in an ast.Node. -// TODO(rstambler): Add more cases here as they become necessary. func offsetPositions(expr ast.Expr, offset token.Pos) { ast.Inspect(expr, func(n ast.Node) bool { - switch n := n.(type) { - case *ast.Ident: - n.NamePos += offset + if n == nil { return false - default: - return true } + + v := reflect.ValueOf(n).Elem() + + switch v.Kind() { + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + if f.Type() != tokenPosType { + continue + } + + if !f.CanSet() { + continue + } + + f.SetInt(int64(f.Interface().(token.Pos) + offset)) + } + } + + return true }) } diff --git a/internal/lsp/source/completion_snippet.go b/internal/lsp/source/completion_snippet.go index fd21acac7b..90cf02a33a 100644 --- a/internal/lsp/source/completion_snippet.go +++ b/internal/lsp/source/completion_snippet.go @@ -61,12 +61,15 @@ func (c *completer) functionCallSnippets(name string, params []string) (*snippet if len(c.path) > 1 { switch n := c.path[1].(type) { case *ast.CallExpr: - if n.Fun == c.path[0] { + // The Lparen != Rparen check detects fudged CallExprs we + // inserted when fixing the AST. In this case, we do still need + // to insert the calling "()" parens. + if n.Fun == c.path[0] && n.Lparen != n.Rparen { return nil, nil } case *ast.SelectorExpr: if len(c.path) > 2 { - if call, ok := c.path[2].(*ast.CallExpr); ok && call.Fun == c.path[1] { + if call, ok := c.path[2].(*ast.CallExpr); ok && call.Fun == c.path[1] && call.Lparen != call.Rparen { return nil, nil } } diff --git a/internal/lsp/testdata/badstmt/badstmt.go b/internal/lsp/testdata/badstmt/badstmt.go deleted file mode 100644 index 2ff6e09073..0000000000 --- a/internal/lsp/testdata/badstmt/badstmt.go +++ /dev/null @@ -1,10 +0,0 @@ -package badstmt - -import ( - "golang.org/x/tools/internal/lsp/foo" -) - -func _() { - defer foo.F //@complete("F", Foo, IntFoo, StructFoo),diag(" //", "LSP", "function must be invoked in defer statement") - go foo.F //@complete("F", Foo, IntFoo, StructFoo) -} \ No newline at end of file diff --git a/internal/lsp/testdata/badstmt/badstmt.go.in b/internal/lsp/testdata/badstmt/badstmt.go.in new file mode 100644 index 0000000000..3aae7db679 --- /dev/null +++ b/internal/lsp/testdata/badstmt/badstmt.go.in @@ -0,0 +1,18 @@ +package badstmt + +import ( + "golang.org/x/tools/internal/lsp/foo" +) + +func _() { + defer foo.F //@complete(" //", Foo),diag(" //", "LSP", "function must be invoked in defer statement") + y := 1 + defer foo.F //@complete(" //", Foo) +} + +func _() { + defer func() { + foo.F //@complete(" //", Foo),snippet(" //", Foo, "Foo()", "Foo()") + foo. //@complete(" //", Foo, IntFoo, StructFoo) + } +} diff --git a/internal/lsp/testdata/badstmt/badstmt_2.go.in b/internal/lsp/testdata/badstmt/badstmt_2.go.in new file mode 100644 index 0000000000..294701d963 --- /dev/null +++ b/internal/lsp/testdata/badstmt/badstmt_2.go.in @@ -0,0 +1,9 @@ +package badstmt + +import ( + "golang.org/x/tools/internal/lsp/foo" +) + +func _() { + defer func() { foo. } //@complete(" }", Foo, IntFoo, StructFoo) +} diff --git a/internal/lsp/testdata/badstmt/badstmt_3.go.in b/internal/lsp/testdata/badstmt/badstmt_3.go.in new file mode 100644 index 0000000000..656ad76738 --- /dev/null +++ b/internal/lsp/testdata/badstmt/badstmt_3.go.in @@ -0,0 +1,9 @@ +package badstmt + +import ( + "golang.org/x/tools/internal/lsp/foo" +) + +func _() { + go foo. //@complete(" //", Foo, IntFoo, StructFoo) +} diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index 29b5b58aa4..ca33dcaadb 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -29,8 +29,8 @@ import ( // We hardcode the expected number of test cases to ensure that all tests // are being executed. If a test is added, this number must be changed. const ( - ExpectedCompletionsCount = 155 - ExpectedCompletionSnippetCount = 15 + ExpectedCompletionsCount = 159 + ExpectedCompletionSnippetCount = 16 ExpectedDiagnosticsCount = 21 ExpectedFormatCount = 6 ExpectedImportCount = 2