diff --git a/internal/lsp/cmd/suggested_fix.go b/internal/lsp/cmd/suggested_fix.go index 5e8b1fa32d..d80e066459 100644 --- a/internal/lsp/cmd/suggested_fix.go +++ b/internal/lsp/cmd/suggested_fix.go @@ -46,8 +46,8 @@ gopls fix flags are: // - if -d is specified, prints out unified diffs of the changes; or // - otherwise, prints the new versions to stdout. func (s *suggestedfix) Run(ctx context.Context, args ...string) error { - if len(args) != 1 { - return tool.CommandLineErrorf("fix expects 1 argument") + if len(args) < 1 { + return tool.CommandLineErrorf("fix expects at least 1 argument") } conn, err := s.app.connect(ctx) if err != nil { @@ -68,12 +68,20 @@ func (s *suggestedfix) Run(ctx context.Context, args ...string) error { conn.Client.filesMu.Lock() defer conn.Client.filesMu.Unlock() + codeActionKinds := []protocol.CodeActionKind{protocol.QuickFix} + if len(args) > 1 { + codeActionKinds = []protocol.CodeActionKind{} + for _, k := range args[1:] { + codeActionKinds = append(codeActionKinds, protocol.CodeActionKind(k)) + } + } + p := protocol.CodeActionParams{ TextDocument: protocol.TextDocumentIdentifier{ URI: protocol.URIFromSpanURI(uri), }, Context: protocol.CodeActionContext{ - Only: []protocol.CodeActionKind{protocol.QuickFix}, + Only: codeActionKinds, Diagnostics: file.diagnostics, }, } @@ -86,9 +94,28 @@ func (s *suggestedfix) Run(ctx context.Context, args ...string) error { if !a.IsPreferred && !s.All { continue } - for _, c := range a.Edit.DocumentChanges { - if fileURI(c.TextDocument.URI) == uri { - edits = append(edits, c.Edits...) + if !from.HasPosition() { + for _, c := range a.Edit.DocumentChanges { + if fileURI(c.TextDocument.URI) == uri { + edits = append(edits, c.Edits...) + } + } + continue + } + // If the span passed in has a position, then we need to find + // the codeaction that has the same range as the passed in span. + for _, diag := range a.Diagnostics { + spn, err := file.mapper.RangeSpan(diag.Range) + if err != nil { + continue + } + if span.ComparePoint(from.Start(), spn.Start()) == 0 { + for _, c := range a.Edit.DocumentChanges { + if fileURI(c.TextDocument.URI) == uri { + edits = append(edits, c.Edits...) + } + } + break } } } diff --git a/internal/lsp/cmd/test/suggested_fix.go b/internal/lsp/cmd/test/suggested_fix.go index 1963fdb97d..0419e33acd 100644 --- a/internal/lsp/cmd/test/suggested_fix.go +++ b/internal/lsp/cmd/test/suggested_fix.go @@ -5,16 +5,19 @@ package cmdtest import ( + "fmt" "testing" "golang.org/x/tools/internal/lsp/tests" "golang.org/x/tools/internal/span" ) -func (r *runner) SuggestedFix(t *testing.T, spn span.Span) { +func (r *runner) SuggestedFix(t *testing.T, spn span.Span, actionKinds []string) { uri := spn.URI() filename := uri.Filename() - got, _ := r.NormalizeGoplsCmd(t, "fix", "-a", filename) + args := []string{"fix", "-a", fmt.Sprintf("%s", spn)} + args = append(args, actionKinds...) + got, _ := r.NormalizeGoplsCmd(t, args...) want := string(r.data.Golden("suggestedfix_"+tests.SpanName(spn), filename, func() ([]byte, error) { return []byte(got), nil })) diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index dee2c6df75..110b1b7924 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -361,7 +361,7 @@ func (r *runner) Import(t *testing.T, spn span.Span) { } } -func (r *runner) SuggestedFix(t *testing.T, spn span.Span) { +func (r *runner) SuggestedFix(t *testing.T, spn span.Span, actionKinds []string) { uri := spn.URI() view, err := r.server.session.ViewOf(uri) if err != nil { @@ -397,12 +397,16 @@ func (r *runner) SuggestedFix(t *testing.T, spn span.Span) { if diag == nil { t.Fatalf("could not get any suggested fixes for %v", spn) } + codeActionKinds := []protocol.CodeActionKind{} + for _, k := range actionKinds { + codeActionKinds = append(codeActionKinds, protocol.CodeActionKind(k)) + } actions, err := r.server.CodeAction(r.ctx, &protocol.CodeActionParams{ TextDocument: protocol.TextDocumentIdentifier{ URI: protocol.URIFromSpanURI(uri), }, Context: protocol.CodeActionContext{ - Only: []protocol.CodeActionKind{protocol.QuickFix}, + Only: codeActionKinds, Diagnostics: toProtocolDiagnostics([]source.Diagnostic{*diag}), }, }) diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index 06144afde5..1740ae719e 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -467,7 +467,7 @@ func (r *runner) Import(t *testing.T, spn span.Span) { } } -func (r *runner) SuggestedFix(t *testing.T, spn span.Span) {} +func (r *runner) SuggestedFix(t *testing.T, spn span.Span, actionKinds []string) {} func (r *runner) Definition(t *testing.T, spn span.Span, d tests.Definition) { _, srcRng, err := spanToRange(r.data, d.Src) diff --git a/internal/lsp/testdata/indirect/primarymod/go.mod b/internal/lsp/testdata/indirect/primarymod/go.mod index cfc9f72171..9c24e47707 100644 --- a/internal/lsp/testdata/indirect/primarymod/go.mod +++ b/internal/lsp/testdata/indirect/primarymod/go.mod @@ -1,5 +1,5 @@ module indirect go 1.12 -//@diag("// indirect", "go mod tidy", "example.com/extramodule should be a direct dependency.", "warning"),suggestedfix("// indirect") +//@diag("// indirect", "go mod tidy", "example.com/extramodule should be a direct dependency.", "warning"),suggestedfix("// indirect", "quickfix") require example.com/extramodule v1.0.0 // indirect diff --git a/internal/lsp/testdata/indirect/primarymod/go.mod.golden b/internal/lsp/testdata/indirect/primarymod/go.mod.golden index 5bc28790b1..3e707df79a 100644 --- a/internal/lsp/testdata/indirect/primarymod/go.mod.golden +++ b/internal/lsp/testdata/indirect/primarymod/go.mod.golden @@ -3,6 +3,6 @@ module indirect go 1.12 -//@diag("// indirect", "go mod tidy", "example.com/extramodule should be a direct dependency.", "warning"),suggestedfix("// indirect") +//@diag("// indirect", "go mod tidy", "example.com/extramodule should be a direct dependency.", "warning"),suggestedfix("// indirect", "quickfix") require example.com/extramodule v1.0.0 diff --git a/internal/lsp/testdata/lsp/primarymod/suggestedfix/has_suggested_fix.go b/internal/lsp/testdata/lsp/primarymod/suggestedfix/has_suggested_fix.go index ccd198c9d9..e06dce0a84 100644 --- a/internal/lsp/testdata/lsp/primarymod/suggestedfix/has_suggested_fix.go +++ b/internal/lsp/testdata/lsp/primarymod/suggestedfix/has_suggested_fix.go @@ -6,6 +6,6 @@ import ( func goodbye() { s := "hiiiiiii" - s = s //@suggestedfix("s = s") + s = s //@suggestedfix("s = s", "quickfix") log.Print(s) } diff --git a/internal/lsp/testdata/lsp/primarymod/suggestedfix/has_suggested_fix.go.golden b/internal/lsp/testdata/lsp/primarymod/suggestedfix/has_suggested_fix.go.golden index 4923ecc691..9ccaa19946 100644 --- a/internal/lsp/testdata/lsp/primarymod/suggestedfix/has_suggested_fix.go.golden +++ b/internal/lsp/testdata/lsp/primarymod/suggestedfix/has_suggested_fix.go.golden @@ -7,7 +7,7 @@ import ( func goodbye() { s := "hiiiiiii" - //@suggestedfix("s = s") + //@suggestedfix("s = s", "quickfix") log.Print(s) } diff --git a/internal/lsp/testdata/missingdep/primarymod/main.go b/internal/lsp/testdata/missingdep/primarymod/main.go index b22d1fd350..18be555eef 100644 --- a/internal/lsp/testdata/missingdep/primarymod/main.go +++ b/internal/lsp/testdata/missingdep/primarymod/main.go @@ -2,7 +2,7 @@ package missingdep import ( - "example.com/extramodule/pkg" //@diag("\"example.com/extramodule/pkg\"", "go mod tidy", "example.com/extramodule is not in your go.mod file.", "warning"),suggestedfix("\"example.com/extramodule/pkg\"") + "example.com/extramodule/pkg" //@diag("\"example.com/extramodule/pkg\"", "go mod tidy", "example.com/extramodule is not in your go.mod file.", "warning"),suggestedfix("\"example.com/extramodule/pkg\"", "quickfix") ) func Yo() { diff --git a/internal/lsp/testdata/missingtwodep/primarymod/main.go b/internal/lsp/testdata/missingtwodep/primarymod/main.go index b6cdcdc9e1..081305d806 100644 --- a/internal/lsp/testdata/missingtwodep/primarymod/main.go +++ b/internal/lsp/testdata/missingtwodep/primarymod/main.go @@ -2,9 +2,9 @@ package missingtwodep import ( - "example.com/anothermodule/hey" //@diag("\"example.com/anothermodule/hey\"", "go mod tidy", "example.com/anothermodule is not in your go.mod file.", "warning"),suggestedfix("\"example.com/anothermodule/hey\"") - "example.com/extramodule/pkg" //@diag("\"example.com/extramodule/pkg\"", "go mod tidy", "example.com/extramodule is not in your go.mod file.", "warning"),suggestedfix("\"example.com/extramodule/pkg\"") - "example.com/extramodule/yo" //@diag("\"example.com/extramodule/yo\"", "go mod tidy", "example.com/extramodule is not in your go.mod file.", "warning"),suggestedfix("\"example.com/extramodule/yo\"") + "example.com/anothermodule/hey" //@diag("\"example.com/anothermodule/hey\"", "go mod tidy", "example.com/anothermodule is not in your go.mod file.", "warning"),suggestedfix("\"example.com/anothermodule/hey\"", "quickfix") + "example.com/extramodule/pkg" //@diag("\"example.com/extramodule/pkg\"", "go mod tidy", "example.com/extramodule is not in your go.mod file.", "warning"),suggestedfix("\"example.com/extramodule/pkg\"", "quickfix") + "example.com/extramodule/yo" //@diag("\"example.com/extramodule/yo\"", "go mod tidy", "example.com/extramodule is not in your go.mod file.", "warning"),suggestedfix("\"example.com/extramodule/yo\"", "quickfix") ) func Yo() { diff --git a/internal/lsp/testdata/unused/primarymod/go.mod b/internal/lsp/testdata/unused/primarymod/go.mod index 6ff23f03a5..78035bdb6e 100644 --- a/internal/lsp/testdata/unused/primarymod/go.mod +++ b/internal/lsp/testdata/unused/primarymod/go.mod @@ -2,4 +2,4 @@ module unused go 1.12 -require example.com/extramodule v1.0.0 //@diag("require example.com/extramodule v1.0.0", "go mod tidy", "example.com/extramodule is not used in this module.", "warning"),suggestedfix("require example.com/extramodule v1.0.0") +require example.com/extramodule v1.0.0 //@diag("require example.com/extramodule v1.0.0", "go mod tidy", "example.com/extramodule is not used in this module.", "warning"),suggestedfix("require example.com/extramodule v1.0.0", "quickfix") diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index eb1a852b5d..36fefe94cb 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -56,7 +56,7 @@ type RankCompletions map[span.Span][]Completion type FoldingRanges []span.Span type Formats []span.Span type Imports []span.Span -type SuggestedFixes []span.Span +type SuggestedFixes map[span.Span][]string type Definitions map[span.Span]Definition type Implementations map[span.Span][]span.Span type Highlights map[span.Span][]span.Span @@ -127,7 +127,7 @@ type Tests interface { FoldingRanges(*testing.T, span.Span) Format(*testing.T, span.Span) Import(*testing.T, span.Span) - SuggestedFix(*testing.T, span.Span) + SuggestedFix(*testing.T, span.Span, []string) Definition(*testing.T, span.Span, Definition) Implementation(*testing.T, span.Span, []span.Span) Highlight(*testing.T, span.Span, []span.Span) @@ -280,6 +280,7 @@ func Load(t testing.TB, exporter packagestest.Exporter, dir string) []*Data { References: make(References), Renames: make(Renames), PrepareRenames: make(PrepareRenames), + SuggestedFixes: make(SuggestedFixes), Symbols: make(Symbols), symbolsChildren: make(SymbolsChildren), symbolInformation: make(SymbolInformation), @@ -588,14 +589,14 @@ func Run(t *testing.T, tests Tests, data *Data) { t.Run("SuggestedFix", func(t *testing.T) { t.Helper() - for _, spn := range data.SuggestedFixes { + for spn, actionKinds := range data.SuggestedFixes { // Check if we should skip this spn if the -modfile flag is not available. if shouldSkip(data, spn.URI()) { continue } t.Run(SpanName(spn), func(t *testing.T) { t.Helper() - tests.SuggestedFix(t, spn) + tests.SuggestedFix(t, spn, actionKinds) }) } }) @@ -1001,8 +1002,11 @@ func (data *Data) collectImports(spn span.Span) { data.Imports = append(data.Imports, spn) } -func (data *Data) collectSuggestedFixes(spn span.Span) { - data.SuggestedFixes = append(data.SuggestedFixes, spn) +func (data *Data) collectSuggestedFixes(spn span.Span, actionKind string) { + if _, ok := data.SuggestedFixes[spn]; !ok { + data.SuggestedFixes[spn] = []string{} + } + data.SuggestedFixes[spn] = append(data.SuggestedFixes[spn], actionKind) } func (data *Data) collectDefinitions(src, target span.Span) {