From 46d1522a5d8300cc04a811920df78ad0485885bc Mon Sep 17 00:00:00 2001 From: Suzy Mueller Date: Tue, 22 Jun 2021 05:14:28 -0400 Subject: [PATCH] internal/lsp: add extract to method code action "Extract method" allows users to take a code fragment and move it to a separate method. This is available if the enclosing function is a method. Change-Id: Ib824f6b79b13ca73532223283a050946c90a47e7 Reviewed-on: https://go-review.googlesource.com/c/tools/+/330070 Trust: Suzy Mueller Run-TryBot: Suzy Mueller gopls-CI: kokoro TryBot-Result: Go Bot Reviewed-by: Rebecca Stambler --- internal/lsp/cmd/test/cmdtest.go | 4 + internal/lsp/code_action.go | 15 +- internal/lsp/lsp_test.go | 60 +- internal/lsp/source/extract.go | 124 ++- internal/lsp/source/fix.go | 2 + internal/lsp/source/source_test.go | 1 + .../extract/extract_function/extract_scope.go | 6 +- .../extract_function/extract_scope.go.golden | 10 +- .../extract/extract_method/extract_basic.go | 24 + .../extract_method/extract_basic.go.golden | 728 ++++++++++++++++++ internal/lsp/testdata/summary.txt.golden | 3 +- internal/lsp/tests/tests.go | 26 + 12 files changed, 964 insertions(+), 39 deletions(-) create mode 100644 internal/lsp/testdata/extract/extract_method/extract_basic.go create mode 100644 internal/lsp/testdata/extract/extract_method/extract_basic.go.golden diff --git a/internal/lsp/cmd/test/cmdtest.go b/internal/lsp/cmd/test/cmdtest.go index b63a92aece..2e9272611d 100644 --- a/internal/lsp/cmd/test/cmdtest.go +++ b/internal/lsp/cmd/test/cmdtest.go @@ -100,6 +100,10 @@ func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span //TODO: function extraction not supported on command line } +func (r *runner) MethodExtraction(t *testing.T, start span.Span, end span.Span) { + //TODO: function extraction not supported on command line +} + func (r *runner) AddImport(t *testing.T, uri span.URI, expectedImport string) { //TODO: import addition not supported on command line } diff --git a/internal/lsp/code_action.go b/internal/lsp/code_action.go index 1c5ad4d636..b58e954030 100644 --- a/internal/lsp/code_action.go +++ b/internal/lsp/code_action.go @@ -289,8 +289,8 @@ func extractionFixes(ctx context.Context, snapshot source.Snapshot, pkg source.P } puri := protocol.URIFromSpanURI(uri) var commands []protocol.Command - if _, ok, _ := source.CanExtractFunction(snapshot.FileSet(), srng, pgf.Src, pgf.File); ok { - cmd, err := command.NewApplyFixCommand("Extract to function", command.ApplyFixArgs{ + if _, ok, methodOk, _ := source.CanExtractFunction(snapshot.FileSet(), srng, pgf.Src, pgf.File); ok { + cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{ URI: puri, Fix: source.ExtractFunction, Range: rng, @@ -299,6 +299,17 @@ func extractionFixes(ctx context.Context, snapshot source.Snapshot, pkg source.P return nil, err } commands = append(commands, cmd) + if methodOk { + cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{ + URI: puri, + Fix: source.ExtractMethod, + Range: rng, + }) + if err != nil { + return nil, err + } + commands = append(commands, cmd) + } } if _, _, ok, _ := source.CanExtractVariable(srng, pgf.File); ok { cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{ diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 68c83f653d..f095489c79 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -583,7 +583,7 @@ func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span if err != nil { t.Fatal(err) } - actions, err := r.server.CodeAction(r.ctx, &protocol.CodeActionParams{ + actionsRaw, err := r.server.CodeAction(r.ctx, &protocol.CodeActionParams{ TextDocument: protocol.TextDocumentIdentifier{ URI: protocol.URIFromSpanURI(uri), }, @@ -595,6 +595,12 @@ func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span if err != nil { t.Fatal(err) } + var actions []protocol.CodeAction + for _, action := range actionsRaw { + if action.Command.Title == "Extract function" { + actions = append(actions, action) + } + } // Hack: We assume that we only get one code action per range. // TODO(rstambler): Support multiple code actions per test. if len(actions) == 0 || len(actions) > 1 { @@ -618,6 +624,58 @@ func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span } } +func (r *runner) MethodExtraction(t *testing.T, start span.Span, end span.Span) { + uri := start.URI() + m, err := r.data.Mapper(uri) + if err != nil { + t.Fatal(err) + } + spn := span.New(start.URI(), start.Start(), end.End()) + rng, err := m.Range(spn) + if err != nil { + t.Fatal(err) + } + actionsRaw, err := r.server.CodeAction(r.ctx, &protocol.CodeActionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: protocol.URIFromSpanURI(uri), + }, + Range: rng, + Context: protocol.CodeActionContext{ + Only: []protocol.CodeActionKind{"refactor.extract"}, + }, + }) + if err != nil { + t.Fatal(err) + } + var actions []protocol.CodeAction + for _, action := range actionsRaw { + if action.Command.Title == "Extract method" { + actions = append(actions, action) + } + } + // Hack: We assume that we only get one matching code action per range. + // TODO(rstambler): Support multiple code actions per test. + if len(actions) == 0 || len(actions) > 1 { + t.Fatalf("unexpected number of code actions, want 1, got %v", len(actions)) + } + _, err = r.server.ExecuteCommand(r.ctx, &protocol.ExecuteCommandParams{ + Command: actions[0].Command.Command, + Arguments: actions[0].Command.Arguments, + }) + if err != nil { + t.Fatal(err) + } + res := <-r.editRecv + for u, got := range res { + want := string(r.data.Golden("methodextraction_"+tests.SpanName(spn), u.Filename(), func() ([]byte, error) { + return []byte(got), nil + })) + if want != got { + t.Errorf("method extraction failed for %s:\n%s", u.Filename(), tests.Diff(t, want, got)) + } + } +} + func (r *runner) Definition(t *testing.T, spn span.Span, d tests.Definition) { sm, err := r.data.Mapper(d.Src.URI()) if err != nil { diff --git a/internal/lsp/source/extract.go b/internal/lsp/source/extract.go index 6450ba3612..4f0de5938e 100644 --- a/internal/lsp/source/extract.go +++ b/internal/lsp/source/extract.go @@ -139,11 +139,17 @@ func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast. // Possible collisions include other function and variable names. Returns the next index to check for prefix. func generateAvailableIdentifier(pos token.Pos, file *ast.File, path []ast.Node, info *types.Info, prefix string, idx int) (string, int) { scopes := CollectScopes(info, path, pos) + return generateIdentifier(idx, prefix, func(name string) bool { + return file.Scope.Lookup(name) != nil || !isValidName(name, scopes) + }) +} + +func generateIdentifier(idx int, prefix string, hasCollision func(string) bool) (string, int) { name := prefix if idx != 0 { name += fmt.Sprintf("%d", idx) } - for file.Scope.Lookup(name) != nil || !isValidName(name, scopes) { + for hasCollision(name) { idx++ name = fmt.Sprintf("%v%d", prefix, idx) } @@ -177,28 +183,42 @@ type returnVariable struct { zeroVal ast.Expr } +// extractMethod refactors the selected block of code into a new method. +func extractMethod(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { + return extractFunctionMethod(fset, rng, src, file, pkg, info, true) +} + // extractFunction refactors the selected block of code into a new function. +func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { + return extractFunctionMethod(fset, rng, src, file, pkg, info, false) +} + +// extractFunctionMethod refactors the selected block of code into a new function/method. // It also replaces the selected block of code with a call to the extracted // function. First, we manually adjust the selection range. We remove trailing // and leading whitespace characters to ensure the range is precisely bounded // by AST nodes. Next, we determine the variables that will be the parameters -// and return values of the extracted function. Lastly, we construct the call -// of the function and insert this call as well as the extracted function into +// and return values of the extracted function/method. Lastly, we construct the call +// of the function/method and insert this call as well as the extracted function/method into // their proper locations. -func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { - p, ok, err := CanExtractFunction(fset, rng, src, file) - if !ok { - return nil, fmt.Errorf("extractFunction: cannot extract %s: %v", +func extractFunctionMethod(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info, isMethod bool) (*analysis.SuggestedFix, error) { + errorPrefix := "extractFunction" + if isMethod { + errorPrefix = "extractMethod" + } + p, ok, methodOk, err := CanExtractFunction(fset, rng, src, file) + if (!ok && !isMethod) || (!methodOk && isMethod) { + return nil, fmt.Errorf("%s: cannot extract %s: %v", errorPrefix, fset.Position(rng.Start), err) } tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start fileScope := info.Scopes[file] if fileScope == nil { - return nil, fmt.Errorf("extractFunction: file scope is empty") + return nil, fmt.Errorf("%s: file scope is empty", errorPrefix) } pkgScope := fileScope.Parent() if pkgScope == nil { - return nil, fmt.Errorf("extractFunction: package scope is empty") + return nil, fmt.Errorf("%s: package scope is empty", errorPrefix) } // A return statement is non-nested if its parent node is equal to the parent node @@ -235,6 +255,25 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast. return nil, err } + var ( + receiverUsed bool + receiver *ast.Field + receiverName string + receiverObj types.Object + ) + if isMethod { + if outer == nil || outer.Recv == nil || len(outer.Recv.List) == 0 { + return nil, fmt.Errorf("%s: cannot extract need method receiver", errorPrefix) + } + receiver = outer.Recv.List[0] + if len(receiver.Names) == 0 || receiver.Names[0] == nil { + return nil, fmt.Errorf("%s: cannot extract need method receiver name", errorPrefix) + } + recvName := receiver.Names[0] + receiverName = recvName.Name + receiverObj = info.ObjectOf(recvName) + } + var ( params, returns []ast.Expr // used when calling the extracted function paramTypes, returnTypes []*ast.Field // used in the signature of the extracted function @@ -308,6 +347,11 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast. // extracted function. (1) it must be free (isFree), and (2) its first // use within the selection cannot be its own definition (isDefined). if v.free && !v.defined { + // Skip the selector for a method. + if isMethod && v.obj == receiverObj { + receiverUsed = true + continue + } params = append(params, identifier) paramTypes = append(paramTypes, &ast.Field{ Names: []*ast.Ident{identifier}, @@ -471,9 +515,17 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast. if canDefine { sym = token.DEFINE } - funName, _ := generateAvailableIdentifier(rng.Start, file, path, info, "newFunction", 0) + var name, funName string + if isMethod { + name = "newMethod" + // TODO(suzmue): generate a name that does not conflict for "newMethod". + funName = name + } else { + name = "newFunction" + funName, _ = generateAvailableIdentifier(rng.Start, file, path, info, name, 0) + } extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params, - append(returns, getNames(retVars)...), funName, sym) + append(returns, getNames(retVars)...), funName, sym, receiverName) // Build the extracted function. newFunc := &ast.FuncDecl{ @@ -484,6 +536,18 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast. }, Body: extractedBlock, } + if isMethod { + var names []*ast.Ident + if receiverUsed { + names = append(names, ast.NewIdent(receiverName)) + } + newFunc.Recv = &ast.FieldList{ + List: []*ast.Field{{ + Names: names, + Type: receiver.Type, + }}, + } + } // Create variable declarations for any identifiers that need to be initialized prior to // calling the extracted function. We do not manually initialize variables if every return @@ -844,24 +908,24 @@ type fnExtractParams struct { // CanExtractFunction reports whether the code in the given range can be // extracted to a function. -func CanExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File) (*fnExtractParams, bool, error) { +func CanExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File) (*fnExtractParams, bool, bool, error) { if rng.Start == rng.End { - return nil, false, fmt.Errorf("start and end are equal") + return nil, false, false, fmt.Errorf("start and end are equal") } tok := fset.File(file.Pos()) if tok == nil { - return nil, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos())) + return nil, false, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos())) } rng = adjustRangeForWhitespace(rng, tok, src) path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) if len(path) == 0 { - return nil, false, fmt.Errorf("no path enclosing interval") + return nil, false, false, fmt.Errorf("no path enclosing interval") } // Node that encloses the selection must be a statement. // TODO: Support function extraction for an expression. _, ok := path[0].(ast.Stmt) if !ok { - return nil, false, fmt.Errorf("node is not a statement") + return nil, false, false, fmt.Errorf("node is not a statement") } // Find the function declaration that encloses the selection. @@ -873,7 +937,7 @@ func CanExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *a } } if outer == nil { - return nil, false, fmt.Errorf("no enclosing function") + return nil, false, false, fmt.Errorf("no enclosing function") } // Find the nodes at the start and end of the selection. @@ -893,7 +957,7 @@ func CanExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *a return n.Pos() <= rng.End }) if start == nil || end == nil { - return nil, false, fmt.Errorf("range does not map to AST nodes") + return nil, false, false, fmt.Errorf("range does not map to AST nodes") } return &fnExtractParams{ tok: tok, @@ -901,7 +965,7 @@ func CanExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *a rng: rng, outer: outer, start: start, - }, true, nil + }, true, outer.Recv != nil, nil } // objUsed checks if the object is used within the range. It returns the first @@ -1089,13 +1153,22 @@ func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object] // generateFuncCall constructs a call expression for the extracted function, described by the // given parameters and return variables. -func generateFuncCall(hasNonNestedReturn, hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token) ast.Node { +func generateFuncCall(hasNonNestedReturn, hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token, selector string) ast.Node { var replace ast.Node - if hasReturnVals { - callExpr := &ast.CallExpr{ - Fun: ast.NewIdent(name), + callExpr := &ast.CallExpr{ + Fun: ast.NewIdent(name), + Args: params, + } + if selector != "" { + callExpr = &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(selector), + Sel: ast.NewIdent(name), + }, Args: params, } + } + if hasReturnVals { if hasNonNestedReturn { // Create a return statement that returns the result of the function call. replace = &ast.ReturnStmt{ @@ -1111,10 +1184,7 @@ func generateFuncCall(hasNonNestedReturn, hasReturnVals bool, params, returns [] } } } else { - replace = &ast.CallExpr{ - Fun: ast.NewIdent(name), - Args: params, - } + replace = callExpr } return replace } diff --git a/internal/lsp/source/fix.go b/internal/lsp/source/fix.go index 6a012396cc..3308aee0c2 100644 --- a/internal/lsp/source/fix.go +++ b/internal/lsp/source/fix.go @@ -32,6 +32,7 @@ const ( UndeclaredName = "undeclared_name" ExtractVariable = "extract_variable" ExtractFunction = "extract_function" + ExtractMethod = "extract_method" ) // suggestedFixes maps a suggested fix command id to its handler. @@ -40,6 +41,7 @@ var suggestedFixes = map[string]SuggestedFixFunc{ UndeclaredName: undeclaredname.SuggestedFix, ExtractVariable: extractVariable, ExtractFunction: extractFunction, + ExtractMethod: extractMethod, } func SuggestedFixFromCommand(cmd protocol.Command, kind protocol.CodeActionKind) SuggestedFix { diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index c09b2feada..f1ab3ff4c2 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -935,6 +935,7 @@ func (r *runner) Link(t *testing.T, uri span.URI, wantLinks []tests.Link) {} func (r *runner) SuggestedFix(t *testing.T, spn span.Span, actionKinds []string, expectedActions int) { } func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span) {} +func (r *runner) MethodExtraction(t *testing.T, start span.Span, end span.Span) {} func (r *runner) CodeLens(t *testing.T, uri span.URI, want []protocol.CodeLens) {} func (r *runner) AddImport(t *testing.T, uri span.URI, expectedImport string) {} diff --git a/internal/lsp/testdata/extract/extract_function/extract_scope.go b/internal/lsp/testdata/extract/extract_function/extract_scope.go index 73d74192e2..6cc141fd11 100644 --- a/internal/lsp/testdata/extract/extract_function/extract_scope.go +++ b/internal/lsp/testdata/extract/extract_function/extract_scope.go @@ -1,10 +1,10 @@ package extract func _() { - fn0 := 1 - a := fn0 //@extractfunc("a", "fn0") + newFunction := 1 + a := newFunction //@extractfunc("a", "newFunction") } -func fn1() int { +func newFunction1() int { return 1 } diff --git a/internal/lsp/testdata/extract/extract_function/extract_scope.go.golden b/internal/lsp/testdata/extract/extract_function/extract_scope.go.golden index 1bb4e61fe4..a4803b4fe3 100644 --- a/internal/lsp/testdata/extract/extract_function/extract_scope.go.golden +++ b/internal/lsp/testdata/extract/extract_function/extract_scope.go.golden @@ -2,15 +2,15 @@ package extract func _() { - fn0 := 1 - newFunction(fn0) //@extractfunc("a", "fn0") + newFunction := 1 + newFunction2(newFunction) //@extractfunc("a", "newFunction") } -func newFunction(fn0 int) { - a := fn0 +func newFunction2(newFunction int) { + a := newFunction } -func fn1() int { +func newFunction1() int { return 1 } diff --git a/internal/lsp/testdata/extract/extract_method/extract_basic.go b/internal/lsp/testdata/extract/extract_method/extract_basic.go new file mode 100644 index 0000000000..c9a8d9dce3 --- /dev/null +++ b/internal/lsp/testdata/extract/extract_method/extract_basic.go @@ -0,0 +1,24 @@ +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} diff --git a/internal/lsp/testdata/extract/extract_method/extract_basic.go.golden b/internal/lsp/testdata/extract/extract_method/extract_basic.go.golden new file mode 100644 index 0000000000..eab22a673c --- /dev/null +++ b/internal/lsp/testdata/extract/extract_method/extract_basic.go.golden @@ -0,0 +1,728 @@ +-- functionextraction_extract_basic_13_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := newFunction(a) //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func newFunction(a *A) int { + sum := a.x + a.y + return sum +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- functionextraction_extract_basic_14_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return newFunction(sum) //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func newFunction(sum int) int { + return sum +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- functionextraction_extract_basic_18_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return newFunction(a) //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func newFunction(a A) bool { + return a.x < a.y +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- functionextraction_extract_basic_22_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := newFunction(a) //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func newFunction(a A) int { + sum := a.x + a.y + return sum +} + +-- functionextraction_extract_basic_23_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return newFunction(sum) //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func newFunction(sum int) int { + return sum +} + +-- functionextraction_extract_basic_9_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return newFunction(a) //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func newFunction(a *A) bool { + return a.x < a.y +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- functionextraction_extract_method_13_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := newFunction(a) //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func newFunction(a *A) int { + sum := a.x + a.y + return sum +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- functionextraction_extract_method_14_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return newFunction(sum) //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func newFunction(sum int) int { + return sum +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- functionextraction_extract_method_18_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return newFunction(a) //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func newFunction(a A) bool { + return a.x < a.y +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- functionextraction_extract_method_22_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := newFunction(a) //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func newFunction(a A) int { + sum := a.x + a.y + return sum +} + +-- functionextraction_extract_method_23_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return newFunction(sum) //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func newFunction(sum int) int { + return sum +} + +-- functionextraction_extract_method_9_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return newFunction(a) //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func newFunction(a *A) bool { + return a.x < a.y +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- methodextraction_extract_basic_13_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.newMethod() //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a *A) newMethod() int { + sum := a.x + a.y + return sum +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- methodextraction_extract_basic_14_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return a.newMethod(sum) //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (*A) newMethod(sum int) int { + return sum +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- methodextraction_extract_basic_18_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.newMethod() //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) newMethod() bool { + return a.x < a.y +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- methodextraction_extract_basic_22_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.newMethod() //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) newMethod() int { + sum := a.x + a.y + return sum +} + +-- methodextraction_extract_basic_23_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return a.newMethod(sum) //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (A) newMethod(sum int) int { + return sum +} + +-- methodextraction_extract_basic_9_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.newMethod() //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) newMethod() bool { + return a.x < a.y +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- methodextraction_extract_method_13_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.newMethod() //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a *A) newMethod() int { + sum := a.x + a.y + return sum +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- methodextraction_extract_method_14_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return a.newMethod(sum) //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (*A) newMethod(sum int) int { + return sum +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- methodextraction_extract_method_18_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.newMethod() //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) newMethod() bool { + return a.x < a.y +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +-- methodextraction_extract_method_22_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.newMethod() //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) newMethod() int { + sum := a.x + a.y + return sum +} + +-- methodextraction_extract_method_23_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return a.newMethod(sum) //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (A) newMethod(sum int) int { + return sum +} + +-- methodextraction_extract_method_9_2 -- +package extract + +type A struct { + x int + y int +} + +func (a *A) XLessThanYP() bool { + return a.newMethod() //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a *A) newMethod() bool { + return a.x < a.y +} + +func (a *A) AddP() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + +func (a A) XLessThanY() bool { + return a.x < a.y //@extractmethod("return", "a.y"),extractfunc("return", "a.y") +} + +func (a A) Add() int { + sum := a.x + a.y //@extractmethod("sum", "a.y"),extractfunc("sum", "a.y") + return sum //@extractmethod("return", "sum"),extractfunc("return", "sum") +} + diff --git a/internal/lsp/testdata/summary.txt.golden b/internal/lsp/testdata/summary.txt.golden index edbb4fa568..d0e1bc7979 100644 --- a/internal/lsp/testdata/summary.txt.golden +++ b/internal/lsp/testdata/summary.txt.golden @@ -14,7 +14,8 @@ FormatCount = 6 ImportCount = 8 SemanticTokenCount = 3 SuggestedFixCount = 40 -FunctionExtractionCount = 18 +FunctionExtractionCount = 24 +MethodExtractionCount = 6 DefinitionsCount = 95 TypeDefinitionsCount = 18 HighlightsCount = 69 diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index f942ced3bc..d5db454b73 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -70,6 +70,7 @@ type Imports []span.Span type SemanticTokens []span.Span type SuggestedFixes map[span.Span][]string type FunctionExtractions map[span.Span]span.Span +type MethodExtractions map[span.Span]span.Span type Definitions map[span.Span]Definition type Implementations map[span.Span][]span.Span type Highlights map[span.Span][]span.Span @@ -104,6 +105,7 @@ type Data struct { SemanticTokens SemanticTokens SuggestedFixes SuggestedFixes FunctionExtractions FunctionExtractions + MethodExtractions MethodExtractions Definitions Definitions Implementations Implementations Highlights Highlights @@ -147,6 +149,7 @@ type Tests interface { SemanticTokens(*testing.T, span.Span) SuggestedFix(*testing.T, span.Span, []string, int) FunctionExtraction(*testing.T, span.Span, span.Span) + MethodExtraction(*testing.T, span.Span, span.Span) Definition(*testing.T, span.Span, Definition) Implementation(*testing.T, span.Span, []span.Span) Highlight(*testing.T, span.Span, []span.Span) @@ -298,6 +301,7 @@ func load(t testing.TB, mode string, dir string) *Data { PrepareRenames: make(PrepareRenames), SuggestedFixes: make(SuggestedFixes), FunctionExtractions: make(FunctionExtractions), + MethodExtractions: make(MethodExtractions), Symbols: make(Symbols), symbolsChildren: make(SymbolsChildren), symbolInformation: make(SymbolInformation), @@ -465,6 +469,7 @@ func load(t testing.TB, mode string, dir string) *Data { "link": datum.collectLinks, "suggestedfix": datum.collectSuggestedFixes, "extractfunc": datum.collectFunctionExtractions, + "extractmethod": datum.collectMethodExtractions, "incomingcalls": datum.collectIncomingCalls, "outgoingcalls": datum.collectOutgoingCalls, "addimport": datum.collectAddImports, @@ -675,6 +680,20 @@ func Run(t *testing.T, tests Tests, data *Data) { } }) + t.Run("MethodExtraction", func(t *testing.T) { + t.Helper() + for start, end := range data.MethodExtractions { + // Check if we should skip this spn if the -modfile flag is not available. + if shouldSkip(data, start.URI()) { + continue + } + t.Run(SpanName(start), func(t *testing.T) { + t.Helper() + tests.MethodExtraction(t, start, end) + }) + } + }) + t.Run("Definition", func(t *testing.T) { t.Helper() for spn, d := range data.Definitions { @@ -895,6 +914,7 @@ func checkData(t *testing.T, data *Data) { fmt.Fprintf(buf, "SemanticTokenCount = %v\n", len(data.SemanticTokens)) fmt.Fprintf(buf, "SuggestedFixCount = %v\n", len(data.SuggestedFixes)) fmt.Fprintf(buf, "FunctionExtractionCount = %v\n", len(data.FunctionExtractions)) + fmt.Fprintf(buf, "MethodExtractionCount = %v\n", len(data.MethodExtractions)) fmt.Fprintf(buf, "DefinitionsCount = %v\n", definitionCount) fmt.Fprintf(buf, "TypeDefinitionsCount = %v\n", typeDefinitionCount) fmt.Fprintf(buf, "HighlightsCount = %v\n", len(data.Highlights)) @@ -1128,6 +1148,12 @@ func (data *Data) collectFunctionExtractions(start span.Span, end span.Span) { } } +func (data *Data) collectMethodExtractions(start span.Span, end span.Span) { + if _, ok := data.MethodExtractions[start]; !ok { + data.MethodExtractions[start] = end + } +} + func (data *Data) collectDefinitions(src, target span.Span) { data.Definitions[src] = Definition{ Src: src,