internal/lsp: pass snapshot/view to memoize.Functions

Due to the runtime's inability to collect cycles involving finalizers,
we can't close over handles in memoize.Functions without causing memory
leaks. Up until now we've dealt with that by closing over all the bits
of the snapshot that we want, but it distorts the design of all the code
used in the Functions.

We can solve the problem another way: instead of closing over the
snapshot/view, we can force the caller to pass it in. This is somewhat
scary: there is no requirement that the argument matches the data that
we're working with. But the reality is that this is not a new problem:
the Function used to calculate a cache value is not necessarily the one
that the caller expects. As long as the cache key fully identifies all
the inputs to the Function, the output should be correct. And since the
caller used the snapshot/view to calculate that cache key, it should
always be safe to pass in that snapshot/view. If it's not, then we
already had a bug.

The Arg type in memoize is clumsy, but I thought it would be nice to
have at least a little bit of type safety. I'm open to suggestions.

Change-Id: I23f546638b0c66a4698620a986949087211f4762
Reviewed-on: https://go-review.googlesource.com/c/tools/+/244019
Reviewed-by: Robert Findley <rfindley@google.com>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
This commit is contained in:
Heschi Kreinick 2020-07-16 17:37:12 -04:00
parent 60da08ac03
commit 72051f7961
30 changed files with 183 additions and 178 deletions

View File

@ -7,7 +7,6 @@ package cache
import (
"context"
"fmt"
"go/token"
"go/types"
"reflect"
"sort"
@ -41,7 +40,7 @@ func (s *snapshot) Analyze(ctx context.Context, id string, analyzers ...*analysi
var results []*source.Error
for _, ah := range roots {
diagnostics, _, err := ah.analyze(ctx)
diagnostics, _, err := ah.analyze(ctx, s)
if err != nil {
return nil, err
}
@ -93,7 +92,7 @@ func (s *snapshot) actionHandle(ctx context.Context, id packageID, a *analysis.A
if len(ph.key) == 0 {
return nil, errors.Errorf("no key for PackageHandle %s", id)
}
pkg, err := ph.check(ctx)
pkg, err := ph.check(ctx, s)
if err != nil {
return nil, err
}
@ -133,17 +132,16 @@ func (s *snapshot) actionHandle(ctx context.Context, id packageID, a *analysis.A
}
}
fset := s.view.session.cache.fset
h := s.view.session.cache.store.Bind(buildActionKey(a, ph), func(ctx context.Context) interface{} {
h := s.view.session.cache.store.Bind(buildActionKey(a, ph), func(ctx context.Context, arg memoize.Arg) interface{} {
snapshot := arg.(*snapshot)
// Analyze dependencies first.
results, err := execAll(ctx, deps)
results, err := execAll(ctx, snapshot, deps)
if err != nil {
return &actionData{
err: err,
}
}
return runAnalysis(ctx, fset, a, pkg, results)
return runAnalysis(ctx, snapshot, a, pkg, results)
})
act.handle = h
@ -151,8 +149,8 @@ func (s *snapshot) actionHandle(ctx context.Context, id packageID, a *analysis.A
return act, nil
}
func (act *actionHandle) analyze(ctx context.Context) ([]*source.Error, interface{}, error) {
v, err := act.handle.Get(ctx)
func (act *actionHandle) analyze(ctx context.Context, snapshot *snapshot) ([]*source.Error, interface{}, error) {
v, err := act.handle.Get(ctx, snapshot)
if v == nil {
return nil, nil, err
}
@ -174,7 +172,7 @@ func (act *actionHandle) String() string {
return fmt.Sprintf("%s@%s", act.analyzer, act.pkg.PkgPath())
}
func execAll(ctx context.Context, actions []*actionHandle) (map[*actionHandle]*actionData, error) {
func execAll(ctx context.Context, snapshot *snapshot, actions []*actionHandle) (map[*actionHandle]*actionData, error) {
var mu sync.Mutex
results := make(map[*actionHandle]*actionData)
@ -182,7 +180,7 @@ func execAll(ctx context.Context, actions []*actionHandle) (map[*actionHandle]*a
for _, act := range actions {
act := act
g.Go(func() error {
v, err := act.handle.Get(ctx)
v, err := act.handle.Get(ctx, snapshot)
if err != nil {
return err
}
@ -201,7 +199,7 @@ func execAll(ctx context.Context, actions []*actionHandle) (map[*actionHandle]*a
return results, g.Wait()
}
func runAnalysis(ctx context.Context, fset *token.FileSet, analyzer *analysis.Analyzer, pkg *pkg, deps map[*actionHandle]*actionData) (data *actionData) {
func runAnalysis(ctx context.Context, snapshot *snapshot, analyzer *analysis.Analyzer, pkg *pkg, deps map[*actionHandle]*actionData) (data *actionData) {
data = &actionData{
objectFacts: make(map[objectFactKey]analysis.Fact),
packageFacts: make(map[packageFactKey]analysis.Fact),
@ -251,7 +249,7 @@ func runAnalysis(ctx context.Context, fset *token.FileSet, analyzer *analysis.An
// Run the analysis.
pass := &analysis.Pass{
Analyzer: analyzer,
Fset: fset,
Fset: snapshot.view.session.cache.fset,
Files: pkg.GetSyntax(),
Pkg: pkg.GetTypes(),
TypesInfo: pkg.GetTypesInfo(),
@ -343,7 +341,7 @@ func runAnalysis(ctx context.Context, fset *token.FileSet, analyzer *analysis.An
}
for _, diag := range diagnostics {
srcErr, err := sourceError(ctx, fset, pkg, diag)
srcErr, err := sourceError(ctx, snapshot, pkg, diag)
if err != nil {
event.Error(ctx, "unable to compute analysis error position", err, tag.Category.Of(diag.Category), tag.Package.Of(pkg.ID()))
continue

View File

@ -9,7 +9,6 @@ import (
"context"
"fmt"
"go/ast"
"go/token"
"go/types"
"path"
"sort"
@ -83,18 +82,19 @@ func (s *snapshot) buildPackageHandle(ctx context.Context, id packageID, mode so
goFiles := ph.goFiles
compiledGoFiles := ph.compiledGoFiles
key := ph.key
fset := s.view.session.cache.fset
h := s.view.session.cache.store.Bind(key, func(ctx context.Context) interface{} {
h := s.view.session.cache.store.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} {
snapshot := arg.(*snapshot)
// Begin loading the direct dependencies, in parallel.
for _, dep := range deps {
go func(dep *packageHandle) {
dep.check(ctx)
dep.check(ctx, snapshot)
}(dep)
}
data := &packageData{}
data.pkg, data.err = typeCheck(ctx, fset, m, mode, goFiles, compiledGoFiles, deps)
data.pkg, data.err = typeCheck(ctx, snapshot, m, mode, goFiles, compiledGoFiles, deps)
return data
})
@ -193,12 +193,12 @@ func hashConfig(config *packages.Config) string {
return hashContents(b.Bytes())
}
func (ph *packageHandle) Check(ctx context.Context) (source.Package, error) {
return ph.check(ctx)
func (ph *packageHandle) Check(ctx context.Context, s source.Snapshot) (source.Package, error) {
return ph.check(ctx, s.(*snapshot))
}
func (ph *packageHandle) check(ctx context.Context) (*pkg, error) {
v, err := ph.handle.Get(ctx)
func (ph *packageHandle) check(ctx context.Context, s *snapshot) (*pkg, error) {
v, err := ph.handle.Get(ctx, s)
if err != nil {
return nil, err
}
@ -239,7 +239,7 @@ func (s *snapshot) parseGoHandles(ctx context.Context, files []span.URI, mode so
return pghs, nil
}
func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode source.ParseMode, goFiles, compiledGoFiles []*parseGoHandle, deps map[packagePath]*packageHandle) (*pkg, error) {
func typeCheck(ctx context.Context, snapshot *snapshot, m *metadata, mode source.ParseMode, goFiles, compiledGoFiles []*parseGoHandle, deps map[packagePath]*packageHandle) (*pkg, error) {
ctx, done := event.Start(ctx, "cache.importer.typeCheck", tag.Package.Of(string(m.id)))
defer done()
@ -248,6 +248,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc
rawErrors = append(rawErrors, err)
}
fset := snapshot.view.session.cache.fset
pkg := &pkg{
m: m,
mode: mode,
@ -278,7 +279,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc
wg.Add(1)
go func(i int, ph *parseGoHandle) {
defer wg.Done()
data, err := ph.parse(ctx)
data, err := ph.parse(ctx, snapshot.view)
if err != nil {
actualErrors[i] = err
return
@ -294,7 +295,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc
wg.Add(1)
// We need to parse the non-compiled go files, but we don't care about their errors.
go func(ph source.ParseGoHandle) {
ph.Parse(ctx)
ph.Parse(ctx, snapshot.view)
wg.Done()
}(ph)
}
@ -325,7 +326,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc
// Try to attach errors messages to the file as much as possible.
var found bool
for _, e := range rawErrors {
srcErr, err := sourceError(ctx, fset, pkg, e)
srcErr, err := sourceError(ctx, snapshot, pkg, e)
if err != nil {
continue
}
@ -361,7 +362,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc
if !isValidImport(m.pkgPath, dep.m.pkgPath) {
return nil, errors.Errorf("invalid use of internal package %s", pkgPath)
}
depPkg, err := dep.check(ctx)
depPkg, err := dep.check(ctx, snapshot)
if err != nil {
return nil, err
}
@ -386,7 +387,7 @@ func typeCheck(ctx context.Context, fset *token.FileSet, m *metadata, mode sourc
// We don't care about a package's errors unless we have parsed it in full.
if mode == source.ParseFull {
for _, e := range rawErrors {
srcErr, err := sourceError(ctx, fset, pkg, e)
srcErr, err := sourceError(ctx, snapshot, pkg, e)
if err != nil {
event.Error(ctx, "unable to compute error positions", err, tag.Package.Of(pkg.ID()))
continue

View File

@ -25,7 +25,8 @@ import (
errors "golang.org/x/xerrors"
)
func sourceError(ctx context.Context, fset *token.FileSet, pkg *pkg, e interface{}) (*source.Error, error) {
func sourceError(ctx context.Context, snapshot *snapshot, pkg *pkg, e interface{}) (*source.Error, error) {
fset := snapshot.view.session.cache.fset
var (
spn span.Span
err error
@ -38,7 +39,7 @@ func sourceError(ctx context.Context, fset *token.FileSet, pkg *pkg, e interface
case packages.Error:
kind = toSourceErrorKind(e.Kind)
var ok bool
if msg, spn, ok = parseGoListImportCycleError(ctx, fset, e, pkg); ok {
if msg, spn, ok = parseGoListImportCycleError(ctx, snapshot, e, pkg); ok {
kind = source.TypeError
break
}
@ -254,7 +255,7 @@ func parseGoListError(input string) span.Span {
return span.Parse(input[:msgIndex])
}
func parseGoListImportCycleError(ctx context.Context, fset *token.FileSet, e packages.Error, pkg *pkg) (string, span.Span, bool) {
func parseGoListImportCycleError(ctx context.Context, snapshot *snapshot, e packages.Error, pkg *pkg) (string, span.Span, bool) {
re := regexp.MustCompile(`(.*): import stack: \[(.+)\]`)
matches := re.FindStringSubmatch(strings.TrimSpace(e.Msg))
if len(matches) < 3 {
@ -270,14 +271,14 @@ func parseGoListImportCycleError(ctx context.Context, fset *token.FileSet, e pac
// Imports have quotation marks around them.
circImp := strconv.Quote(importList[1])
for _, ph := range pkg.compiledGoFiles {
fh, _, _, _, err := ph.Parse(ctx)
fh, _, _, _, err := ph.Parse(ctx, snapshot.view)
if err != nil {
continue
}
// Search file imports for the import that is causing the import cycle.
for _, imp := range fh.Imports {
if imp.Path.Value == circImp {
spn, err := span.NewRange(fset, imp.Pos(), imp.End()).Span()
spn, err := span.NewRange(snapshot.view.session.cache.fset, imp.Pos(), imp.End()).Span()
if err != nil {
return msg, span.Span{}, false
}

View File

@ -54,8 +54,8 @@ func (mh *parseModHandle) Sum() source.FileHandle {
return mh.sum
}
func (mh *parseModHandle) Parse(ctx context.Context) (*modfile.File, *protocol.ColumnMapper, []source.Error, error) {
v, err := mh.handle.Get(ctx)
func (mh *parseModHandle) Parse(ctx context.Context, s source.Snapshot) (*modfile.File, *protocol.ColumnMapper, []source.Error, error) {
v, err := mh.handle.Get(ctx, s.(*snapshot))
if err != nil {
return nil, nil, nil, err
}
@ -67,7 +67,7 @@ func (s *snapshot) ParseModHandle(ctx context.Context, modFH source.FileHandle)
if handle := s.getModHandle(modFH.URI()); handle != nil {
return handle, nil
}
h := s.view.session.cache.store.Bind(modFH.Identity().String(), func(ctx context.Context) interface{} {
h := s.view.session.cache.store.Bind(modFH.Identity().String(), func(ctx context.Context, _ memoize.Arg) interface{} {
_, done := event.Start(ctx, "cache.ParseModHandle", tag.URI.Of(modFH.URI()))
defer done()
@ -187,8 +187,6 @@ const (
type modWhyHandle struct {
handle *memoize.Handle
pmh source.ParseModHandle
}
type modWhyData struct {
@ -199,8 +197,8 @@ type modWhyData struct {
err error
}
func (mwh *modWhyHandle) Why(ctx context.Context) (map[string]string, error) {
v, err := mwh.handle.Get(ctx)
func (mwh *modWhyHandle) Why(ctx context.Context, s source.Snapshot) (map[string]string, error) {
v, err := mwh.handle.Get(ctx, s.(*snapshot))
if err != nil {
return nil, err
}
@ -216,26 +214,26 @@ func (s *snapshot) ModWhyHandle(ctx context.Context) (source.ModWhyHandle, error
if err != nil {
return nil, err
}
pmh, err := s.ParseModHandle(ctx, fh)
if err != nil {
return nil, err
}
var (
cfg = s.config(ctx)
tmpMod = s.view.tmpMod
)
cfg := s.config(ctx)
key := modKey{
sessionID: s.view.session.id,
cfg: hashConfig(cfg),
mod: pmh.Mod().Identity().String(),
cfg: hashConfig(s.config(ctx)),
mod: fh.Identity().String(),
view: s.view.root.Filename(),
verb: why,
}
h := s.view.session.cache.store.Bind(key, func(ctx context.Context) interface{} {
ctx, done := event.Start(ctx, "cache.ModHandle", tag.URI.Of(pmh.Mod().URI()))
h := s.view.session.cache.store.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} {
ctx, done := event.Start(ctx, "cache.ModWhyHandle", tag.URI.Of(fh.URI()))
defer done()
parsed, _, _, err := pmh.Parse(ctx)
snapshot := arg.(*snapshot)
pmh, err := snapshot.ParseModHandle(ctx, fh)
if err != nil {
return &modWhyData{err: err}
}
parsed, _, _, err := pmh.Parse(ctx, snapshot)
if err != nil {
return &modWhyData{err: err}
}
@ -248,7 +246,7 @@ func (s *snapshot) ModWhyHandle(ctx context.Context) (source.ModWhyHandle, error
for _, req := range parsed.Require {
args = append(args, req.Mod.Path)
}
_, stdout, err := runGoCommand(ctx, cfg, pmh, tmpMod, "mod", args)
_, stdout, err := runGoCommand(ctx, cfg, pmh, snapshot.view.tmpMod, "mod", args)
if err != nil {
return &modWhyData{err: err}
}
@ -268,15 +266,12 @@ func (s *snapshot) ModWhyHandle(ctx context.Context) (source.ModWhyHandle, error
defer s.mu.Unlock()
s.modWhyHandle = &modWhyHandle{
handle: h,
pmh: pmh,
}
return s.modWhyHandle, nil
}
type modUpgradeHandle struct {
handle *memoize.Handle
pmh source.ParseModHandle
}
type modUpgradeData struct {
@ -286,8 +281,8 @@ type modUpgradeData struct {
err error
}
func (muh *modUpgradeHandle) Upgrades(ctx context.Context) (map[string]string, error) {
v, err := muh.handle.Get(ctx)
func (muh *modUpgradeHandle) Upgrades(ctx context.Context, s source.Snapshot) (map[string]string, error) {
v, err := muh.handle.Get(ctx, s.(*snapshot))
if v == nil {
return nil, err
}
@ -303,26 +298,26 @@ func (s *snapshot) ModUpgradeHandle(ctx context.Context) (source.ModUpgradeHandl
if err != nil {
return nil, err
}
pmh, err := s.ParseModHandle(ctx, fh)
if err != nil {
return nil, err
}
var (
cfg = s.config(ctx)
tmpMod = s.view.tmpMod
)
cfg := s.config(ctx)
key := modKey{
sessionID: s.view.session.id,
cfg: hashConfig(cfg),
mod: pmh.Mod().Identity().String(),
mod: fh.Identity().String(),
view: s.view.root.Filename(),
verb: upgrade,
}
h := s.view.session.cache.store.Bind(key, func(ctx context.Context) interface{} {
ctx, done := event.Start(ctx, "cache.ModUpgradeHandle", tag.URI.Of(pmh.Mod().URI()))
h := s.view.session.cache.store.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} {
ctx, done := event.Start(ctx, "cache.ModUpgradeHandle", tag.URI.Of(fh.URI()))
defer done()
parsed, _, _, err := pmh.Parse(ctx)
snapshot := arg.(*snapshot)
pmh, err := s.ParseModHandle(ctx, fh)
if err != nil {
return &modUpgradeData{err: err}
}
parsed, _, _, err := pmh.Parse(ctx, snapshot)
if err != nil {
return &modUpgradeData{err: err}
}
@ -333,12 +328,12 @@ func (s *snapshot) ModUpgradeHandle(ctx context.Context) (source.ModUpgradeHandl
// Run "go list -mod readonly -u -m all" to be able to see which deps can be
// upgraded without modifying mod file.
args := []string{"-u", "-m", "all"}
if !tmpMod || containsVendor(pmh.Mod().URI()) {
if !snapshot.view.tmpMod || containsVendor(pmh.Mod().URI()) {
// Use -mod=readonly if the module contains a vendor directory
// (see golang/go#38711).
args = append([]string{"-mod", "readonly"}, args...)
}
_, stdout, err := runGoCommand(ctx, cfg, pmh, tmpMod, "list", args)
_, stdout, err := runGoCommand(ctx, cfg, pmh, snapshot.view.tmpMod, "list", args)
if err != nil {
return &modUpgradeData{err: err}
}
@ -370,7 +365,6 @@ func (s *snapshot) ModUpgradeHandle(ctx context.Context) (source.ModUpgradeHandl
defer s.mu.Unlock()
s.modUpgradeHandle = &modUpgradeHandle{
handle: h,
pmh: pmh,
}
return s.modUpgradeHandle, nil
}

View File

@ -8,7 +8,6 @@ import (
"context"
"fmt"
"go/ast"
"go/token"
"io/ioutil"
"sort"
"strconv"
@ -56,8 +55,8 @@ func (mth *modTidyHandle) ParseModHandle() source.ParseModHandle {
return mth.pmh
}
func (mth *modTidyHandle) Tidy(ctx context.Context) ([]source.Error, error) {
v, err := mth.handle.Get(ctx)
func (mth *modTidyHandle) Tidy(ctx context.Context, s source.Snapshot) ([]source.Error, error) {
v, err := mth.handle.Get(ctx, s.(*snapshot))
if err != nil {
return nil, err
}
@ -65,8 +64,8 @@ func (mth *modTidyHandle) Tidy(ctx context.Context) ([]source.Error, error) {
return data.diagnostics, data.err
}
func (mth *modTidyHandle) TidiedContent(ctx context.Context) ([]byte, error) {
v, err := mth.handle.Get(ctx)
func (mth *modTidyHandle) TidiedContent(ctx context.Context, s source.Snapshot) ([]byte, error) {
v, err := mth.handle.Get(ctx, s.(*snapshot))
if err != nil {
return nil, err
}
@ -98,7 +97,7 @@ func (s *snapshot) ModTidyHandle(ctx context.Context) (source.ModTidyHandle, err
}
var workspacePkgs []source.Package
for _, ph := range wsPhs {
pkg, err := ph.Check(ctx)
pkg, err := ph.Check(ctx, s)
if err != nil {
return nil, err
}
@ -117,7 +116,6 @@ func (s *snapshot) ModTidyHandle(ctx context.Context) (source.ModTidyHandle, err
modURI = s.view.modURI
cfg = s.config(ctx)
options = s.view.Options()
fset = s.view.session.cache.fset
)
key := modTidyKey{
sessionID: s.view.session.id,
@ -127,11 +125,13 @@ func (s *snapshot) ModTidyHandle(ctx context.Context) (source.ModTidyHandle, err
gomod: pmh.Mod().Identity().String(),
cfg: hashConfig(cfg),
}
h := s.view.session.cache.store.Bind(key, func(ctx context.Context) interface{} {
h := s.view.session.cache.store.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} {
ctx, done := event.Start(ctx, "cache.ModTidyHandle", tag.URI.Of(modURI))
defer done()
original, m, parseErrors, err := pmh.Parse(ctx)
snapshot := arg.(*snapshot)
original, m, parseErrors, err := pmh.Parse(ctx, snapshot)
if err != nil || len(parseErrors) > 0 {
return &modTidyData{
diagnostics: parseErrors,
@ -191,7 +191,7 @@ func (s *snapshot) ModTidyHandle(ctx context.Context) (source.ModTidyHandle, err
// go.mod file. The fixes will be for the go.mod file, but the
// diagnostics should appear on the import statements in the Go or
// go.mod files.
missingModuleErrs, err := missingModuleErrors(ctx, fset, m, workspacePkgs, ideal.Require, missingDeps, original, options)
missingModuleErrs, err := missingModuleErrors(ctx, snapshot, m, workspacePkgs, ideal.Require, missingDeps, original, options)
if err != nil {
return &modTidyData{err: err}
}
@ -430,7 +430,7 @@ func rangeFromPositions(m *protocol.ColumnMapper, s, e modfile.Position) (protoc
// missingModuleErrors returns diagnostics for each file in each workspace
// package that has dependencies that are not reflected in the go.mod file.
func missingModuleErrors(ctx context.Context, fset *token.FileSet, modMapper *protocol.ColumnMapper, pkgs []source.Package, modules []*modfile.Require, missingMods map[string]*modfile.Require, original *modfile.File, options source.Options) ([]source.Error, error) {
func missingModuleErrors(ctx context.Context, snapshot *snapshot, modMapper *protocol.ColumnMapper, pkgs []source.Package, modules []*modfile.Require, missingMods map[string]*modfile.Require, original *modfile.File, options source.Options) ([]source.Error, error) {
var moduleErrs []source.Error
matchedMissingMods := make(map[*modfile.Require]struct{})
for _, pkg := range pkgs {
@ -462,7 +462,7 @@ func missingModuleErrors(ctx context.Context, fset *token.FileSet, modMapper *pr
}
}
if len(missingPkgs) > 0 {
errs, err := missingModules(ctx, fset, modMapper, pkg, missingPkgs, options)
errs, err := missingModules(ctx, snapshot, modMapper, pkg, missingPkgs, options)
if err != nil {
return nil, err
}
@ -499,10 +499,10 @@ func missingModuleErrors(ctx context.Context, fset *token.FileSet, modMapper *pr
return moduleErrs, nil
}
func missingModules(ctx context.Context, fset *token.FileSet, modMapper *protocol.ColumnMapper, pkg source.Package, missing map[string]*modfile.Require, options source.Options) ([]source.Error, error) {
func missingModules(ctx context.Context, snapshot *snapshot, modMapper *protocol.ColumnMapper, pkg source.Package, missing map[string]*modfile.Require, options source.Options) ([]source.Error, error) {
var errors []source.Error
for _, pgh := range pkg.CompiledGoFiles() {
file, _, m, _, err := pgh.Parse(ctx)
file, _, m, _, err := pgh.Parse(ctx, snapshot.view)
if err != nil {
return nil, err
}
@ -526,7 +526,7 @@ func missingModules(ctx context.Context, fset *token.FileSet, modMapper *protoco
if !ok {
continue
}
spn, err := span.NewRange(fset, imp.Path.Pos(), imp.Path.End()).Span()
spn, err := span.NewRange(snapshot.view.session.cache.fset, imp.Path.Pos(), imp.Path.End()).Span()
if err != nil {
return nil, err
}

View File

@ -65,18 +65,21 @@ func (c *Cache) parseGoHandle(ctx context.Context, fh source.FileHandle, mode so
file: fh.Identity(),
mode: mode,
}
fset := c.fset
h := c.store.Bind(key, func(ctx context.Context) interface{} {
return parseGo(ctx, fset, fh, mode)
parseHandle := c.store.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} {
view := arg.(*View)
return parseGo(ctx, view.session.cache.fset, fh, mode)
})
astHandle := c.store.Bind(astCacheKey(key), func(ctx context.Context, arg memoize.Arg) interface{} {
view := arg.(*View)
return buildASTCache(ctx, view, parseHandle)
})
return &parseGoHandle{
handle: h,
file: fh,
mode: mode,
astCacheHandle: c.store.Bind(astCacheKey(key), func(ctx context.Context) interface{} {
return buildASTCache(ctx, h)
}),
handle: parseHandle,
file: fh,
mode: mode,
astCacheHandle: astHandle,
}
}
@ -92,20 +95,20 @@ func (pgh *parseGoHandle) Mode() source.ParseMode {
return pgh.mode
}
func (pgh *parseGoHandle) Parse(ctx context.Context) (*ast.File, []byte, *protocol.ColumnMapper, error, error) {
data, err := pgh.parse(ctx)
func (pgh *parseGoHandle) Parse(ctx context.Context, v source.View) (*ast.File, []byte, *protocol.ColumnMapper, error, error) {
data, err := pgh.parse(ctx, v.(*View))
if err != nil {
return nil, nil, nil, nil, err
}
return data.ast, data.src, data.mapper, data.parseError, data.err
}
func (pgh *parseGoHandle) parse(ctx context.Context) (*parseGoData, error) {
v, err := pgh.handle.Get(ctx)
func (pgh *parseGoHandle) parse(ctx context.Context, v *View) (*parseGoData, error) {
d, err := pgh.handle.Get(ctx, v)
if err != nil {
return nil, err
}
data, ok := v.(*parseGoData)
data, ok := d.(*parseGoData)
if !ok {
return nil, errors.Errorf("no parsed file for %s", pgh.File().URI())
}
@ -129,13 +132,13 @@ func (pgh *parseGoHandle) cached() (*parseGoData, error) {
return data, nil
}
func (pgh *parseGoHandle) PosToDecl(ctx context.Context) (map[token.Pos]ast.Decl, error) {
v, err := pgh.astCacheHandle.Get(ctx)
func (pgh *parseGoHandle) PosToDecl(ctx context.Context, v source.View) (map[token.Pos]ast.Decl, error) {
d, err := pgh.astCacheHandle.Get(ctx, v.(*View))
if err != nil || v == nil {
return nil, err
}
data := v.(*astCacheData)
data := d.(*astCacheData)
if data.err != nil {
return nil, data.err
}
@ -143,13 +146,13 @@ func (pgh *parseGoHandle) PosToDecl(ctx context.Context) (map[token.Pos]ast.Decl
return data.posToDecl, nil
}
func (pgh *parseGoHandle) PosToField(ctx context.Context) (map[token.Pos]*ast.Field, error) {
v, err := pgh.astCacheHandle.Get(ctx)
if err != nil || v == nil {
func (pgh *parseGoHandle) PosToField(ctx context.Context, v source.View) (map[token.Pos]*ast.Field, error) {
d, err := pgh.astCacheHandle.Get(ctx, v.(*View))
if err != nil || d == nil {
return nil, err
}
data := v.(*astCacheData)
data := d.(*astCacheData)
if data.err != nil {
return nil, data.err
}
@ -168,7 +171,7 @@ type astCacheData struct {
// buildASTCache builds caches to aid in quickly going from the typed
// world to the syntactic world.
func buildASTCache(ctx context.Context, parseHandle *memoize.Handle) *astCacheData {
func buildASTCache(ctx context.Context, view *View, parseHandle *memoize.Handle) *astCacheData {
var (
// path contains all ancestors, including n.
path []ast.Node
@ -176,7 +179,7 @@ func buildASTCache(ctx context.Context, parseHandle *memoize.Handle) *astCacheDa
decls []ast.Decl
)
v, err := parseHandle.Get(ctx)
v, err := parseHandle.Get(ctx, view)
if err != nil || v == nil || v.(*parseGoData).ast == nil {
return &astCacheData{err: err}
}

View File

@ -24,6 +24,7 @@ import (
"golang.org/x/tools/internal/gocommand"
"golang.org/x/tools/internal/lsp/debug/tag"
"golang.org/x/tools/internal/lsp/source"
"golang.org/x/tools/internal/memoize"
"golang.org/x/tools/internal/packagesinternal"
"golang.org/x/tools/internal/span"
"golang.org/x/tools/internal/typesinternal"
@ -31,6 +32,8 @@ import (
)
type snapshot struct {
memoize.Arg // allow as a memoize.Function arg
id uint64
view *View
@ -1004,8 +1007,8 @@ func (s *snapshot) shouldInvalidateMetadata(ctx context.Context, originalFH, cur
return originalFH.URI() == s.view.modURI
}
// Get the original and current parsed files in order to check package name and imports.
original, _, _, _, originalErr := s.view.session.cache.ParseGoHandle(ctx, originalFH, source.ParseHeader).Parse(ctx)
current, _, _, _, currentErr := s.view.session.cache.ParseGoHandle(ctx, currentFH, source.ParseHeader).Parse(ctx)
original, _, _, _, originalErr := s.view.session.cache.ParseGoHandle(ctx, originalFH, source.ParseHeader).Parse(ctx, s.view)
current, _, _, _, currentErr := s.view.session.cache.ParseGoHandle(ctx, currentFH, source.ParseHeader).Parse(ctx, s.view)
if originalErr != nil || currentErr != nil {
return (originalErr == nil) != (currentErr == nil)
}

View File

@ -33,6 +33,8 @@ import (
)
type View struct {
memoize.Arg // allow as a memoize.Function arg
session *Session
id string
@ -143,7 +145,6 @@ type View struct {
type builtinPackageHandle struct {
handle *memoize.Handle
file source.ParseGoHandle
}
type builtinPackageData struct {
@ -300,7 +301,7 @@ func (v *View) BuiltinPackage(ctx context.Context) (source.BuiltinPackage, error
if v.builtin == nil {
return nil, errors.Errorf("no builtin package for view %s", v.name)
}
data, err := v.builtin.handle.Get(ctx)
data, err := v.builtin.handle.Get(ctx, v)
if err != nil {
return nil, err
}
@ -332,14 +333,15 @@ func (v *View) buildBuiltinPackage(ctx context.Context, goFiles []string) error
if err != nil {
return err
}
pgh := v.session.cache.parseGoHandle(ctx, fh, source.ParseFull)
fset := v.session.cache.fset
h := v.session.cache.store.Bind(fh.Identity(), func(ctx context.Context) interface{} {
file, _, _, _, err := pgh.Parse(ctx)
h := v.session.cache.store.Bind(fh.Identity(), func(ctx context.Context, arg memoize.Arg) interface{} {
view := arg.(*View)
pgh := view.session.cache.parseGoHandle(ctx, fh, source.ParseFull)
file, _, _, _, err := pgh.Parse(ctx, view)
if err != nil {
return &builtinPackageData{err: err}
}
pkg, err := ast.NewPackage(fset, map[string]*ast.File{
pkg, err := ast.NewPackage(view.session.cache.fset, map[string]*ast.File{
pgh.File().URI().Filename(): file,
}, nil, nil)
if err != nil {
@ -352,7 +354,6 @@ func (v *View) buildBuiltinPackage(ctx context.Context, goFiles []string) error
})
v.builtin = &builtinPackageHandle{
handle: h,
file: pgh,
}
return nil
}
@ -565,7 +566,7 @@ func (v *View) WorkspaceDirectories(ctx context.Context) ([]string, error) {
if err != nil {
return nil, err
}
parsed, _, _, err := pmh.Parse(ctx)
parsed, _, _, err := pmh.Parse(ctx, v.Snapshot())
if err != nil {
return nil, err
}

View File

@ -461,7 +461,7 @@ func moduleQuickFixes(ctx context.Context, snapshot source.Snapshot, diagnostics
if err != nil {
return nil, err
}
errors, err := mth.Tidy(ctx)
errors, err := mth.Tidy(ctx, snapshot)
if err != nil {
return nil, err
}
@ -515,7 +515,7 @@ func goModTidy(ctx context.Context, snapshot source.Snapshot) (*protocol.CodeAct
return nil, err
}
uri := mth.ParseModHandle().Mod().URI()
_, m, _, err := mth.ParseModHandle().Parse(ctx)
_, m, _, err := mth.ParseModHandle().Parse(ctx, snapshot)
if err != nil {
return nil, err
}
@ -523,7 +523,7 @@ func goModTidy(ctx context.Context, snapshot source.Snapshot) (*protocol.CodeAct
if err != nil {
return nil, err
}
right, err := mth.TidiedContent(ctx)
right, err := mth.TidiedContent(ctx, snapshot)
if err != nil {
return nil, err
}

View File

@ -115,7 +115,7 @@ If you believe this is a mistake, please file an issue: https://github.com/golan
go func(ph source.PackageHandle) {
defer wg.Done()
pkg, err := ph.Check(ctx)
pkg, err := ph.Check(ctx, snapshot)
if err != nil {
event.Error(ctx, "warning: diagnose package", err, tag.Snapshot.Of(snapshot.ID()), tag.Package.Of(ph.ID()))
return

View File

@ -50,7 +50,7 @@ func modLinks(ctx context.Context, snapshot source.Snapshot, fh source.FileHandl
if err != nil {
return nil, err
}
file, m, _, err := pmh.Parse(ctx)
file, m, _, err := pmh.Parse(ctx, snapshot)
if err != nil {
return nil, err
}
@ -109,7 +109,7 @@ func goLinks(ctx context.Context, view source.View, fh source.FileHandle) ([]pro
return nil, err
}
pgh := view.Session().Cache().ParseGoHandle(ctx, fh, source.ParseFull)
file, _, m, _, err := pgh.Parse(ctx)
file, _, m, _, err := pgh.Parse(ctx, view)
if err != nil {
return nil, err
}
@ -142,7 +142,7 @@ func goLinks(ctx context.Context, view source.View, fh source.FileHandle) ([]pro
if view.IsGoPrivatePath(target) {
continue
}
if mod, version, ok := moduleAtVersion(ctx, target, ph); ok && strings.ToLower(view.Options().LinkTarget) == "pkg.go.dev" {
if mod, version, ok := moduleAtVersion(ctx, view.Snapshot(), target, ph); ok && strings.ToLower(view.Options().LinkTarget) == "pkg.go.dev" {
target = strings.Replace(target, mod, mod+"@"+version, 1)
}
// Account for the quotation marks in the positions.
@ -175,8 +175,8 @@ func goLinks(ctx context.Context, view source.View, fh source.FileHandle) ([]pro
return links, nil
}
func moduleAtVersion(ctx context.Context, target string, ph source.PackageHandle) (string, string, bool) {
pkg, err := ph.Check(ctx)
func moduleAtVersion(ctx context.Context, snapshot source.Snapshot, target string, ph source.PackageHandle) (string, string, bool) {
pkg, err := ph.Check(ctx, snapshot)
if err != nil {
return "", "", false
}

View File

@ -32,7 +32,7 @@ func CodeLens(ctx context.Context, snapshot source.Snapshot, uri span.URI) ([]pr
if err != nil {
return nil, err
}
file, m, _, err := pmh.Parse(ctx)
file, m, _, err := pmh.Parse(ctx, snapshot)
if err != nil {
return nil, err
}
@ -40,7 +40,7 @@ func CodeLens(ctx context.Context, snapshot source.Snapshot, uri span.URI) ([]pr
if err != nil {
return nil, err
}
upgrades, err := muh.Upgrades(ctx)
upgrades, err := muh.Upgrades(ctx, snapshot)
if err != nil {
return nil, err
}

View File

@ -44,7 +44,7 @@ func Diagnostics(ctx context.Context, snapshot source.Snapshot) (map[source.File
if err != nil {
return nil, err
}
diagnostics, err := mth.Tidy(ctx)
diagnostics, err := mth.Tidy(ctx, snapshot)
if err != nil {
return nil, err
}
@ -102,7 +102,7 @@ func ExtractGoCommandError(ctx context.Context, snapshot source.Snapshot, fh sou
if err != nil {
return nil, err
}
parsed, m, _, err := pmh.Parse(ctx)
parsed, m, _, err := pmh.Parse(ctx, snapshot)
if err != nil {
return nil, err
}

View File

@ -16,7 +16,7 @@ func Format(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle)
if err != nil {
return nil, err
}
file, m, _, err := pmh.Parse(ctx)
file, m, _, err := pmh.Parse(ctx, snapshot)
if err != nil {
return nil, err
}

View File

@ -30,7 +30,7 @@ func Hover(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle,
if err != nil {
return nil, fmt.Errorf("getting modfile handle: %w", err)
}
file, m, _, err := pmh.Parse(ctx)
file, m, _, err := pmh.Parse(ctx, snapshot)
if err != nil {
return nil, err
}
@ -72,7 +72,7 @@ func Hover(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle,
if err != nil {
return nil, err
}
why, err := mwh.Why(ctx)
why, err := mwh.Why(ctx, snapshot)
if err != nil {
return nil, fmt.Errorf("running go mod why: %w", err)
}

View File

@ -144,7 +144,7 @@ func matchTestFunc(fn *ast.FuncDecl, pkg Package, nameRe *regexp.Regexp, paramID
func goGenerateCodeLens(ctx context.Context, snapshot Snapshot, fh FileHandle) ([]protocol.CodeLens, error) {
pgh := snapshot.View().Session().Cache().ParseGoHandle(ctx, fh, ParseFull)
file, _, m, _, err := pgh.Parse(ctx)
file, _, m, _, err := pgh.Parse(ctx, snapshot.View())
if err != nil {
return nil, err
}
@ -194,7 +194,7 @@ func goGenerateCodeLens(ctx context.Context, snapshot Snapshot, fh FileHandle) (
func regenerateCgoLens(ctx context.Context, snapshot Snapshot, fh FileHandle) ([]protocol.CodeLens, error) {
pgh := snapshot.View().Session().Cache().ParseGoHandle(ctx, fh, ParseFull)
file, _, m, _, err := pgh.Parse(ctx)
file, _, m, _, err := pgh.Parse(ctx, snapshot.View())
if err != nil {
return nil, err
}

View File

@ -193,7 +193,7 @@ func (c *completer) item(ctx context.Context, cand candidate) (CompletionItem, e
return item, nil
}
posToDecl, err := ph.PosToDecl(ctx)
posToDecl, err := ph.PosToDecl(ctx, c.snapshot.View())
if err != nil {
return CompletionItem{}, err
}

View File

@ -137,7 +137,7 @@ func FileDiagnostics(ctx context.Context, snapshot Snapshot, uri span.URI) (File
if err != nil {
return FileIdentity{}, nil, err
}
pkg, err := ph.Check(ctx)
pkg, err := ph.Check(ctx, snapshot)
if err != nil {
return FileIdentity{}, nil, err
}

View File

@ -19,7 +19,7 @@ func FoldingRange(ctx context.Context, snapshot Snapshot, fh FileHandle, lineFol
// TODO(suzmue): consider limiting the number of folding ranges returned, and
// implement a way to prioritize folding ranges in that case.
pgh := snapshot.View().Session().Cache().ParseGoHandle(ctx, fh, ParseFull)
file, _, m, _, err := pgh.Parse(ctx)
file, _, m, _, err := pgh.Parse(ctx, snapshot.View())
if err != nil {
return nil, err
}

View File

@ -27,7 +27,7 @@ func Format(ctx context.Context, snapshot Snapshot, fh FileHandle) ([]protocol.T
defer done()
pgh := snapshot.View().Session().Cache().ParseGoHandle(ctx, fh, ParseFull)
file, _, m, parseErrors, err := pgh.Parse(ctx)
file, _, m, parseErrors, err := pgh.Parse(ctx, snapshot.View())
if err != nil {
return nil, err
}
@ -110,7 +110,7 @@ func computeImportEdits(ctx context.Context, view View, ph ParseGoHandle, option
if err != nil {
return nil, nil, err
}
_, _, origMapper, _, err := ph.Parse(ctx)
_, _, origMapper, _, err := ph.Parse(ctx, view)
if err != nil {
return nil, nil, err
}
@ -145,7 +145,7 @@ func computeOneImportFixEdits(ctx context.Context, view View, ph ParseGoHandle,
if err != nil {
return nil, err
}
_, _, origMapper, _, err := ph.Parse(ctx) // ph.Parse returns values never used
_, _, origMapper, _, err := ph.Parse(ctx, view) // ph.Parse returns values never used
if err != nil {
return nil, err
}

View File

@ -26,7 +26,7 @@ func Highlight(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protoc
if err != nil {
return nil, fmt.Errorf("getting file for Highlight: %w", err)
}
file, _, m, _, err := pgh.Parse(ctx)
file, _, m, _, err := pgh.Parse(ctx, snapshot.View())
if err != nil {
return nil, err
}

View File

@ -293,7 +293,7 @@ func objToDecl(ctx context.Context, v View, srcPkg Package, obj types.Object) (a
if err != nil {
return nil, err
}
posToDecl, err := ph.PosToDecl(ctx)
posToDecl, err := ph.PosToDecl(ctx, v)
if err != nil {
return nil, err
}

View File

@ -95,7 +95,7 @@ func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.
return nil, err
}
for _, ph := range knownPkgs {
pkg, err := ph.Check(ctx)
pkg, err := ph.Check(ctx, s)
if err != nil {
return nil, err
}
@ -212,7 +212,7 @@ func qualifiedObjsAtProtocolPos(ctx context.Context, s Snapshot, fh FileHandle,
// Check all the packages that the file belongs to.
var qualifiedObjs []qualifiedObject
for _, ph := range phs {
searchpkg, err := ph.Check(ctx)
searchpkg, err := ph.Check(ctx, s)
if err != nil {
return nil, err
}

View File

@ -77,7 +77,7 @@ func references(ctx context.Context, s Snapshot, qos []qualifiedObject, includeD
return nil, err
}
for _, ph := range reverseDeps {
pkg, err := ph.Check(ctx)
pkg, err := ph.Check(ctx, s)
if err != nil {
return nil, err
}

View File

@ -204,7 +204,7 @@ func formatVarType(ctx context.Context, s Snapshot, srcpkg Package, srcfile *ast
return types.TypeString(obj.Type(), qf)
}
expr, err := varType(ctx, ph, obj)
expr, err := varType(ctx, s, ph, obj)
if err != nil {
return types.TypeString(obj.Type(), qf)
}
@ -224,8 +224,8 @@ func formatVarType(ctx context.Context, s Snapshot, srcpkg Package, srcfile *ast
}
// varType returns the type expression for a *types.Var.
func varType(ctx context.Context, ph ParseGoHandle, obj *types.Var) (ast.Expr, error) {
posToField, err := ph.PosToField(ctx)
func varType(ctx context.Context, snapshot Snapshot, ph ParseGoHandle, obj *types.Var) (ast.Expr, error) {
posToField, err := ph.PosToField(ctx, snapshot.View())
if err != nil {
return nil, err
}

View File

@ -77,7 +77,7 @@ func getParsedFile(ctx context.Context, snapshot Snapshot, fh FileHandle, select
if err != nil {
return nil, nil, err
}
pkg, err := ph.Check(ctx)
pkg, err := ph.Check(ctx, snapshot)
if err != nil {
return nil, nil, err
}
@ -148,7 +148,7 @@ func IsGenerated(ctx context.Context, snapshot Snapshot, uri span.URI) bool {
return false
}
ph := snapshot.View().Session().Cache().ParseGoHandle(ctx, fh, ParseHeader)
parsed, _, _, _, err := ph.Parse(ctx)
parsed, _, _, _, err := ph.Parse(ctx, snapshot.View())
if err != nil {
return false
}

View File

@ -108,7 +108,7 @@ type PackageHandle interface {
CompiledGoFiles() []span.URI
// Check returns the type-checked Package for the PackageHandle.
Check(ctx context.Context) (Package, error)
Check(ctx context.Context, snapshot Snapshot) (Package, error)
// Cached returns the Package for the PackageHandle if it has already been stored.
Cached() (Package, error)
@ -320,7 +320,7 @@ type ParseGoHandle interface {
// Parse returns the parsed AST for the file.
// If the file is not available, returns nil and an error.
Parse(ctx context.Context) (file *ast.File, src []byte, m *protocol.ColumnMapper, parseErr error, err error)
Parse(ctx context.Context, view View) (file *ast.File, src []byte, m *protocol.ColumnMapper, parseErr error, err error)
// Cached returns the AST for this handle, if it has already been stored.
Cached() (file *ast.File, src []byte, m *protocol.ColumnMapper, parseErr error, err error)
@ -329,12 +329,12 @@ type ParseGoHandle interface {
// to quickly find corresponding *ast.Field node given a *types.Var.
// We must refer to the AST to render type aliases properly when
// formatting signatures and other types.
PosToField(context.Context) (map[token.Pos]*ast.Field, error)
PosToField(ctx context.Context, view View) (map[token.Pos]*ast.Field, error)
// PosToDecl maps certain objects' positions to their surrounding
// ast.Decl. This mapping is used when building the documentation
// string for the objects.
PosToDecl(context.Context) (map[token.Pos]ast.Decl, error)
PosToDecl(ctx context.Context, view View) (map[token.Pos]ast.Decl, error)
}
type ParseModHandle interface {
@ -346,19 +346,19 @@ type ParseModHandle interface {
// Parse returns the parsed go.mod file, a column mapper, and a list of
// parse for the go.mod file.
Parse(ctx context.Context) (*modfile.File, *protocol.ColumnMapper, []Error, error)
Parse(ctx context.Context, snapshot Snapshot) (*modfile.File, *protocol.ColumnMapper, []Error, error)
}
type ModUpgradeHandle interface {
// Upgrades returns the latest versions for each of the module's
// dependencies.
Upgrades(ctx context.Context) (map[string]string, error)
Upgrades(ctx context.Context, snapshot Snapshot) (map[string]string, error)
}
type ModWhyHandle interface {
// Why returns the results of `go mod why` for every dependency of the
// module.
Why(ctx context.Context) (map[string]string, error)
Why(ctx context.Context, snapshot Snapshot) (map[string]string, error)
}
type ModTidyHandle interface {
@ -366,10 +366,10 @@ type ModTidyHandle interface {
ParseModHandle() ParseModHandle
// Tidy returns the results of `go mod tidy` for the module.
Tidy(ctx context.Context) ([]Error, error)
Tidy(ctx context.Context, snapshot Snapshot) ([]Error, error)
// TidiedContent is the content of the tidied go.mod file.
TidiedContent(ctx context.Context) ([]byte, error)
TidiedContent(ctx context.Context, snapshot Snapshot) ([]byte, error)
}
var ErrTmpModfileUnsupported = errors.New("-modfile is unsupported for this Go version")

View File

@ -52,7 +52,7 @@ outer:
return nil, err
}
for _, ph := range knownPkgs {
pkg, err := ph.Check(ctx)
pkg, err := ph.Check(ctx, view.Snapshot())
symbolMatcher := makePackageSymbolMatcher(style, pkg, queryMatcher)
if err != nil {
return nil, err

View File

@ -32,9 +32,13 @@ type Store struct {
entries map[interface{}]uintptr
}
// Arg is a marker interface that can be embedded to indicate a type is
// intended for use as a Function argument.
type Arg interface{ memoizeArg() }
// Function is the type for functions that can be memoized.
// The result must be a pointer.
type Function func(ctx context.Context) interface{}
type Function func(ctx context.Context, arg Arg) interface{}
type state int
@ -203,14 +207,14 @@ func (h *Handle) Cached() interface{} {
// If the value is not yet ready, the underlying function will be invoked.
// This activates the handle, and it will remember the value for as long as it exists.
// If ctx is cancelled, Get returns nil.
func (h *Handle) Get(ctx context.Context) (interface{}, error) {
func (h *Handle) Get(ctx context.Context, arg Arg) (interface{}, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
h.mu.Lock()
switch h.state {
case stateIdle:
return h.run(ctx)
return h.run(ctx, arg)
case stateRunning:
return h.wait(ctx)
case stateCompleted:
@ -222,7 +226,7 @@ func (h *Handle) Get(ctx context.Context) (interface{}, error) {
}
// run starts h.function and returns the result. h.mu must be locked.
func (h *Handle) run(ctx context.Context) (interface{}, error) {
func (h *Handle) run(ctx context.Context, arg Arg) (interface{}, error) {
childCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
h.cancel = cancel
h.state = stateRunning
@ -234,7 +238,7 @@ func (h *Handle) run(ctx context.Context) (interface{}, error) {
if childCtx.Err() != nil {
return
}
v := function(childCtx)
v := function(childCtx, arg)
if childCtx.Err() != nil {
return
}

View File

@ -83,7 +83,7 @@ end @3 = G !fail H
`[1:],
},
} {
s.Bind(test.key, generate(s, test.key)).Get(ctx)
s.Bind(test.key, generate(s, test.key)).Get(ctx, nil)
got := logBuffer.String()
if got != test.want {
t.Errorf("at %q expected:\n%v\ngot:\n%s", test.name, test.want, got)
@ -103,7 +103,7 @@ end @3 = G !fail H
var pins []*memoize.Handle
for _, key := range pinned {
h := s.Bind(key, generate(s, key))
h.Get(ctx)
h.Get(ctx, nil)
pins = append(pins, h)
}
@ -175,7 +175,7 @@ func asValue(v interface{}) *stringOrError {
}
func generate(s *memoize.Store, key interface{}) memoize.Function {
return func(ctx context.Context) interface{} {
return func(ctx context.Context, _ memoize.Arg) interface{} {
name := key.(string)
switch name {
case "":
@ -217,7 +217,7 @@ func joinValues(ctx context.Context, s *memoize.Store, name string, keys ...stri
fmt.Fprintf(w, "start %v\n", name)
value := ""
for _, key := range keys {
i, err := s.Bind(key, generate(s, key)).Get(ctx)
i, err := s.Bind(key, generate(s, key)).Get(ctx, nil)
if err != nil {
return &stringOrError{err: err}
}