From 72051f796149b570577cd7e666ea968b19e0745d Mon Sep 17 00:00:00 2001 From: Heschi Kreinick Date: Thu, 16 Jul 2020 17:37:12 -0400 Subject: [PATCH] internal/lsp: pass snapshot/view to memoize.Functions Due to the runtime's inability to collect cycles involving finalizers, we can't close over handles in memoize.Functions without causing memory leaks. Up until now we've dealt with that by closing over all the bits of the snapshot that we want, but it distorts the design of all the code used in the Functions. We can solve the problem another way: instead of closing over the snapshot/view, we can force the caller to pass it in. This is somewhat scary: there is no requirement that the argument matches the data that we're working with. But the reality is that this is not a new problem: the Function used to calculate a cache value is not necessarily the one that the caller expects. As long as the cache key fully identifies all the inputs to the Function, the output should be correct. And since the caller used the snapshot/view to calculate that cache key, it should always be safe to pass in that snapshot/view. If it's not, then we already had a bug. The Arg type in memoize is clumsy, but I thought it would be nice to have at least a little bit of type safety. I'm open to suggestions. Change-Id: I23f546638b0c66a4698620a986949087211f4762 Reviewed-on: https://go-review.googlesource.com/c/tools/+/244019 Reviewed-by: Robert Findley Reviewed-by: Rebecca Stambler --- internal/lsp/cache/analysis.go | 28 ++++----- internal/lsp/cache/check.go | 31 +++++----- internal/lsp/cache/errors.go | 11 ++-- internal/lsp/cache/mod.go | 76 +++++++++++------------- internal/lsp/cache/mod_tidy.go | 30 +++++----- internal/lsp/cache/parse.go | 49 ++++++++------- internal/lsp/cache/snapshot.go | 7 ++- internal/lsp/cache/view.go | 19 +++--- internal/lsp/code_action.go | 6 +- internal/lsp/diagnostics.go | 2 +- internal/lsp/link.go | 10 ++-- internal/lsp/mod/code_lens.go | 4 +- internal/lsp/mod/diagnostics.go | 4 +- internal/lsp/mod/format.go | 2 +- internal/lsp/mod/hover.go | 4 +- internal/lsp/source/code_lens.go | 4 +- internal/lsp/source/completion_format.go | 2 +- internal/lsp/source/diagnostics.go | 2 +- internal/lsp/source/folding_range.go | 2 +- internal/lsp/source/format.go | 6 +- internal/lsp/source/highlight.go | 2 +- internal/lsp/source/identifier.go | 2 +- internal/lsp/source/implementation.go | 4 +- internal/lsp/source/references.go | 2 +- internal/lsp/source/types_format.go | 6 +- internal/lsp/source/util.go | 4 +- internal/lsp/source/view.go | 18 +++--- internal/lsp/source/workspace_symbol.go | 2 +- internal/memoize/memoize.go | 14 +++-- internal/memoize/memoize_test.go | 8 +-- 30 files changed, 183 insertions(+), 178 deletions(-) diff --git a/internal/lsp/cache/analysis.go b/internal/lsp/cache/analysis.go index ba5fb39e6c..058509b1d6 100644 --- a/internal/lsp/cache/analysis.go +++ b/internal/lsp/cache/analysis.go @@ -7,7 +7,6 @@ package cache import ( "context" "fmt" - "go/token" "go/types" "reflect" "sort" @@ -41,7 +40,7 @@ func (s *snapshot) Analyze(ctx context.Context, id string, analyzers ...*analysi var results []*source.Error for _, ah := range roots { - diagnostics, _, err := ah.analyze(ctx) + diagnostics, _, err := ah.analyze(ctx, s) if err != nil { return nil, err } @@ -93,7 +92,7 @@ func (s *snapshot) actionHandle(ctx context.Context, id packageID, a *analysis.A if len(ph.key) == 0 { return nil, errors.Errorf("no key for PackageHandle %s", id) } - pkg, err := ph.check(ctx) + pkg, err := ph.check(ctx, s) if err != nil { return nil, err } @@ -133,17 +132,16 @@ func (s *snapshot) actionHandle(ctx context.Context, id packageID, a *analysis.A } } - fset := s.view.session.cache.fset - - h := s.view.session.cache.store.Bind(buildActionKey(a, ph), func(ctx context.Context) interface{} { + h := s.view.session.cache.store.Bind(buildActionKey(a, ph), func(ctx context.Context, arg memoize.Arg) interface{} { + snapshot := arg.(*snapshot) // Analyze dependencies first. - results, err := execAll(ctx, deps) + results, err := execAll(ctx, snapshot, deps) if err != nil { return &actionData{ err: err, } } - return runAnalysis(ctx, fset, a, pkg, results) + return runAnalysis(ctx, snapshot, a, pkg, results) }) act.handle = h @@ -151,8 +149,8 @@ func (s *snapshot) actionHandle(ctx context.Context, id packageID, a *analysis.A return act, nil } -func (act *actionHandle) analyze(ctx context.Context) ([]*source.Error, interface{}, error) { - v, err := act.handle.Get(ctx) +func (act *actionHandle) analyze(ctx context.Context, snapshot *snapshot) ([]*source.Error, interface{}, error) { + v, err := act.handle.Get(ctx, snapshot) if v == nil { return nil, nil, err } @@ -174,7 +172,7 @@ func (act *actionHandle) String() string { return fmt.Sprintf("%s@%s", act.analyzer, act.pkg.PkgPath()) } -func execAll(ctx context.Context, actions []*actionHandle) (map[*actionHandle]*actionData, error) { +func execAll(ctx context.Context, snapshot *snapshot, actions []*actionHandle) (map[*actionHandle]*actionData, error) { var mu sync.Mutex results := make(map[*actionHandle]*actionData) @@ -182,7 +180,7 @@ func execAll(ctx context.Context, actions []*actionHandle) (map[*actionHandle]*a for _, act := range actions { act := act g.Go(func() error { - v, err := act.handle.Get(ctx) + v, err := act.handle.Get(ctx, snapshot) if err != nil { return err } @@ -201,7 +199,7 @@ func execAll(ctx context.Context, actions []*actionHandle) (map[*actionHandle]*a return results, g.Wait() } -func runAnalysis(ctx context.Context, fset *token.FileSet, analyzer *analysis.Analyzer, pkg *pkg, deps map[*actionHandle]*actionData) (data *actionData) { +func runAnalysis(ctx context.Context, snapshot *snapshot, analyzer *analysis.Analyzer, pkg *pkg, deps map[*actionHandle]*actionData) (data *actionData) { data = &actionData{ objectFacts: make(map[objectFactKey]analysis.Fact), packageFacts: make(map[packageFactKey]analysis.Fact), @@ -251,7 +249,7 @@ func runAnalysis(ctx context.Context, fset *token.FileSet, analyzer *analysis.An // Run the analysis. pass := &analysis.Pass{ Analyzer: analyzer, - Fset: fset, + Fset: snapshot.view.session.cache.fset, Files: pkg.GetSyntax(), Pkg: pkg.GetTypes(), TypesInfo: pkg.GetTypesInfo(), @@ -343,7 +341,7 @@ func runAnalysis(ctx context.Context, fset *token.FileSet, analyzer *analysis.An } for _, diag := range diagnostics { - srcErr, err := sourceError(ctx, fset, pkg, diag) + srcErr, err := sourceError(ctx, snapshot, pkg, diag) if err != nil { event.Error(ctx, "unable to compute analysis error position", err, tag.Category.Of(diag.Category), tag.Package.Of(pkg.ID())) continue diff --git a/internal/lsp/cache/check.go b/internal/lsp/cache/check.go index 5e4a0520a9..2010a47260 100644 --- a/internal/lsp/cache/check.go +++ b/internal/lsp/cache/check.go @@ -9,7 +9,6 @@ import ( "context" "fmt" "go/ast" - "go/token" "go/types" "path" "sort" @@ -83,18 +82,19 @@ func (s *snapshot) buildPackageHandle(ctx context.Context, id packageID, mode so goFiles := ph.goFiles compiledGoFiles := ph.compiledGoFiles key := ph.key - fset := s.view.session.cache.fset - h := s.view.session.cache.store.Bind(key, func(ctx context.Context) interface{} { + h := s.view.session.cache.store.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} { + snapshot := arg.(*snapshot) + // Begin loading the direct dependencies, in parallel. for _, dep := range deps { go func(dep *packageHandle) { - dep.check(ctx) + dep.check(ctx, snapshot) }(dep) } data := &packageData{} - data.pkg, data.err = typeCheck(ctx, fset, m, mode, goFiles, compiledGoFiles, deps) + data.pkg, data.err = typeCheck(ctx, snapshot, m, mode, goFiles, compiledGoFiles, deps) return data }) @@ -193,12 +193,12 @@ func hashConfig(config *packages.Config) string { return hashContents(b.Bytes()) } -func (ph *packageHandle) Check(ctx context.Context) (source.Package, error) { - return ph.check(ctx) +func (ph *packageHandle) Check(ctx context.Context, s source.Snapshot) (source.Package, error) { + return ph.check(ctx, s.(*snapshot)) } -func (ph *packageHandle) check(ctx context.Context) (*pkg, error) { - v, err := ph.handle.Get(ctx) +func (ph *packageHandle) check(ctx context.Context, s *snapshot) (*pkg, error) { + v, err := ph.handle.Get(ctx, s) if err != nil { return nil, err } @@ -239,7 +239,7 @@ func (s *snapshot) parseGoHandles(ctx context.Context, files []span.URI, mode so return pghs, nil } -func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode source.ParseMode, goFiles, compiledGoFiles []*parseGoHandle, deps map[packagePath]*packageHandle) (*pkg, error) { +func typeCheck(ctx context.Context, snapshot *snapshot, m *metadata, mode source.ParseMode, goFiles, compiledGoFiles []*parseGoHandle, deps map[packagePath]*packageHandle) (*pkg, error) { ctx, done := event.Start(ctx, "cache.importer.typeCheck", tag.Package.Of(string(m.id))) defer done() @@ -248,6 +248,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc rawErrors = append(rawErrors, err) } + fset := snapshot.view.session.cache.fset pkg := &pkg{ m: m, mode: mode, @@ -278,7 +279,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc wg.Add(1) go func(i int, ph *parseGoHandle) { defer wg.Done() - data, err := ph.parse(ctx) + data, err := ph.parse(ctx, snapshot.view) if err != nil { actualErrors[i] = err return @@ -294,7 +295,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc wg.Add(1) // We need to parse the non-compiled go files, but we don't care about their errors. go func(ph source.ParseGoHandle) { - ph.Parse(ctx) + ph.Parse(ctx, snapshot.view) wg.Done() }(ph) } @@ -325,7 +326,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc // Try to attach errors messages to the file as much as possible. var found bool for _, e := range rawErrors { - srcErr, err := sourceError(ctx, fset, pkg, e) + srcErr, err := sourceError(ctx, snapshot, pkg, e) if err != nil { continue } @@ -361,7 +362,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc if !isValidImport(m.pkgPath, dep.m.pkgPath) { return nil, errors.Errorf("invalid use of internal package %s", pkgPath) } - depPkg, err := dep.check(ctx) + depPkg, err := dep.check(ctx, snapshot) if err != nil { return nil, err } @@ -386,7 +387,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc // We don't care about a package's errors unless we have parsed it in full. if mode == source.ParseFull { for _, e := range rawErrors { - srcErr, err := sourceError(ctx, fset, pkg, e) + srcErr, err := sourceError(ctx, snapshot, pkg, e) if err != nil { event.Error(ctx, "unable to compute error positions", err, tag.Package.Of(pkg.ID())) continue diff --git a/internal/lsp/cache/errors.go b/internal/lsp/cache/errors.go index a2d034fd0c..89da166523 100644 --- a/internal/lsp/cache/errors.go +++ b/internal/lsp/cache/errors.go @@ -25,7 +25,8 @@ import ( errors "golang.org/x/xerrors" ) -func sourceError(ctx context.Context, fset *token.FileSet, pkg *pkg, e interface{}) (*source.Error, error) { +func sourceError(ctx context.Context, snapshot *snapshot, pkg *pkg, e interface{}) (*source.Error, error) { + fset := snapshot.view.session.cache.fset var ( spn span.Span err error @@ -38,7 +39,7 @@ func sourceError(ctx context.Context, fset *token.FileSet, pkg *pkg, e interface case packages.Error: kind = toSourceErrorKind(e.Kind) var ok bool - if msg, spn, ok = parseGoListImportCycleError(ctx, fset, e, pkg); ok { + if msg, spn, ok = parseGoListImportCycleError(ctx, snapshot, e, pkg); ok { kind = source.TypeError break } @@ -254,7 +255,7 @@ func parseGoListError(input string) span.Span { return span.Parse(input[:msgIndex]) } -func parseGoListImportCycleError(ctx context.Context, fset *token.FileSet, e packages.Error, pkg *pkg) (string, span.Span, bool) { +func parseGoListImportCycleError(ctx context.Context, snapshot *snapshot, e packages.Error, pkg *pkg) (string, span.Span, bool) { re := regexp.MustCompile(`(.*): import stack: \[(.+)\]`) matches := re.FindStringSubmatch(strings.TrimSpace(e.Msg)) if len(matches) < 3 { @@ -270,14 +271,14 @@ func parseGoListImportCycleError(ctx context.Context, fset *token.FileSet, e pac // Imports have quotation marks around them. circImp := strconv.Quote(importList[1]) for _, ph := range pkg.compiledGoFiles { - fh, _, _, _, err := ph.Parse(ctx) + fh, _, _, _, err := ph.Parse(ctx, snapshot.view) if err != nil { continue } // Search file imports for the import that is causing the import cycle. for _, imp := range fh.Imports { if imp.Path.Value == circImp { - spn, err := span.NewRange(fset, imp.Pos(), imp.End()).Span() + spn, err := span.NewRange(snapshot.view.session.cache.fset, imp.Pos(), imp.End()).Span() if err != nil { return msg, span.Span{}, false } diff --git a/internal/lsp/cache/mod.go b/internal/lsp/cache/mod.go index ef32f2ce32..9476efaf47 100644 --- a/internal/lsp/cache/mod.go +++ b/internal/lsp/cache/mod.go @@ -54,8 +54,8 @@ func (mh *parseModHandle) Sum() source.FileHandle { return mh.sum } -func (mh *parseModHandle) Parse(ctx context.Context) (*modfile.File, *protocol.ColumnMapper, []source.Error, error) { - v, err := mh.handle.Get(ctx) +func (mh *parseModHandle) Parse(ctx context.Context, s source.Snapshot) (*modfile.File, *protocol.ColumnMapper, []source.Error, error) { + v, err := mh.handle.Get(ctx, s.(*snapshot)) if err != nil { return nil, nil, nil, err } @@ -67,7 +67,7 @@ func (s *snapshot) ParseModHandle(ctx context.Context, modFH source.FileHandle) if handle := s.getModHandle(modFH.URI()); handle != nil { return handle, nil } - h := s.view.session.cache.store.Bind(modFH.Identity().String(), func(ctx context.Context) interface{} { + h := s.view.session.cache.store.Bind(modFH.Identity().String(), func(ctx context.Context, _ memoize.Arg) interface{} { _, done := event.Start(ctx, "cache.ParseModHandle", tag.URI.Of(modFH.URI())) defer done() @@ -187,8 +187,6 @@ const ( type modWhyHandle struct { handle *memoize.Handle - - pmh source.ParseModHandle } type modWhyData struct { @@ -199,8 +197,8 @@ type modWhyData struct { err error } -func (mwh *modWhyHandle) Why(ctx context.Context) (map[string]string, error) { - v, err := mwh.handle.Get(ctx) +func (mwh *modWhyHandle) Why(ctx context.Context, s source.Snapshot) (map[string]string, error) { + v, err := mwh.handle.Get(ctx, s.(*snapshot)) if err != nil { return nil, err } @@ -216,26 +214,26 @@ func (s *snapshot) ModWhyHandle(ctx context.Context) (source.ModWhyHandle, error if err != nil { return nil, err } - pmh, err := s.ParseModHandle(ctx, fh) - if err != nil { - return nil, err - } - var ( - cfg = s.config(ctx) - tmpMod = s.view.tmpMod - ) + cfg := s.config(ctx) key := modKey{ sessionID: s.view.session.id, - cfg: hashConfig(cfg), - mod: pmh.Mod().Identity().String(), + cfg: hashConfig(s.config(ctx)), + mod: fh.Identity().String(), view: s.view.root.Filename(), verb: why, } - h := s.view.session.cache.store.Bind(key, func(ctx context.Context) interface{} { - ctx, done := event.Start(ctx, "cache.ModHandle", tag.URI.Of(pmh.Mod().URI())) + h := s.view.session.cache.store.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} { + ctx, done := event.Start(ctx, "cache.ModWhyHandle", tag.URI.Of(fh.URI())) defer done() - parsed, _, _, err := pmh.Parse(ctx) + snapshot := arg.(*snapshot) + + pmh, err := snapshot.ParseModHandle(ctx, fh) + if err != nil { + return &modWhyData{err: err} + } + + parsed, _, _, err := pmh.Parse(ctx, snapshot) if err != nil { return &modWhyData{err: err} } @@ -248,7 +246,7 @@ func (s *snapshot) ModWhyHandle(ctx context.Context) (source.ModWhyHandle, error for _, req := range parsed.Require { args = append(args, req.Mod.Path) } - _, stdout, err := runGoCommand(ctx, cfg, pmh, tmpMod, "mod", args) + _, stdout, err := runGoCommand(ctx, cfg, pmh, snapshot.view.tmpMod, "mod", args) if err != nil { return &modWhyData{err: err} } @@ -268,15 +266,12 @@ func (s *snapshot) ModWhyHandle(ctx context.Context) (source.ModWhyHandle, error defer s.mu.Unlock() s.modWhyHandle = &modWhyHandle{ handle: h, - pmh: pmh, } return s.modWhyHandle, nil } type modUpgradeHandle struct { handle *memoize.Handle - - pmh source.ParseModHandle } type modUpgradeData struct { @@ -286,8 +281,8 @@ type modUpgradeData struct { err error } -func (muh *modUpgradeHandle) Upgrades(ctx context.Context) (map[string]string, error) { - v, err := muh.handle.Get(ctx) +func (muh *modUpgradeHandle) Upgrades(ctx context.Context, s source.Snapshot) (map[string]string, error) { + v, err := muh.handle.Get(ctx, s.(*snapshot)) if v == nil { return nil, err } @@ -303,26 +298,26 @@ func (s *snapshot) ModUpgradeHandle(ctx context.Context) (source.ModUpgradeHandl if err != nil { return nil, err } - pmh, err := s.ParseModHandle(ctx, fh) - if err != nil { - return nil, err - } - var ( - cfg = s.config(ctx) - tmpMod = s.view.tmpMod - ) + cfg := s.config(ctx) key := modKey{ sessionID: s.view.session.id, cfg: hashConfig(cfg), - mod: pmh.Mod().Identity().String(), + mod: fh.Identity().String(), view: s.view.root.Filename(), verb: upgrade, } - h := s.view.session.cache.store.Bind(key, func(ctx context.Context) interface{} { - ctx, done := event.Start(ctx, "cache.ModUpgradeHandle", tag.URI.Of(pmh.Mod().URI())) + h := s.view.session.cache.store.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} { + ctx, done := event.Start(ctx, "cache.ModUpgradeHandle", tag.URI.Of(fh.URI())) defer done() - parsed, _, _, err := pmh.Parse(ctx) + snapshot := arg.(*snapshot) + + pmh, err := s.ParseModHandle(ctx, fh) + if err != nil { + return &modUpgradeData{err: err} + } + + parsed, _, _, err := pmh.Parse(ctx, snapshot) if err != nil { return &modUpgradeData{err: err} } @@ -333,12 +328,12 @@ func (s *snapshot) ModUpgradeHandle(ctx context.Context) (source.ModUpgradeHandl // Run "go list -mod readonly -u -m all" to be able to see which deps can be // upgraded without modifying mod file. args := []string{"-u", "-m", "all"} - if !tmpMod || containsVendor(pmh.Mod().URI()) { + if !snapshot.view.tmpMod || containsVendor(pmh.Mod().URI()) { // Use -mod=readonly if the module contains a vendor directory // (see golang/go#38711). args = append([]string{"-mod", "readonly"}, args...) } - _, stdout, err := runGoCommand(ctx, cfg, pmh, tmpMod, "list", args) + _, stdout, err := runGoCommand(ctx, cfg, pmh, snapshot.view.tmpMod, "list", args) if err != nil { return &modUpgradeData{err: err} } @@ -370,7 +365,6 @@ func (s *snapshot) ModUpgradeHandle(ctx context.Context) (source.ModUpgradeHandl defer s.mu.Unlock() s.modUpgradeHandle = &modUpgradeHandle{ handle: h, - pmh: pmh, } return s.modUpgradeHandle, nil } diff --git a/internal/lsp/cache/mod_tidy.go b/internal/lsp/cache/mod_tidy.go index 566babbb8d..853757f9c7 100644 --- a/internal/lsp/cache/mod_tidy.go +++ b/internal/lsp/cache/mod_tidy.go @@ -8,7 +8,6 @@ import ( "context" "fmt" "go/ast" - "go/token" "io/ioutil" "sort" "strconv" @@ -56,8 +55,8 @@ func (mth *modTidyHandle) ParseModHandle() source.ParseModHandle { return mth.pmh } -func (mth *modTidyHandle) Tidy(ctx context.Context) ([]source.Error, error) { - v, err := mth.handle.Get(ctx) +func (mth *modTidyHandle) Tidy(ctx context.Context, s source.Snapshot) ([]source.Error, error) { + v, err := mth.handle.Get(ctx, s.(*snapshot)) if err != nil { return nil, err } @@ -65,8 +64,8 @@ func (mth *modTidyHandle) Tidy(ctx context.Context) ([]source.Error, error) { return data.diagnostics, data.err } -func (mth *modTidyHandle) TidiedContent(ctx context.Context) ([]byte, error) { - v, err := mth.handle.Get(ctx) +func (mth *modTidyHandle) TidiedContent(ctx context.Context, s source.Snapshot) ([]byte, error) { + v, err := mth.handle.Get(ctx, s.(*snapshot)) if err != nil { return nil, err } @@ -98,7 +97,7 @@ func (s *snapshot) ModTidyHandle(ctx context.Context) (source.ModTidyHandle, err } var workspacePkgs []source.Package for _, ph := range wsPhs { - pkg, err := ph.Check(ctx) + pkg, err := ph.Check(ctx, s) if err != nil { return nil, err } @@ -117,7 +116,6 @@ func (s *snapshot) ModTidyHandle(ctx context.Context) (source.ModTidyHandle, err modURI = s.view.modURI cfg = s.config(ctx) options = s.view.Options() - fset = s.view.session.cache.fset ) key := modTidyKey{ sessionID: s.view.session.id, @@ -127,11 +125,13 @@ func (s *snapshot) ModTidyHandle(ctx context.Context) (source.ModTidyHandle, err gomod: pmh.Mod().Identity().String(), cfg: hashConfig(cfg), } - h := s.view.session.cache.store.Bind(key, func(ctx context.Context) interface{} { + h := s.view.session.cache.store.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} { ctx, done := event.Start(ctx, "cache.ModTidyHandle", tag.URI.Of(modURI)) defer done() - original, m, parseErrors, err := pmh.Parse(ctx) + snapshot := arg.(*snapshot) + + original, m, parseErrors, err := pmh.Parse(ctx, snapshot) if err != nil || len(parseErrors) > 0 { return &modTidyData{ diagnostics: parseErrors, @@ -191,7 +191,7 @@ func (s *snapshot) ModTidyHandle(ctx context.Context) (source.ModTidyHandle, err // go.mod file. The fixes will be for the go.mod file, but the // diagnostics should appear on the import statements in the Go or // go.mod files. - missingModuleErrs, err := missingModuleErrors(ctx, fset, m, workspacePkgs, ideal.Require, missingDeps, original, options) + missingModuleErrs, err := missingModuleErrors(ctx, snapshot, m, workspacePkgs, ideal.Require, missingDeps, original, options) if err != nil { return &modTidyData{err: err} } @@ -430,7 +430,7 @@ func rangeFromPositions(m *protocol.ColumnMapper, s, e modfile.Position) (protoc // missingModuleErrors returns diagnostics for each file in each workspace // package that has dependencies that are not reflected in the go.mod file. -func missingModuleErrors(ctx context.Context, fset *token.FileSet, modMapper *protocol.ColumnMapper, pkgs []source.Package, modules []*modfile.Require, missingMods map[string]*modfile.Require, original *modfile.File, options source.Options) ([]source.Error, error) { +func missingModuleErrors(ctx context.Context, snapshot *snapshot, modMapper *protocol.ColumnMapper, pkgs []source.Package, modules []*modfile.Require, missingMods map[string]*modfile.Require, original *modfile.File, options source.Options) ([]source.Error, error) { var moduleErrs []source.Error matchedMissingMods := make(map[*modfile.Require]struct{}) for _, pkg := range pkgs { @@ -462,7 +462,7 @@ func missingModuleErrors(ctx context.Context, fset *token.FileSet, modMapper *pr } } if len(missingPkgs) > 0 { - errs, err := missingModules(ctx, fset, modMapper, pkg, missingPkgs, options) + errs, err := missingModules(ctx, snapshot, modMapper, pkg, missingPkgs, options) if err != nil { return nil, err } @@ -499,10 +499,10 @@ func missingModuleErrors(ctx context.Context, fset *token.FileSet, modMapper *pr return moduleErrs, nil } -func missingModules(ctx context.Context, fset *token.FileSet, modMapper *protocol.ColumnMapper, pkg source.Package, missing map[string]*modfile.Require, options source.Options) ([]source.Error, error) { +func missingModules(ctx context.Context, snapshot *snapshot, modMapper *protocol.ColumnMapper, pkg source.Package, missing map[string]*modfile.Require, options source.Options) ([]source.Error, error) { var errors []source.Error for _, pgh := range pkg.CompiledGoFiles() { - file, _, m, _, err := pgh.Parse(ctx) + file, _, m, _, err := pgh.Parse(ctx, snapshot.view) if err != nil { return nil, err } @@ -526,7 +526,7 @@ func missingModules(ctx context.Context, fset *token.FileSet, modMapper *protoco if !ok { continue } - spn, err := span.NewRange(fset, imp.Path.Pos(), imp.Path.End()).Span() + spn, err := span.NewRange(snapshot.view.session.cache.fset, imp.Path.Pos(), imp.Path.End()).Span() if err != nil { return nil, err } diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go index d5165ba56c..04a78bdfef 100644 --- a/internal/lsp/cache/parse.go +++ b/internal/lsp/cache/parse.go @@ -65,18 +65,21 @@ func (c *Cache) parseGoHandle(ctx context.Context, fh source.FileHandle, mode so file: fh.Identity(), mode: mode, } - fset := c.fset - h := c.store.Bind(key, func(ctx context.Context) interface{} { - return parseGo(ctx, fset, fh, mode) + parseHandle := c.store.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} { + view := arg.(*View) + return parseGo(ctx, view.session.cache.fset, fh, mode) + }) + + astHandle := c.store.Bind(astCacheKey(key), func(ctx context.Context, arg memoize.Arg) interface{} { + view := arg.(*View) + return buildASTCache(ctx, view, parseHandle) }) return &parseGoHandle{ - handle: h, - file: fh, - mode: mode, - astCacheHandle: c.store.Bind(astCacheKey(key), func(ctx context.Context) interface{} { - return buildASTCache(ctx, h) - }), + handle: parseHandle, + file: fh, + mode: mode, + astCacheHandle: astHandle, } } @@ -92,20 +95,20 @@ func (pgh *parseGoHandle) Mode() source.ParseMode { return pgh.mode } -func (pgh *parseGoHandle) Parse(ctx context.Context) (*ast.File, []byte, *protocol.ColumnMapper, error, error) { - data, err := pgh.parse(ctx) +func (pgh *parseGoHandle) Parse(ctx context.Context, v source.View) (*ast.File, []byte, *protocol.ColumnMapper, error, error) { + data, err := pgh.parse(ctx, v.(*View)) if err != nil { return nil, nil, nil, nil, err } return data.ast, data.src, data.mapper, data.parseError, data.err } -func (pgh *parseGoHandle) parse(ctx context.Context) (*parseGoData, error) { - v, err := pgh.handle.Get(ctx) +func (pgh *parseGoHandle) parse(ctx context.Context, v *View) (*parseGoData, error) { + d, err := pgh.handle.Get(ctx, v) if err != nil { return nil, err } - data, ok := v.(*parseGoData) + data, ok := d.(*parseGoData) if !ok { return nil, errors.Errorf("no parsed file for %s", pgh.File().URI()) } @@ -129,13 +132,13 @@ func (pgh *parseGoHandle) cached() (*parseGoData, error) { return data, nil } -func (pgh *parseGoHandle) PosToDecl(ctx context.Context) (map[token.Pos]ast.Decl, error) { - v, err := pgh.astCacheHandle.Get(ctx) +func (pgh *parseGoHandle) PosToDecl(ctx context.Context, v source.View) (map[token.Pos]ast.Decl, error) { + d, err := pgh.astCacheHandle.Get(ctx, v.(*View)) if err != nil || v == nil { return nil, err } - data := v.(*astCacheData) + data := d.(*astCacheData) if data.err != nil { return nil, data.err } @@ -143,13 +146,13 @@ func (pgh *parseGoHandle) PosToDecl(ctx context.Context) (map[token.Pos]ast.Decl return data.posToDecl, nil } -func (pgh *parseGoHandle) PosToField(ctx context.Context) (map[token.Pos]*ast.Field, error) { - v, err := pgh.astCacheHandle.Get(ctx) - if err != nil || v == nil { +func (pgh *parseGoHandle) PosToField(ctx context.Context, v source.View) (map[token.Pos]*ast.Field, error) { + d, err := pgh.astCacheHandle.Get(ctx, v.(*View)) + if err != nil || d == nil { return nil, err } - data := v.(*astCacheData) + data := d.(*astCacheData) if data.err != nil { return nil, data.err } @@ -168,7 +171,7 @@ type astCacheData struct { // buildASTCache builds caches to aid in quickly going from the typed // world to the syntactic world. -func buildASTCache(ctx context.Context, parseHandle *memoize.Handle) *astCacheData { +func buildASTCache(ctx context.Context, view *View, parseHandle *memoize.Handle) *astCacheData { var ( // path contains all ancestors, including n. path []ast.Node @@ -176,7 +179,7 @@ func buildASTCache(ctx context.Context, parseHandle *memoize.Handle) *astCacheDa decls []ast.Decl ) - v, err := parseHandle.Get(ctx) + v, err := parseHandle.Get(ctx, view) if err != nil || v == nil || v.(*parseGoData).ast == nil { return &astCacheData{err: err} } diff --git a/internal/lsp/cache/snapshot.go b/internal/lsp/cache/snapshot.go index 28e02e69f4..4ca9277583 100644 --- a/internal/lsp/cache/snapshot.go +++ b/internal/lsp/cache/snapshot.go @@ -24,6 +24,7 @@ import ( "golang.org/x/tools/internal/gocommand" "golang.org/x/tools/internal/lsp/debug/tag" "golang.org/x/tools/internal/lsp/source" + "golang.org/x/tools/internal/memoize" "golang.org/x/tools/internal/packagesinternal" "golang.org/x/tools/internal/span" "golang.org/x/tools/internal/typesinternal" @@ -31,6 +32,8 @@ import ( ) type snapshot struct { + memoize.Arg // allow as a memoize.Function arg + id uint64 view *View @@ -1004,8 +1007,8 @@ func (s *snapshot) shouldInvalidateMetadata(ctx context.Context, originalFH, cur return originalFH.URI() == s.view.modURI } // Get the original and current parsed files in order to check package name and imports. - original, _, _, _, originalErr := s.view.session.cache.ParseGoHandle(ctx, originalFH, source.ParseHeader).Parse(ctx) - current, _, _, _, currentErr := s.view.session.cache.ParseGoHandle(ctx, currentFH, source.ParseHeader).Parse(ctx) + original, _, _, _, originalErr := s.view.session.cache.ParseGoHandle(ctx, originalFH, source.ParseHeader).Parse(ctx, s.view) + current, _, _, _, currentErr := s.view.session.cache.ParseGoHandle(ctx, currentFH, source.ParseHeader).Parse(ctx, s.view) if originalErr != nil || currentErr != nil { return (originalErr == nil) != (currentErr == nil) } diff --git a/internal/lsp/cache/view.go b/internal/lsp/cache/view.go index 105b8aeaa8..cd1b763b0e 100644 --- a/internal/lsp/cache/view.go +++ b/internal/lsp/cache/view.go @@ -33,6 +33,8 @@ import ( ) type View struct { + memoize.Arg // allow as a memoize.Function arg + session *Session id string @@ -143,7 +145,6 @@ type View struct { type builtinPackageHandle struct { handle *memoize.Handle - file source.ParseGoHandle } type builtinPackageData struct { @@ -300,7 +301,7 @@ func (v *View) BuiltinPackage(ctx context.Context) (source.BuiltinPackage, error if v.builtin == nil { return nil, errors.Errorf("no builtin package for view %s", v.name) } - data, err := v.builtin.handle.Get(ctx) + data, err := v.builtin.handle.Get(ctx, v) if err != nil { return nil, err } @@ -332,14 +333,15 @@ func (v *View) buildBuiltinPackage(ctx context.Context, goFiles []string) error if err != nil { return err } - pgh := v.session.cache.parseGoHandle(ctx, fh, source.ParseFull) - fset := v.session.cache.fset - h := v.session.cache.store.Bind(fh.Identity(), func(ctx context.Context) interface{} { - file, _, _, _, err := pgh.Parse(ctx) + h := v.session.cache.store.Bind(fh.Identity(), func(ctx context.Context, arg memoize.Arg) interface{} { + view := arg.(*View) + + pgh := view.session.cache.parseGoHandle(ctx, fh, source.ParseFull) + file, _, _, _, err := pgh.Parse(ctx, view) if err != nil { return &builtinPackageData{err: err} } - pkg, err := ast.NewPackage(fset, map[string]*ast.File{ + pkg, err := ast.NewPackage(view.session.cache.fset, map[string]*ast.File{ pgh.File().URI().Filename(): file, }, nil, nil) if err != nil { @@ -352,7 +354,6 @@ func (v *View) buildBuiltinPackage(ctx context.Context, goFiles []string) error }) v.builtin = &builtinPackageHandle{ handle: h, - file: pgh, } return nil } @@ -565,7 +566,7 @@ func (v *View) WorkspaceDirectories(ctx context.Context) ([]string, error) { if err != nil { return nil, err } - parsed, _, _, err := pmh.Parse(ctx) + parsed, _, _, err := pmh.Parse(ctx, v.Snapshot()) if err != nil { return nil, err } diff --git a/internal/lsp/code_action.go b/internal/lsp/code_action.go index 090567fb46..29c99ed3af 100644 --- a/internal/lsp/code_action.go +++ b/internal/lsp/code_action.go @@ -461,7 +461,7 @@ func moduleQuickFixes(ctx context.Context, snapshot source.Snapshot, diagnostics if err != nil { return nil, err } - errors, err := mth.Tidy(ctx) + errors, err := mth.Tidy(ctx, snapshot) if err != nil { return nil, err } @@ -515,7 +515,7 @@ func goModTidy(ctx context.Context, snapshot source.Snapshot) (*protocol.CodeAct return nil, err } uri := mth.ParseModHandle().Mod().URI() - _, m, _, err := mth.ParseModHandle().Parse(ctx) + _, m, _, err := mth.ParseModHandle().Parse(ctx, snapshot) if err != nil { return nil, err } @@ -523,7 +523,7 @@ func goModTidy(ctx context.Context, snapshot source.Snapshot) (*protocol.CodeAct if err != nil { return nil, err } - right, err := mth.TidiedContent(ctx) + right, err := mth.TidiedContent(ctx, snapshot) if err != nil { return nil, err } diff --git a/internal/lsp/diagnostics.go b/internal/lsp/diagnostics.go index d2907cb558..f5ae5f004e 100644 --- a/internal/lsp/diagnostics.go +++ b/internal/lsp/diagnostics.go @@ -115,7 +115,7 @@ If you believe this is a mistake, please file an issue: https://github.com/golan go func(ph source.PackageHandle) { defer wg.Done() - pkg, err := ph.Check(ctx) + pkg, err := ph.Check(ctx, snapshot) if err != nil { event.Error(ctx, "warning: diagnose package", err, tag.Snapshot.Of(snapshot.ID()), tag.Package.Of(ph.ID())) return diff --git a/internal/lsp/link.go b/internal/lsp/link.go index 135d4574dd..63de726c76 100644 --- a/internal/lsp/link.go +++ b/internal/lsp/link.go @@ -50,7 +50,7 @@ func modLinks(ctx context.Context, snapshot source.Snapshot, fh source.FileHandl if err != nil { return nil, err } - file, m, _, err := pmh.Parse(ctx) + file, m, _, err := pmh.Parse(ctx, snapshot) if err != nil { return nil, err } @@ -109,7 +109,7 @@ func goLinks(ctx context.Context, view source.View, fh source.FileHandle) ([]pro return nil, err } pgh := view.Session().Cache().ParseGoHandle(ctx, fh, source.ParseFull) - file, _, m, _, err := pgh.Parse(ctx) + file, _, m, _, err := pgh.Parse(ctx, view) if err != nil { return nil, err } @@ -142,7 +142,7 @@ func goLinks(ctx context.Context, view source.View, fh source.FileHandle) ([]pro if view.IsGoPrivatePath(target) { continue } - if mod, version, ok := moduleAtVersion(ctx, target, ph); ok && strings.ToLower(view.Options().LinkTarget) == "pkg.go.dev" { + if mod, version, ok := moduleAtVersion(ctx, view.Snapshot(), target, ph); ok && strings.ToLower(view.Options().LinkTarget) == "pkg.go.dev" { target = strings.Replace(target, mod, mod+"@"+version, 1) } // Account for the quotation marks in the positions. @@ -175,8 +175,8 @@ func goLinks(ctx context.Context, view source.View, fh source.FileHandle) ([]pro return links, nil } -func moduleAtVersion(ctx context.Context, target string, ph source.PackageHandle) (string, string, bool) { - pkg, err := ph.Check(ctx) +func moduleAtVersion(ctx context.Context, snapshot source.Snapshot, target string, ph source.PackageHandle) (string, string, bool) { + pkg, err := ph.Check(ctx, snapshot) if err != nil { return "", "", false } diff --git a/internal/lsp/mod/code_lens.go b/internal/lsp/mod/code_lens.go index 96879445ed..5f0a0b4305 100644 --- a/internal/lsp/mod/code_lens.go +++ b/internal/lsp/mod/code_lens.go @@ -32,7 +32,7 @@ func CodeLens(ctx context.Context, snapshot source.Snapshot, uri span.URI) ([]pr if err != nil { return nil, err } - file, m, _, err := pmh.Parse(ctx) + file, m, _, err := pmh.Parse(ctx, snapshot) if err != nil { return nil, err } @@ -40,7 +40,7 @@ func CodeLens(ctx context.Context, snapshot source.Snapshot, uri span.URI) ([]pr if err != nil { return nil, err } - upgrades, err := muh.Upgrades(ctx) + upgrades, err := muh.Upgrades(ctx, snapshot) if err != nil { return nil, err } diff --git a/internal/lsp/mod/diagnostics.go b/internal/lsp/mod/diagnostics.go index b70b64a4c1..878294dfa9 100644 --- a/internal/lsp/mod/diagnostics.go +++ b/internal/lsp/mod/diagnostics.go @@ -44,7 +44,7 @@ func Diagnostics(ctx context.Context, snapshot source.Snapshot) (map[source.File if err != nil { return nil, err } - diagnostics, err := mth.Tidy(ctx) + diagnostics, err := mth.Tidy(ctx, snapshot) if err != nil { return nil, err } @@ -102,7 +102,7 @@ func ExtractGoCommandError(ctx context.Context, snapshot source.Snapshot, fh sou if err != nil { return nil, err } - parsed, m, _, err := pmh.Parse(ctx) + parsed, m, _, err := pmh.Parse(ctx, snapshot) if err != nil { return nil, err } diff --git a/internal/lsp/mod/format.go b/internal/lsp/mod/format.go index fdb52e483d..5d73f04bd8 100644 --- a/internal/lsp/mod/format.go +++ b/internal/lsp/mod/format.go @@ -16,7 +16,7 @@ func Format(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle) if err != nil { return nil, err } - file, m, _, err := pmh.Parse(ctx) + file, m, _, err := pmh.Parse(ctx, snapshot) if err != nil { return nil, err } diff --git a/internal/lsp/mod/hover.go b/internal/lsp/mod/hover.go index 98b9ceef09..b66471496f 100644 --- a/internal/lsp/mod/hover.go +++ b/internal/lsp/mod/hover.go @@ -30,7 +30,7 @@ func Hover(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle, if err != nil { return nil, fmt.Errorf("getting modfile handle: %w", err) } - file, m, _, err := pmh.Parse(ctx) + file, m, _, err := pmh.Parse(ctx, snapshot) if err != nil { return nil, err } @@ -72,7 +72,7 @@ func Hover(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle, if err != nil { return nil, err } - why, err := mwh.Why(ctx) + why, err := mwh.Why(ctx, snapshot) if err != nil { return nil, fmt.Errorf("running go mod why: %w", err) } diff --git a/internal/lsp/source/code_lens.go b/internal/lsp/source/code_lens.go index fbf881d041..b38aecff22 100644 --- a/internal/lsp/source/code_lens.go +++ b/internal/lsp/source/code_lens.go @@ -144,7 +144,7 @@ func matchTestFunc(fn *ast.FuncDecl, pkg Package, nameRe *regexp.Regexp, paramID func goGenerateCodeLens(ctx context.Context, snapshot Snapshot, fh FileHandle) ([]protocol.CodeLens, error) { pgh := snapshot.View().Session().Cache().ParseGoHandle(ctx, fh, ParseFull) - file, _, m, _, err := pgh.Parse(ctx) + file, _, m, _, err := pgh.Parse(ctx, snapshot.View()) if err != nil { return nil, err } @@ -194,7 +194,7 @@ func goGenerateCodeLens(ctx context.Context, snapshot Snapshot, fh FileHandle) ( func regenerateCgoLens(ctx context.Context, snapshot Snapshot, fh FileHandle) ([]protocol.CodeLens, error) { pgh := snapshot.View().Session().Cache().ParseGoHandle(ctx, fh, ParseFull) - file, _, m, _, err := pgh.Parse(ctx) + file, _, m, _, err := pgh.Parse(ctx, snapshot.View()) if err != nil { return nil, err } diff --git a/internal/lsp/source/completion_format.go b/internal/lsp/source/completion_format.go index 664f8fff09..399ef9d1d1 100644 --- a/internal/lsp/source/completion_format.go +++ b/internal/lsp/source/completion_format.go @@ -193,7 +193,7 @@ func (c *completer) item(ctx context.Context, cand candidate) (CompletionItem, e return item, nil } - posToDecl, err := ph.PosToDecl(ctx) + posToDecl, err := ph.PosToDecl(ctx, c.snapshot.View()) if err != nil { return CompletionItem{}, err } diff --git a/internal/lsp/source/diagnostics.go b/internal/lsp/source/diagnostics.go index b38c626dd7..661dd66e9f 100644 --- a/internal/lsp/source/diagnostics.go +++ b/internal/lsp/source/diagnostics.go @@ -137,7 +137,7 @@ func FileDiagnostics(ctx context.Context, snapshot Snapshot, uri span.URI) (File if err != nil { return FileIdentity{}, nil, err } - pkg, err := ph.Check(ctx) + pkg, err := ph.Check(ctx, snapshot) if err != nil { return FileIdentity{}, nil, err } diff --git a/internal/lsp/source/folding_range.go b/internal/lsp/source/folding_range.go index 74ad3e657d..fbe3fcd735 100644 --- a/internal/lsp/source/folding_range.go +++ b/internal/lsp/source/folding_range.go @@ -19,7 +19,7 @@ func FoldingRange(ctx context.Context, snapshot Snapshot, fh FileHandle, lineFol // TODO(suzmue): consider limiting the number of folding ranges returned, and // implement a way to prioritize folding ranges in that case. pgh := snapshot.View().Session().Cache().ParseGoHandle(ctx, fh, ParseFull) - file, _, m, _, err := pgh.Parse(ctx) + file, _, m, _, err := pgh.Parse(ctx, snapshot.View()) if err != nil { return nil, err } diff --git a/internal/lsp/source/format.go b/internal/lsp/source/format.go index 09a903fd5b..04042cd89f 100644 --- a/internal/lsp/source/format.go +++ b/internal/lsp/source/format.go @@ -27,7 +27,7 @@ func Format(ctx context.Context, snapshot Snapshot, fh FileHandle) ([]protocol.T defer done() pgh := snapshot.View().Session().Cache().ParseGoHandle(ctx, fh, ParseFull) - file, _, m, parseErrors, err := pgh.Parse(ctx) + file, _, m, parseErrors, err := pgh.Parse(ctx, snapshot.View()) if err != nil { return nil, err } @@ -110,7 +110,7 @@ func computeImportEdits(ctx context.Context, view View, ph ParseGoHandle, option if err != nil { return nil, nil, err } - _, _, origMapper, _, err := ph.Parse(ctx) + _, _, origMapper, _, err := ph.Parse(ctx, view) if err != nil { return nil, nil, err } @@ -145,7 +145,7 @@ func computeOneImportFixEdits(ctx context.Context, view View, ph ParseGoHandle, if err != nil { return nil, err } - _, _, origMapper, _, err := ph.Parse(ctx) // ph.Parse returns values never used + _, _, origMapper, _, err := ph.Parse(ctx, view) // ph.Parse returns values never used if err != nil { return nil, err } diff --git a/internal/lsp/source/highlight.go b/internal/lsp/source/highlight.go index fe9121754b..ac6cc36217 100644 --- a/internal/lsp/source/highlight.go +++ b/internal/lsp/source/highlight.go @@ -26,7 +26,7 @@ func Highlight(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protoc if err != nil { return nil, fmt.Errorf("getting file for Highlight: %w", err) } - file, _, m, _, err := pgh.Parse(ctx) + file, _, m, _, err := pgh.Parse(ctx, snapshot.View()) if err != nil { return nil, err } diff --git a/internal/lsp/source/identifier.go b/internal/lsp/source/identifier.go index 5b7e43ac54..585c250395 100644 --- a/internal/lsp/source/identifier.go +++ b/internal/lsp/source/identifier.go @@ -293,7 +293,7 @@ func objToDecl(ctx context.Context, v View, srcPkg Package, obj types.Object) (a if err != nil { return nil, err } - posToDecl, err := ph.PosToDecl(ctx) + posToDecl, err := ph.PosToDecl(ctx, v) if err != nil { return nil, err } diff --git a/internal/lsp/source/implementation.go b/internal/lsp/source/implementation.go index 3bf876b925..38ec43b4b1 100644 --- a/internal/lsp/source/implementation.go +++ b/internal/lsp/source/implementation.go @@ -95,7 +95,7 @@ func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol. return nil, err } for _, ph := range knownPkgs { - pkg, err := ph.Check(ctx) + pkg, err := ph.Check(ctx, s) if err != nil { return nil, err } @@ -212,7 +212,7 @@ func qualifiedObjsAtProtocolPos(ctx context.Context, s Snapshot, fh FileHandle, // Check all the packages that the file belongs to. var qualifiedObjs []qualifiedObject for _, ph := range phs { - searchpkg, err := ph.Check(ctx) + searchpkg, err := ph.Check(ctx, s) if err != nil { return nil, err } diff --git a/internal/lsp/source/references.go b/internal/lsp/source/references.go index 74e2fe4d6d..3ba71dd855 100644 --- a/internal/lsp/source/references.go +++ b/internal/lsp/source/references.go @@ -77,7 +77,7 @@ func references(ctx context.Context, s Snapshot, qos []qualifiedObject, includeD return nil, err } for _, ph := range reverseDeps { - pkg, err := ph.Check(ctx) + pkg, err := ph.Check(ctx, s) if err != nil { return nil, err } diff --git a/internal/lsp/source/types_format.go b/internal/lsp/source/types_format.go index 6c96b802e2..f0e09885b2 100644 --- a/internal/lsp/source/types_format.go +++ b/internal/lsp/source/types_format.go @@ -204,7 +204,7 @@ func formatVarType(ctx context.Context, s Snapshot, srcpkg Package, srcfile *ast return types.TypeString(obj.Type(), qf) } - expr, err := varType(ctx, ph, obj) + expr, err := varType(ctx, s, ph, obj) if err != nil { return types.TypeString(obj.Type(), qf) } @@ -224,8 +224,8 @@ func formatVarType(ctx context.Context, s Snapshot, srcpkg Package, srcfile *ast } // varType returns the type expression for a *types.Var. -func varType(ctx context.Context, ph ParseGoHandle, obj *types.Var) (ast.Expr, error) { - posToField, err := ph.PosToField(ctx) +func varType(ctx context.Context, snapshot Snapshot, ph ParseGoHandle, obj *types.Var) (ast.Expr, error) { + posToField, err := ph.PosToField(ctx, snapshot.View()) if err != nil { return nil, err } diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go index 3fce0a4ec2..5a0ec62b02 100644 --- a/internal/lsp/source/util.go +++ b/internal/lsp/source/util.go @@ -77,7 +77,7 @@ func getParsedFile(ctx context.Context, snapshot Snapshot, fh FileHandle, select if err != nil { return nil, nil, err } - pkg, err := ph.Check(ctx) + pkg, err := ph.Check(ctx, snapshot) if err != nil { return nil, nil, err } @@ -148,7 +148,7 @@ func IsGenerated(ctx context.Context, snapshot Snapshot, uri span.URI) bool { return false } ph := snapshot.View().Session().Cache().ParseGoHandle(ctx, fh, ParseHeader) - parsed, _, _, _, err := ph.Parse(ctx) + parsed, _, _, _, err := ph.Parse(ctx, snapshot.View()) if err != nil { return false } diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go index 11ba8b6013..63b122a49e 100644 --- a/internal/lsp/source/view.go +++ b/internal/lsp/source/view.go @@ -108,7 +108,7 @@ type PackageHandle interface { CompiledGoFiles() []span.URI // Check returns the type-checked Package for the PackageHandle. - Check(ctx context.Context) (Package, error) + Check(ctx context.Context, snapshot Snapshot) (Package, error) // Cached returns the Package for the PackageHandle if it has already been stored. Cached() (Package, error) @@ -320,7 +320,7 @@ type ParseGoHandle interface { // Parse returns the parsed AST for the file. // If the file is not available, returns nil and an error. - Parse(ctx context.Context) (file *ast.File, src []byte, m *protocol.ColumnMapper, parseErr error, err error) + Parse(ctx context.Context, view View) (file *ast.File, src []byte, m *protocol.ColumnMapper, parseErr error, err error) // Cached returns the AST for this handle, if it has already been stored. Cached() (file *ast.File, src []byte, m *protocol.ColumnMapper, parseErr error, err error) @@ -329,12 +329,12 @@ type ParseGoHandle interface { // to quickly find corresponding *ast.Field node given a *types.Var. // We must refer to the AST to render type aliases properly when // formatting signatures and other types. - PosToField(context.Context) (map[token.Pos]*ast.Field, error) + PosToField(ctx context.Context, view View) (map[token.Pos]*ast.Field, error) // PosToDecl maps certain objects' positions to their surrounding // ast.Decl. This mapping is used when building the documentation // string for the objects. - PosToDecl(context.Context) (map[token.Pos]ast.Decl, error) + PosToDecl(ctx context.Context, view View) (map[token.Pos]ast.Decl, error) } type ParseModHandle interface { @@ -346,19 +346,19 @@ type ParseModHandle interface { // Parse returns the parsed go.mod file, a column mapper, and a list of // parse for the go.mod file. - Parse(ctx context.Context) (*modfile.File, *protocol.ColumnMapper, []Error, error) + Parse(ctx context.Context, snapshot Snapshot) (*modfile.File, *protocol.ColumnMapper, []Error, error) } type ModUpgradeHandle interface { // Upgrades returns the latest versions for each of the module's // dependencies. - Upgrades(ctx context.Context) (map[string]string, error) + Upgrades(ctx context.Context, snapshot Snapshot) (map[string]string, error) } type ModWhyHandle interface { // Why returns the results of `go mod why` for every dependency of the // module. - Why(ctx context.Context) (map[string]string, error) + Why(ctx context.Context, snapshot Snapshot) (map[string]string, error) } type ModTidyHandle interface { @@ -366,10 +366,10 @@ type ModTidyHandle interface { ParseModHandle() ParseModHandle // Tidy returns the results of `go mod tidy` for the module. - Tidy(ctx context.Context) ([]Error, error) + Tidy(ctx context.Context, snapshot Snapshot) ([]Error, error) // TidiedContent is the content of the tidied go.mod file. - TidiedContent(ctx context.Context) ([]byte, error) + TidiedContent(ctx context.Context, snapshot Snapshot) ([]byte, error) } var ErrTmpModfileUnsupported = errors.New("-modfile is unsupported for this Go version") diff --git a/internal/lsp/source/workspace_symbol.go b/internal/lsp/source/workspace_symbol.go index 8d470fb2d8..54613b7810 100644 --- a/internal/lsp/source/workspace_symbol.go +++ b/internal/lsp/source/workspace_symbol.go @@ -52,7 +52,7 @@ outer: return nil, err } for _, ph := range knownPkgs { - pkg, err := ph.Check(ctx) + pkg, err := ph.Check(ctx, view.Snapshot()) symbolMatcher := makePackageSymbolMatcher(style, pkg, queryMatcher) if err != nil { return nil, err diff --git a/internal/memoize/memoize.go b/internal/memoize/memoize.go index 595d438252..a5e81f096e 100644 --- a/internal/memoize/memoize.go +++ b/internal/memoize/memoize.go @@ -32,9 +32,13 @@ type Store struct { entries map[interface{}]uintptr } +// Arg is a marker interface that can be embedded to indicate a type is +// intended for use as a Function argument. +type Arg interface{ memoizeArg() } + // Function is the type for functions that can be memoized. // The result must be a pointer. -type Function func(ctx context.Context) interface{} +type Function func(ctx context.Context, arg Arg) interface{} type state int @@ -203,14 +207,14 @@ func (h *Handle) Cached() interface{} { // If the value is not yet ready, the underlying function will be invoked. // This activates the handle, and it will remember the value for as long as it exists. // If ctx is cancelled, Get returns nil. -func (h *Handle) Get(ctx context.Context) (interface{}, error) { +func (h *Handle) Get(ctx context.Context, arg Arg) (interface{}, error) { if ctx.Err() != nil { return nil, ctx.Err() } h.mu.Lock() switch h.state { case stateIdle: - return h.run(ctx) + return h.run(ctx, arg) case stateRunning: return h.wait(ctx) case stateCompleted: @@ -222,7 +226,7 @@ func (h *Handle) Get(ctx context.Context) (interface{}, error) { } // run starts h.function and returns the result. h.mu must be locked. -func (h *Handle) run(ctx context.Context) (interface{}, error) { +func (h *Handle) run(ctx context.Context, arg Arg) (interface{}, error) { childCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) h.cancel = cancel h.state = stateRunning @@ -234,7 +238,7 @@ func (h *Handle) run(ctx context.Context) (interface{}, error) { if childCtx.Err() != nil { return } - v := function(childCtx) + v := function(childCtx, arg) if childCtx.Err() != nil { return } diff --git a/internal/memoize/memoize_test.go b/internal/memoize/memoize_test.go index 305b594db9..00d003a6c4 100644 --- a/internal/memoize/memoize_test.go +++ b/internal/memoize/memoize_test.go @@ -83,7 +83,7 @@ end @3 = G !fail H `[1:], }, } { - s.Bind(test.key, generate(s, test.key)).Get(ctx) + s.Bind(test.key, generate(s, test.key)).Get(ctx, nil) got := logBuffer.String() if got != test.want { t.Errorf("at %q expected:\n%v\ngot:\n%s", test.name, test.want, got) @@ -103,7 +103,7 @@ end @3 = G !fail H var pins []*memoize.Handle for _, key := range pinned { h := s.Bind(key, generate(s, key)) - h.Get(ctx) + h.Get(ctx, nil) pins = append(pins, h) } @@ -175,7 +175,7 @@ func asValue(v interface{}) *stringOrError { } func generate(s *memoize.Store, key interface{}) memoize.Function { - return func(ctx context.Context) interface{} { + return func(ctx context.Context, _ memoize.Arg) interface{} { name := key.(string) switch name { case "": @@ -217,7 +217,7 @@ func joinValues(ctx context.Context, s *memoize.Store, name string, keys ...stri fmt.Fprintf(w, "start %v\n", name) value := "" for _, key := range keys { - i, err := s.Bind(key, generate(s, key)).Get(ctx) + i, err := s.Bind(key, generate(s, key)).Get(ctx, nil) if err != nil { return &stringOrError{err: err} }