diff --git a/internal/lsp/source/format.go b/internal/lsp/source/format.go index 5823a6ef0c..09a903fd5b 100644 --- a/internal/lsp/source/format.go +++ b/internal/lsp/source/format.go @@ -192,8 +192,9 @@ func computeFixEdits(view View, ph ParseGoHandle, options *imports.Options, orig return ToProtocolEdits(origMapper, edits) } -// return the prefix of the src through the last imports, or if there are -// no imports, through the package statement (and a subsequent comment group) +// importPrefix returns the prefix of the given file content through the final +// import statement. If there are no imports, the prefix is the package +// statement and any comment groups below it. func importPrefix(src []byte) string { fset := token.NewFileSet() // do as little parsing as possible @@ -201,28 +202,54 @@ func importPrefix(src []byte) string { if err != nil { // This can happen if 'package' is misspelled return "" } - myStart := fset.File(f.Pos()).Base() // 1, but the generality costs little - pkgEnd := int(f.Name.NamePos) + len(f.Name.Name) + tok := fset.File(f.Pos()) var importEnd int for _, d := range f.Decls { if x, ok := d.(*ast.GenDecl); ok && x.Tok == token.IMPORT { - e := int(d.End()) - myStart - if e > importEnd { + if e := tok.Offset(d.End()); e > importEnd { importEnd = e } } } - if importEnd == 0 { - importEnd = pkgEnd - if importEnd > len(src) { - importEnd-- // pkgEnd is off by 1 because Pos is 1-based + + maybeAdjustToLineEnd := func(pos token.Pos, isCommentNode bool) int { + offset := tok.Offset(pos) + + // Don't go past the end of the file. + if offset > len(src) { + offset = len(src) } + // The go/ast package does not account for different line endings, and + // specifically, in the text of a comment, it will strip out \r\n line + // endings in favor of \n. To account for these differences, we try to + // return a position on the next line whenever possible. + switch line := tok.Line(tok.Pos(offset)); { + case line < tok.LineCount(): + nextLineOffset := tok.Offset(tok.LineStart(line + 1)) + // If we found a position that is at the end of a line, move the + // offset to the start of the next line. + if offset+1 == nextLineOffset { + offset = nextLineOffset + } + case isCommentNode, offset+1 == tok.Size(): + // If the last line of the file is a comment, or we are at the end + // of the file, the prefix is the entire file. + offset = len(src) + } + return offset + } + if importEnd == 0 { + pkgEnd := f.Name.End() + importEnd = maybeAdjustToLineEnd(pkgEnd, false) } for _, c := range f.Comments { - if int(c.End()) > importEnd { - importEnd = int(c.End()) + if end := tok.Offset(c.End()); end > importEnd { + importEnd = maybeAdjustToLineEnd(c.End(), true) } } + if importEnd > len(src) { + importEnd = len(src) + } return string(src[:importEnd]) } diff --git a/internal/lsp/source/format_test.go b/internal/lsp/source/format_test.go index f929812bce..50308064a9 100644 --- a/internal/lsp/source/format_test.go +++ b/internal/lsp/source/format_test.go @@ -1,15 +1,17 @@ package source import ( + "fmt" "testing" + + "golang.org/x/tools/internal/lsp/diff" + "golang.org/x/tools/internal/lsp/diff/myers" ) -type data struct { - input, want string -} - func TestImportPrefix(t *testing.T) { - var tdata = []data{ + for i, tt := range []struct { + input, want string + }{ {"package foo", "package foo"}, {"package foo\n", "package foo\n"}, {"package foo\n\nfunc f(){}\n", "package foo\n"}, @@ -19,13 +21,29 @@ func TestImportPrefix(t *testing.T) { {"// hi \n\npackage foo //xx\nfunc _(){}\n", "// hi \n\npackage foo //xx\n"}, {"package foo //hi\n", "package foo //hi\n"}, {"//hi\npackage foo\n//a\n\n//b\n", "//hi\npackage foo\n//a\n\n//b\n"}, - {"package a\n\nimport (\n \"fmt\"\n)\n//hi\n", - "package a\n\nimport (\n \"fmt\"\n)\n//hi\n"}, - } - for i, x := range tdata { - got := importPrefix([]byte(x.input)) - if got != x.want { - t.Errorf("%d: got\n%q, wanted\n%q for %q", i, got, x.want, x.input) + { + "package a\n\nimport (\n \"fmt\"\n)\n//hi\n", + "package a\n\nimport (\n \"fmt\"\n)\n//hi\n", + }, + {`package a /*hi*/`, `package a /*hi*/`}, + {"package main\r\n\r\nimport \"go/types\"\r\n\r\n/*\r\n\r\n */\r\n", "package main\r\n\r\nimport \"go/types\"\r\n\r\n/*\r\n\r\n */\r\n"}, + {"package x; import \"os\"; func f() {}\n\n", "package x; import \"os\""}, + {"package x; func f() {fmt.Println()}\n\n", "package x"}, + } { + got := importPrefix([]byte(tt.input)) + if got != tt.want { + t.Errorf("%d: failed for %q:\n%s", i, tt.input, diffStr(tt.want, got)) } } } + +func diffStr(want, got string) string { + if want == got { + return "" + } + // Add newlines to avoid newline messages in diff. + want += "\n" + got += "\n" + d := myers.ComputeEdits("", want, got) + return fmt.Sprintf("%q", diff.ToUnified("want", "got", want, d)) +}