diff --git a/internal/lsp/cache/analysis.go b/internal/lsp/cache/analysis.go index 4b437858ef..e196d1c4a3 100644 --- a/internal/lsp/cache/analysis.go +++ b/internal/lsp/cache/analysis.go @@ -137,7 +137,7 @@ func (s *snapshot) actionHandle(ctx context.Context, id PackageID, a *analysis.A } } - handle, release := s.generation.GetHandle(buildActionKey(a, ph), func(ctx context.Context, arg memoize.Arg) interface{} { + handle, release := s.store.Handle(buildActionKey(a, ph), func(ctx context.Context, arg interface{}) interface{} { snapshot := arg.(*snapshot) // Analyze dependencies first. results, err := execAll(ctx, snapshot, deps) @@ -159,7 +159,7 @@ func (s *snapshot) actionHandle(ctx context.Context, id PackageID, a *analysis.A } func (act *actionHandle) analyze(ctx context.Context, snapshot *snapshot) ([]*source.Diagnostic, interface{}, error) { - d, err := act.handle.Get(ctx, snapshot.generation, snapshot) + d, err := snapshot.awaitHandle(ctx, act.handle) if err != nil { return nil, nil, err } @@ -189,7 +189,7 @@ func execAll(ctx context.Context, snapshot *snapshot, actions []*actionHandle) ( for _, act := range actions { act := act g.Go(func() error { - v, err := act.handle.Get(ctx, snapshot.generation, snapshot) + v, err := snapshot.awaitHandle(ctx, act.handle) if err != nil { return err } diff --git a/internal/lsp/cache/check.go b/internal/lsp/cache/check.go index c17288c9e1..4680c6e728 100644 --- a/internal/lsp/cache/check.go +++ b/internal/lsp/cache/check.go @@ -167,7 +167,7 @@ func (s *snapshot) buildPackageHandle(ctx context.Context, id PackageID, mode so // Create a handle for the result of type checking. experimentalKey := s.View().Options().ExperimentalPackageCacheKey key := computePackageKey(m.ID, compiledGoFiles, m, depKeys, mode, experimentalKey) - handle, release := s.generation.GetHandle(key, func(ctx context.Context, arg memoize.Arg) interface{} { + handle, release := s.store.Handle(key, func(ctx context.Context, arg interface{}) interface{} { // TODO(adonovan): eliminate use of arg with this handle. // (In all cases snapshot is equal to the enclosing s.) snapshot := arg.(*snapshot) @@ -286,7 +286,7 @@ func hashConfig(config *packages.Config) source.Hash { } func (ph *packageHandle) check(ctx context.Context, s *snapshot) (*pkg, error) { - v, err := ph.handle.Get(ctx, s.generation, s) + v, err := s.awaitHandle(ctx, ph.handle) if err != nil { return nil, err } @@ -302,8 +302,8 @@ func (ph *packageHandle) ID() string { return string(ph.m.ID) } -func (ph *packageHandle) cached(g *memoize.Generation) (*pkg, error) { - v := ph.handle.Cached(g) +func (ph *packageHandle) cached() (*pkg, error) { + v := ph.handle.Cached() if v == nil { return nil, fmt.Errorf("no cached type information for %s", ph.m.PkgPath) } diff --git a/internal/lsp/cache/imports.go b/internal/lsp/cache/imports.go index f333f700dd..710a1f3407 100644 --- a/internal/lsp/cache/imports.go +++ b/internal/lsp/cache/imports.go @@ -143,11 +143,12 @@ func (s *importsState) populateProcessEnv(ctx context.Context, snapshot *snapsho // Take an extra reference to the snapshot so that its workspace directory // (if any) isn't destroyed while we're using it. - release := snapshot.generation.Acquire() + release := snapshot.Acquire() _, inv, cleanupInvocation, err := snapshot.goCommandInvocation(ctx, source.LoadWorkspace, &gocommand.Invocation{ WorkingDir: snapshot.view.rootURI.Filename(), }) if err != nil { + release() return nil, err } pe.WorkingDir = inv.WorkingDir diff --git a/internal/lsp/cache/mod.go b/internal/lsp/cache/mod.go index 1963feea5a..79b3fd016d 100644 --- a/internal/lsp/cache/mod.go +++ b/internal/lsp/cache/mod.go @@ -39,7 +39,7 @@ func (s *snapshot) ParseMod(ctx context.Context, fh source.FileHandle) (*source. // cache miss? if !hit { - handle, release := s.generation.GetHandle(fh.FileIdentity(), func(ctx context.Context, _ memoize.Arg) interface{} { + handle, release := s.store.Handle(fh.FileIdentity(), func(ctx context.Context, _ interface{}) interface{} { parsed, err := parseModImpl(ctx, fh) return parseModResult{parsed, err} }) @@ -51,7 +51,7 @@ func (s *snapshot) ParseMod(ctx context.Context, fh source.FileHandle) (*source. } // Await result. - v, err := entry.(*memoize.Handle).Get(ctx, s.generation, s) + v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func (s *snapshot) ParseWork(ctx context.Context, fh source.FileHandle) (*source // cache miss? if !hit { - handle, release := s.generation.GetHandle(fh.FileIdentity(), func(ctx context.Context, _ memoize.Arg) interface{} { + handle, release := s.store.Handle(fh.FileIdentity(), func(ctx context.Context, _ interface{}) interface{} { parsed, err := parseWorkImpl(ctx, fh) return parseWorkResult{parsed, err} }) @@ -128,7 +128,7 @@ func (s *snapshot) ParseWork(ctx context.Context, fh source.FileHandle) (*source } // Await result. - v, err := entry.(*memoize.Handle).Get(ctx, s.generation, s) + v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) if err != nil { return nil, err } @@ -241,7 +241,7 @@ func (s *snapshot) ModWhy(ctx context.Context, fh source.FileHandle) (map[string mod: fh.FileIdentity(), view: s.view.rootURI.Filename(), } - handle, release := s.generation.GetHandle(key, func(ctx context.Context, arg memoize.Arg) interface{} { + handle, release := s.store.Handle(key, func(ctx context.Context, arg interface{}) interface{} { why, err := modWhyImpl(ctx, arg.(*snapshot), fh) return modWhyResult{why, err} }) @@ -253,7 +253,7 @@ func (s *snapshot) ModWhy(ctx context.Context, fh source.FileHandle) (map[string } // Await result. - v, err := entry.(*memoize.Handle).Get(ctx, s.generation, s) + v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) if err != nil { return nil, err } diff --git a/internal/lsp/cache/mod_tidy.go b/internal/lsp/cache/mod_tidy.go index 84f369ef3d..b59b4fd883 100644 --- a/internal/lsp/cache/mod_tidy.go +++ b/internal/lsp/cache/mod_tidy.go @@ -29,9 +29,6 @@ import ( ) // modTidyImpl runs "go mod tidy" on a go.mod file, using a cache. -// -// REVIEWERS: what does it mean to cache an operation that has side effects? -// Or are we de-duplicating operations in flight on the same file? func (s *snapshot) ModTidy(ctx context.Context, pm *source.ParsedModule) (*source.TidiedModule, error) { uri := pm.URI if pm.File == nil { @@ -77,6 +74,8 @@ func (s *snapshot) ModTidy(ctx context.Context, pm *source.ParsedModule) (*sourc // There's little reason at to use the shared cache for mod // tidy (and mod why) as their key includes the view and session. + // Its only real value is to de-dup requests in flight, for + // which a singleflight in the View would suffice. // TODO(adonovan): use a simpler cache of promises that // is shared across snapshots. type modTidyKey struct { @@ -96,7 +95,7 @@ func (s *snapshot) ModTidy(ctx context.Context, pm *source.ParsedModule) (*sourc gomod: fh.FileIdentity(), env: hashEnv(s), } - handle, release := s.generation.GetHandle(key, func(ctx context.Context, arg memoize.Arg) interface{} { + handle, release := s.store.Handle(key, func(ctx context.Context, arg interface{}) interface{} { tidied, err := modTidyImpl(ctx, arg.(*snapshot), fh, pm, workspacePkgs) return modTidyResult{tidied, err} }) @@ -108,7 +107,7 @@ func (s *snapshot) ModTidy(ctx context.Context, pm *source.ParsedModule) (*sourc } // Await result. - v, err := entry.(*memoize.Handle).Get(ctx, s.generation, s) + v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) if err != nil { return nil, err } diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go index c8c751f0b2..ef588c6059 100644 --- a/internal/lsp/cache/parse.go +++ b/internal/lsp/cache/parse.go @@ -59,7 +59,7 @@ func (s *snapshot) ParseGo(ctx context.Context, fh source.FileHandle, mode sourc // cache miss? if !hit { - handle, release := s.generation.GetHandle(key, func(ctx context.Context, arg memoize.Arg) interface{} { + handle, release := s.store.Handle(key, func(ctx context.Context, arg interface{}) interface{} { parsed, err := parseGoImpl(ctx, arg.(*snapshot).FileSet(), fh, mode) return parseGoResult{parsed, err} }) @@ -77,7 +77,7 @@ func (s *snapshot) ParseGo(ctx context.Context, fh source.FileHandle, mode sourc } // Await result. - v, err := entry.(*memoize.Handle).Get(ctx, s.generation, s) + v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) if err != nil { return nil, err } @@ -93,7 +93,7 @@ func (s *snapshot) peekParseGoLocked(fh source.FileHandle, mode source.ParseMode if !hit { return nil, nil // no-one has requested this file } - v := entry.(*memoize.Handle).Cached(s.generation) + v := entry.(*memoize.Handle).Cached() if v == nil { return nil, nil // parsing is still in progress } @@ -147,12 +147,12 @@ func (s *snapshot) astCacheData(ctx context.Context, spkg source.Package, pos to // the search Pos.) // // A representative benchmark would help. - astHandle, release := s.generation.GetHandle(astCacheKey{pkgHandle.key, pgf.URI}, func(ctx context.Context, arg memoize.Arg) interface{} { + astHandle, release := s.store.Handle(astCacheKey{pkgHandle.key, pgf.URI}, func(ctx context.Context, arg interface{}) interface{} { return buildASTCache(pgf) }) defer release() - d, err := astHandle.Get(ctx, s.generation, s) + d, err := s.awaitHandle(ctx, astHandle) if err != nil { return nil, err } diff --git a/internal/lsp/cache/session.go b/internal/lsp/cache/session.go index 9ea612a3ab..a46b7928c7 100644 --- a/internal/lsp/cache/session.go +++ b/internal/lsp/cache/session.go @@ -231,7 +231,7 @@ func (s *Session) createView(ctx context.Context, name string, folder span.URI, backgroundCtx: backgroundCtx, cancel: cancel, initializeOnce: &sync.Once{}, - generation: s.cache.store.Generation(generationName(v, 0)), + store: &s.cache.store, packages: newPackagesMap(), meta: &metadataGraph{}, files: newFilesMap(), @@ -254,12 +254,28 @@ func (s *Session) createView(ctx context.Context, name string, folder span.URI, initCtx, initCancel := context.WithCancel(xcontext.Detach(ctx)) v.initCancelFirstAttempt = initCancel snapshot := v.snapshot - release := snapshot.generation.Acquire() + + // Acquire both references before the possibility + // of releasing either one, to avoid premature + // destruction if initialize returns quickly. + // + // TODO(adonovan): our reference counting discipline is not sound: + // the count is initially zero and incremented/decremented by + // acquire/release, but there is a race between object birth + // and the first call to acquire during which the snapshot may be + // destroyed. + // + // In most systems, an object is born with a count of 1 and + // destroyed by any decref that brings the count to zero. + // We should do that too. + release1 := snapshot.Acquire() + release2 := snapshot.Acquire() go func() { - defer release() + defer release2() snapshot.initialize(initCtx, true) }() - return v, snapshot, snapshot.generation.Acquire(), nil + + return v, snapshot, release1, nil } // View returns the view by name. @@ -539,6 +555,8 @@ func (s *Session) ExpandModificationsToDirectories(ctx context.Context, changes defer release() snapshots = append(snapshots, snapshot) } + // TODO(adonovan): opt: release lock here. + knownDirs := knownDirectories(ctx, snapshots) defer knownDirs.Destroy() diff --git a/internal/lsp/cache/snapshot.go b/internal/lsp/cache/snapshot.go index c228db9655..fa71fbd8a8 100644 --- a/internal/lsp/cache/snapshot.go +++ b/internal/lsp/cache/snapshot.go @@ -14,6 +14,7 @@ import ( "go/types" "io" "io/ioutil" + "log" "os" "path/filepath" "regexp" @@ -22,6 +23,8 @@ import ( "strconv" "strings" "sync" + "sync/atomic" + "unsafe" "golang.org/x/mod/modfile" "golang.org/x/mod/module" @@ -42,16 +45,16 @@ import ( ) type snapshot struct { - memoize.Arg // allow as a memoize.Function arg - id uint64 view *View cancel func() backgroundCtx context.Context - // the cache generation that contains the data for this snapshot. - generation *memoize.Generation + store *memoize.Store // cache of handles shared by all snapshots + + refcount sync.WaitGroup // number of references + destroyedBy *string // atomically set to non-nil in Destroy once refcount = 0 // The snapshot's initialization state is controlled by the fields below. // @@ -148,6 +151,22 @@ type snapshot struct { unprocessedSubdirChanges []*fileChange } +var _ memoize.RefCounted = (*snapshot)(nil) // snapshots are reference-counted + +// Acquire prevents the snapshot from being destroyed until the returned function is called. +func (s *snapshot) Acquire() func() { + type uP = unsafe.Pointer + if destroyedBy := atomic.LoadPointer((*uP)(uP(&s.destroyedBy))); destroyedBy != nil { + log.Panicf("%d: acquire() after Destroy(%q)", s.id, *(*string)(destroyedBy)) + } + s.refcount.Add(1) + return s.refcount.Done +} + +func (s *snapshot) awaitHandle(ctx context.Context, h *memoize.Handle) (interface{}, error) { + return h.Get(ctx, s) +} + type packageKey struct { mode source.ParseMode id PackageID @@ -159,7 +178,16 @@ type actionKey struct { } func (s *snapshot) Destroy(destroyedBy string) { - s.generation.Destroy(destroyedBy) + // Wait for all leases to end before commencing destruction. + s.refcount.Wait() + + // Report bad state as a debugging aid. + // Not foolproof: another thread could acquire() at this moment. + type uP = unsafe.Pointer // looking forward to generics... + if old := atomic.SwapPointer((*uP)(uP(&s.destroyedBy)), uP(&destroyedBy)); old != nil { + log.Panicf("%d: Destroy(%q) after Destroy(%q)", s.id, destroyedBy, *(*string)(old)) + } + s.packages.Destroy() s.isActivePackageCache.Destroy() s.actions.Destroy() @@ -355,6 +383,7 @@ func (s *snapshot) RunGoCommands(ctx context.Context, allowNetwork bool, wd stri return true, modBytes, sumBytes, nil } +// TODO(adonovan): remove unused cleanup mechanism. func (s *snapshot) goCommandInvocation(ctx context.Context, flags source.InvocationFlags, inv *gocommand.Invocation) (tmpURI span.URI, updatedInv *gocommand.Invocation, cleanup func(), err error) { s.view.optionsMu.Lock() allowModfileModificationOption := s.view.options.AllowModfileModifications @@ -1092,7 +1121,7 @@ func (s *snapshot) CachedImportPaths(ctx context.Context) (map[string]source.Pac results := map[string]source.Package{} s.packages.Range(func(key packageKey, ph *packageHandle) { - cachedPkg, err := ph.cached(s.generation) + cachedPkg, err := ph.cached() if err != nil { return } @@ -1645,10 +1674,6 @@ func inVendor(uri span.URI) bool { return strings.Contains(split[1], "/") } -func generationName(v *View, snapshotID uint64) string { - return fmt.Sprintf("v%v/%v", v.id, snapshotID) -} - // unappliedChanges is a file source that handles an uncloned snapshot. type unappliedChanges struct { originalSnapshot *snapshot @@ -1675,11 +1700,10 @@ func (s *snapshot) clone(ctx, bgCtx context.Context, changes map[span.URI]*fileC s.mu.Lock() defer s.mu.Unlock() - newGen := s.view.session.cache.store.Generation(generationName(s.view, s.id+1)) bgCtx, cancel := context.WithCancel(bgCtx) result := &snapshot{ id: s.id + 1, - generation: newGen, + store: s.store, view: s.view, backgroundCtx: bgCtx, cancel: cancel, diff --git a/internal/lsp/cache/symbols.go b/internal/lsp/cache/symbols.go index 4cbf858902..b562d5bbdd 100644 --- a/internal/lsp/cache/symbols.go +++ b/internal/lsp/cache/symbols.go @@ -35,7 +35,7 @@ func (s *snapshot) symbolize(ctx context.Context, fh source.FileHandle) ([]sourc if !hit { type symbolHandleKey source.Hash key := symbolHandleKey(fh.FileIdentity().Hash) - handle, release := s.generation.GetHandle(key, func(_ context.Context, arg memoize.Arg) interface{} { + handle, release := s.store.Handle(key, func(_ context.Context, arg interface{}) interface{} { symbols, err := symbolizeImpl(arg.(*snapshot), fh) return symbolizeResult{symbols, err} }) @@ -48,7 +48,7 @@ func (s *snapshot) symbolize(ctx context.Context, fh source.FileHandle) ([]sourc } // Await result. - v, err := entry.(*memoize.Handle).Get(ctx, s.generation, s) + v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) if err != nil { return nil, err } diff --git a/internal/lsp/cache/view.go b/internal/lsp/cache/view.go index 1810f6e641..f95c475921 100644 --- a/internal/lsp/cache/view.go +++ b/internal/lsp/cache/view.go @@ -594,7 +594,7 @@ func (v *View) getSnapshot() (*snapshot, func()) { if v.snapshot == nil { panic("getSnapshot called after shutdown") } - return v.snapshot, v.snapshot.generation.Acquire() + return v.snapshot, v.snapshot.Acquire() } func (s *snapshot) initialize(ctx context.Context, firstAttempt bool) { @@ -734,7 +734,7 @@ func (v *View) invalidateContent(ctx context.Context, changes map[span.URI]*file v.snapshot = oldSnapshot.clone(ctx, v.baseCtx, changes, forceReloadMetadata) go oldSnapshot.Destroy("View.invalidateContent") - return v.snapshot, v.snapshot.generation.Acquire() + return v.snapshot, v.snapshot.Acquire() } func (s *Session) getWorkspaceInformation(ctx context.Context, folder span.URI, options *source.Options) (*workspaceInformation, error) { diff --git a/internal/lsp/command.go b/internal/lsp/command.go index 862af6088e..cd4c727310 100644 --- a/internal/lsp/command.go +++ b/internal/lsp/command.go @@ -691,7 +691,7 @@ func (c *commandHandler) GenerateGoplsMod(ctx context.Context, args command.URIA if err != nil { return fmt.Errorf("formatting mod file: %w", err) } - filename := filepath.Join(snapshot.View().Folder().Filename(), "gopls.mod") + filename := filepath.Join(v.Folder().Filename(), "gopls.mod") if err := ioutil.WriteFile(filename, content, 0644); err != nil { return fmt.Errorf("writing mod file: %w", err) } diff --git a/internal/lsp/general.go b/internal/lsp/general.go index 385a04a25f..06633acb0c 100644 --- a/internal/lsp/general.go +++ b/internal/lsp/general.go @@ -474,8 +474,7 @@ func (s *Server) beginFileRequest(ctx context.Context, pURI protocol.DocumentURI release() return nil, nil, false, func() {}, err } - kind := snapshot.View().FileKind(fh) - if expectKind != source.UnknownKind && kind != expectKind { + if expectKind != source.UnknownKind && view.FileKind(fh) != expectKind { // Wrong kind of file. Nothing to do. release() return nil, nil, false, func() {}, nil diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go index caf1850585..d7e212a121 100644 --- a/internal/lsp/source/view.go +++ b/internal/lsp/source/view.go @@ -260,10 +260,14 @@ type View interface { // original one will be. SetOptions(context.Context, *Options) (View, error) - // Snapshot returns the current snapshot for the view. + // Snapshot returns the current snapshot for the view, and a + // release function that must be called when the Snapshot is + // no longer needed. Snapshot(ctx context.Context) (Snapshot, func()) - // Rebuild rebuilds the current view, replacing the original view in its session. + // Rebuild rebuilds the current view, replacing the original + // view in its session. It returns a Snapshot and a release + // function that must be called when the Snapshot is no longer needed. Rebuild(ctx context.Context) (Snapshot, func(), error) // IsGoPrivatePath reports whether target is a private import path, as identified @@ -348,7 +352,8 @@ type Session interface { // NewView creates a new View, returning it and its first snapshot. If a // non-empty tempWorkspace directory is provided, the View will record a copy // of its gopls workspace module in that directory, so that client tooling - // can execute in the same main module. + // can execute in the same main module. It returns a release + // function that must be called when the Snapshot is no longer needed. NewView(ctx context.Context, name string, folder span.URI, options *Options) (View, Snapshot, func(), error) // Cache returns the cache that created this session, for debugging only. @@ -372,6 +377,8 @@ type Session interface { // DidModifyFile reports a file modification to the session. It returns // the new snapshots after the modifications have been applied, paired with // the affected file URIs for those snapshots. + // On success, it returns a list of release functions that + // must be called when the snapshots are no longer needed. DidModifyFiles(ctx context.Context, changes []FileModification) (map[Snapshot][]span.URI, []func(), error) // ExpandModificationsToDirectories returns the set of changes with the diff --git a/internal/memoize/memoize.go b/internal/memoize/memoize.go index 4b84410d50..2db7945428 100644 --- a/internal/memoize/memoize.go +++ b/internal/memoize/memoize.go @@ -5,13 +5,14 @@ // Package memoize supports memoizing the return values of functions with // idempotent results that are expensive to compute. // -// To use this package, build a store and use it to acquire handles with the -// Bind method. +// To use this package, create a Store, call its Handle method to +// acquire a handle to (aka a "promise" of) the future result of a +// function, and call Handle.Get to obtain the result. Get may block +// if the function has not finished (or started). package memoize import ( "context" - "flag" "fmt" "reflect" "runtime/trace" @@ -21,107 +22,44 @@ import ( "golang.org/x/tools/internal/xcontext" ) -var ( - panicOnDestroyed = flag.Bool("memoize_panic_on_destroyed", false, - "Panic when a destroyed generation is read rather than returning an error. "+ - "Panicking may make it easier to debug lifetime errors, especially when "+ - "used with GOTRACEBACK=crash to see all running goroutines.") -) - // Store binds keys to functions, returning handles that can be used to access // the functions results. type Store struct { handlesMu sync.Mutex // lock ordering: Store.handlesMu before Handle.mu handles map[interface{}]*Handle - // handles which are bound to generations for GC purposes. - // (It is the subset of values of 'handles' with trackGenerations enabled.) - boundHandles map[*Handle]struct{} } -// Generation creates a new Generation associated with s. Destroy must be -// called on the returned Generation once it is no longer in use. name is -// for debugging purposes only. -func (s *Store) Generation(name string) *Generation { - return &Generation{store: s, name: name} +// A RefCounted is a value whose functional lifetime is determined by +// reference counting. +// +// Its Acquire method is called before the Function is invoked, and +// the corresponding release is called when the Function returns. +// Usually both events happen within a single call to Get, so Get +// would be fine with a "borrowed" reference, but if the context is +// cancelled, Get may return before the Function is complete, causing +// the argument to escape, and potential premature destruction of the +// value. For a reference-counted type, this requires a pair of +// increment/decrement operations to extend its life. +type RefCounted interface { + // Acquire prevents the value from being destroyed until the + // returned function is called. + Acquire() func() } -// A Generation is a logical point in time of the cache life-cycle. Cache -// entries associated with a Generation will not be removed until the -// Generation is destroyed. -type Generation struct { - // destroyed is 1 after the generation is destroyed. Atomic. - destroyed uint32 - store *Store - name string - // destroyedBy describes the caller that togged destroyed from 0 to 1. - destroyedBy string - // wg tracks the reference count of this generation. - wg sync.WaitGroup -} - -// Destroy waits for all operations referencing g to complete, then removes -// all references to g from cache entries. Cache entries that no longer -// reference any non-destroyed generation are removed. Destroy must be called -// exactly once for each generation, and destroyedBy describes the caller. -func (g *Generation) Destroy(destroyedBy string) { - g.wg.Wait() - - prevDestroyedBy := g.destroyedBy - g.destroyedBy = destroyedBy - if ok := atomic.CompareAndSwapUint32(&g.destroyed, 0, 1); !ok { - panic("Destroy on generation " + g.name + " already destroyed by " + prevDestroyedBy) - } - - g.store.handlesMu.Lock() - defer g.store.handlesMu.Unlock() - for h := range g.store.boundHandles { - h.mu.Lock() - if _, ok := h.generations[g]; ok { - delete(h.generations, g) // delete even if it's dead, in case of dangling references to the entry. - if len(h.generations) == 0 { - h.state = stateDestroyed - delete(g.store.handles, h.key) - if h.trackGenerations { - delete(g.store.boundHandles, h) - } - } - } - h.mu.Unlock() - } -} - -// Acquire creates a new reference to g, and returns a func to release that -// reference. -func (g *Generation) Acquire() func() { - destroyed := atomic.LoadUint32(&g.destroyed) - if destroyed != 0 { - panic("acquire on generation " + g.name + " destroyed by " + g.destroyedBy) - } - g.wg.Add(1) - return g.wg.Done -} - -// Arg is a marker interface that can be embedded to indicate a type is -// intended for use as a Function argument. -type Arg interface{ memoizeArg() } - // Function is the type for functions that can be memoized. -// The result must be a pointer. -type Function func(ctx context.Context, arg Arg) interface{} +// +// If the arg is a RefCounted, its Acquire/Release operations are called. +type Function func(ctx context.Context, arg interface{}) interface{} type state int -// TODO(rfindley): remove stateDestroyed; Handles should not need to know -// whether or not they have been destroyed. -// -// TODO(rfindley): also consider removing stateIdle. Why create a handle if you +// TODO(rfindley): consider removing stateIdle. Why create a handle if you // aren't certain you're going to need its result? And if you know you need its // result, why wait to begin computing it? const ( stateIdle = iota stateRunning stateCompleted - stateDestroyed ) // Handle is returned from a store when a key is bound to a function. @@ -136,19 +74,10 @@ const ( // they decrement waiters. If it drops to zero, the inner context is cancelled, // computation is abandoned, and state resets to idle to start the process over // again. -// -// Handles may be tracked by generations, or directly reference counted, as -// determined by the trackGenerations field. See the field comments for more -// information about the differences between these two forms. -// -// TODO(rfindley): eliminate generational handles. type Handle struct { key interface{} mu sync.Mutex // lock ordering: Store.handlesMu before Handle.mu - // generations is the set of generations in which this handle is valid. - generations map[*Generation]struct{} - state state // done is set in running state, and closed when exiting it. done chan struct{} @@ -161,90 +90,51 @@ type Handle struct { // value is set in completed state. value interface{} - // If trackGenerations is set, this handle tracks generations in which it - // is valid, via the generations field. Otherwise, it is explicitly reference - // counted via the refCounter field. - trackGenerations bool - refCounter int32 + refcount int32 // accessed using atomic load/store } -// Bind returns a "generational" handle for the given key and function. +// Handle returns a reference-counted handle for the future result of +// calling the specified function. Calls to Handle with the same key +// return the same handle, and all calls to Handle.Get on a given +// handle return the same result but the function is called at most once. // -// Each call to bind will return the same handle if it is already bound. Bind -// will always return a valid handle, creating one if needed. Each key can -// only have one handle at any given time. The value will be held at least -// until the associated generation is destroyed. Bind does not cause the value -// to be generated. -// -// It is responsibility of the caller to call Inherit on the handler whenever -// it should still be accessible by a next generation. -func (g *Generation) Bind(key interface{}, function Function) *Handle { - return g.getHandle(key, function, true) -} +// The caller must call the returned function to decrement the +// handle's reference count when it is no longer needed. +func (store *Store) Handle(key interface{}, function Function) (*Handle, func()) { + if function == nil { + panic("nil function") + } + + store.handlesMu.Lock() + h, ok := store.handles[key] + if !ok { + // new handle + h = &Handle{ + key: key, + function: function, + refcount: 1, + } + + if store.handles == nil { + store.handles = map[interface{}]*Handle{} + } + store.handles[key] = h + } else { + // existing handle + atomic.AddInt32(&h.refcount, 1) + } + store.handlesMu.Unlock() -// GetHandle returns a "reference-counted" handle for the given key -// and function with similar properties and behavior as Bind. Unlike -// Bind, it returns a release callback which must be called once the -// handle is no longer needed. -func (g *Generation) GetHandle(key interface{}, function Function) (*Handle, func()) { - h := g.getHandle(key, function, false) - store := g.store release := func() { - // Acquire store.handlesMu before mutating refCounter - store.handlesMu.Lock() - defer store.handlesMu.Unlock() - - h.mu.Lock() - defer h.mu.Unlock() - - h.refCounter-- - if h.refCounter == 0 { - // Don't mark destroyed: for reference counted handles we can't know when - // they are no longer reachable from runnable goroutines. For example, - // gopls could have a current operation that is using a packageHandle. - // Destroying the handle here would cause that operation to hang. + if atomic.AddInt32(&h.refcount, -1) == 0 { + store.handlesMu.Lock() delete(store.handles, h.key) + store.handlesMu.Unlock() } } return h, release } -func (g *Generation) getHandle(key interface{}, function Function, trackGenerations bool) *Handle { - // panic early if the function is nil - // it would panic later anyway, but in a way that was much harder to debug - if function == nil { - panic("the function passed to bind must not be nil") - } - if atomic.LoadUint32(&g.destroyed) != 0 { - panic("operation on generation " + g.name + " destroyed by " + g.destroyedBy) - } - g.store.handlesMu.Lock() - defer g.store.handlesMu.Unlock() - h, ok := g.store.handles[key] - if !ok { - h = &Handle{ - key: key, - function: function, - trackGenerations: trackGenerations, - } - if trackGenerations { - if g.store.boundHandles == nil { - g.store.boundHandles = map[*Handle]struct{}{} - } - h.generations = make(map[*Generation]struct{}, 1) - g.store.boundHandles[h] = struct{}{} - } - - if g.store.handles == nil { - g.store.handles = map[interface{}]*Handle{} - } - g.store.handles[key] = h - } - - h.incrementRef(g) - return h -} - // Stats returns the number of each type of value in the store. func (s *Store) Stats() map[reflect.Type]int { result := map[reflect.Type]int{} @@ -278,53 +168,13 @@ func (s *Store) DebugOnlyIterate(f func(k, v interface{})) { } } -// Inherit makes h valid in generation g. It is concurrency-safe. -func (g *Generation) Inherit(h *Handle) { - if atomic.LoadUint32(&g.destroyed) != 0 { - panic("inherit on generation " + g.name + " destroyed by " + g.destroyedBy) - } - if !h.trackGenerations { - panic("called Inherit on handle not created by Generation.Bind") - } - - h.incrementRef(g) -} - -func (h *Handle) incrementRef(g *Generation) { - h.mu.Lock() - defer h.mu.Unlock() - - if h.state == stateDestroyed { - panic(fmt.Sprintf("inheriting destroyed handle %#v (type %T) into generation %v", h.key, h.key, g.name)) - } - - if h.trackGenerations { - h.generations[g] = struct{}{} - } else { - h.refCounter++ - } -} - -// hasRefLocked reports whether h is valid in generation g. h.mu must be held. -func (h *Handle) hasRefLocked(g *Generation) bool { - if !h.trackGenerations { - return true - } - - _, ok := h.generations[g] - return ok -} - // Cached returns the value associated with a handle. // // It will never cause the value to be generated. // It will return the cached value, if present. -func (h *Handle) Cached(g *Generation) interface{} { +func (h *Handle) Cached() interface{} { h.mu.Lock() defer h.mu.Unlock() - if !h.hasRefLocked(g) { - return nil - } if h.state == stateCompleted { return h.value } @@ -334,54 +184,39 @@ func (h *Handle) Cached(g *Generation) interface{} { // Get returns the value associated with a handle. // // If the value is not yet ready, the underlying function will be invoked. -// If ctx is cancelled, Get returns nil. -func (h *Handle) Get(ctx context.Context, g *Generation, arg Arg) (interface{}, error) { - release := g.Acquire() - defer release() - +// If ctx is cancelled, Get returns (nil, Canceled). +func (h *Handle) Get(ctx context.Context, arg interface{}) (interface{}, error) { if ctx.Err() != nil { return nil, ctx.Err() } h.mu.Lock() - if !h.hasRefLocked(g) { - h.mu.Unlock() - - err := fmt.Errorf("reading key %#v: generation %v is not known", h.key, g.name) - if *panicOnDestroyed && ctx.Err() != nil { - panic(err) - } - return nil, err - } switch h.state { case stateIdle: - return h.run(ctx, g, arg) + return h.run(ctx, arg) case stateRunning: return h.wait(ctx) case stateCompleted: defer h.mu.Unlock() return h.value, nil - case stateDestroyed: - h.mu.Unlock() - err := fmt.Errorf("Get on destroyed entry %#v (type %T) in generation %v", h.key, h.key, g.name) - if *panicOnDestroyed { - panic(err) - } - return nil, err default: panic("unknown state") } } // run starts h.function and returns the result. h.mu must be locked. -func (h *Handle) run(ctx context.Context, g *Generation, arg Arg) (interface{}, error) { +func (h *Handle) run(ctx context.Context, arg interface{}) (interface{}, error) { childCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) h.cancel = cancel h.state = stateRunning h.done = make(chan struct{}) function := h.function // Read under the lock - // Make sure that the generation isn't destroyed while we're running in it. - release := g.Acquire() + // Make sure that the argument isn't destroyed while we're running in it. + release := func() {} + if rc, ok := arg.(RefCounted); ok { + release = rc.Acquire() + } + go func() { trace.WithRegion(childCtx, fmt.Sprintf("Handle.run %T", h.key), func() { defer release() diff --git a/internal/memoize/memoize_test.go b/internal/memoize/memoize_test.go index 48bb181173..bde02bf613 100644 --- a/internal/memoize/memoize_test.go +++ b/internal/memoize/memoize_test.go @@ -6,7 +6,6 @@ package memoize_test import ( "context" - "strings" "sync" "testing" "time" @@ -15,90 +14,53 @@ import ( ) func TestGet(t *testing.T) { - s := &memoize.Store{} - g := s.Generation("x") + var store memoize.Store evaled := 0 - h := g.Bind("key", func(context.Context, memoize.Arg) interface{} { + h, release := store.Handle("key", func(context.Context, interface{}) interface{} { evaled++ return "res" }) - expectGet(t, h, g, "res") - expectGet(t, h, g, "res") + defer release() + expectGet(t, h, "res") + expectGet(t, h, "res") if evaled != 1 { t.Errorf("got %v calls to function, wanted 1", evaled) } } -func expectGet(t *testing.T, h *memoize.Handle, g *memoize.Generation, wantV interface{}) { +func expectGet(t *testing.T, h *memoize.Handle, wantV interface{}) { t.Helper() - gotV, gotErr := h.Get(context.Background(), g, nil) + gotV, gotErr := h.Get(context.Background(), nil) if gotV != wantV || gotErr != nil { t.Fatalf("Get() = %v, %v, wanted %v, nil", gotV, gotErr, wantV) } } -func expectGetError(t *testing.T, h *memoize.Handle, g *memoize.Generation, substr string) { - gotV, gotErr := h.Get(context.Background(), g, nil) - if gotErr == nil || !strings.Contains(gotErr.Error(), substr) { - t.Fatalf("Get() = %v, %v, wanted err %q", gotV, gotErr, substr) - } -} - -func TestGenerations(t *testing.T) { - s := &memoize.Store{} - // Evaluate key in g1. - g1 := s.Generation("g1") - h1 := g1.Bind("key", func(context.Context, memoize.Arg) interface{} { return "res" }) - expectGet(t, h1, g1, "res") - - // Get key in g2. It should inherit the value from g1. - g2 := s.Generation("g2") - h2 := g2.Bind("key", func(context.Context, memoize.Arg) interface{} { - t.Fatal("h2 should not need evaluation") - return "error" - }) - expectGet(t, h2, g2, "res") - - // With g1 destroyed, g2 should still work. - g1.Destroy("TestGenerations") - expectGet(t, h2, g2, "res") - - // With all generations destroyed, key should be re-evaluated. - g2.Destroy("TestGenerations") - g3 := s.Generation("g3") - h3 := g3.Bind("key", func(context.Context, memoize.Arg) interface{} { return "new res" }) - expectGet(t, h3, g3, "new res") -} - func TestHandleRefCounting(t *testing.T) { - s := &memoize.Store{} - g1 := s.Generation("g1") + var store memoize.Store v1 := false v2 := false - h1, release1 := g1.GetHandle("key1", func(context.Context, memoize.Arg) interface{} { + h1, release1 := store.Handle("key1", func(context.Context, interface{}) interface{} { return &v1 }) - h2, release2 := g1.GetHandle("key2", func(context.Context, memoize.Arg) interface{} { + h2, release2 := store.Handle("key2", func(context.Context, interface{}) interface{} { return &v2 }) - expectGet(t, h1, g1, &v1) - expectGet(t, h2, g1, &v2) + expectGet(t, h1, &v1) + expectGet(t, h2, &v2) - g2 := s.Generation("g2") - expectGet(t, h1, g2, &v1) - g1.Destroy("by test") - expectGet(t, h2, g2, &v2) + expectGet(t, h1, &v1) + expectGet(t, h2, &v2) - h2Copy, release2Copy := g2.GetHandle("key2", func(context.Context, memoize.Arg) interface{} { + h2Copy, release2Copy := store.Handle("key2", func(context.Context, interface{}) interface{} { return &v1 }) if h2 != h2Copy { t.Error("NewHandle returned a new value while old is not destroyed yet") } - expectGet(t, h2Copy, g2, &v2) - g2.Destroy("by test") + expectGet(t, h2Copy, &v2) release2() if got, want := v2, false; got != want { @@ -110,27 +72,23 @@ func TestHandleRefCounting(t *testing.T) { } release1() - g3 := s.Generation("g3") - h2Copy, release2Copy = g3.GetHandle("key2", func(context.Context, memoize.Arg) interface{} { + h2Copy, release2Copy = store.Handle("key2", func(context.Context, interface{}) interface{} { return &v2 }) if h2 == h2Copy { t.Error("NewHandle returned previously destroyed value") } release2Copy() - g3.Destroy("by test") } func TestHandleDestroyedWhileRunning(t *testing.T) { - // Test that calls to Handle.Get return even if the handle is destroyed while - // running. + // Test that calls to Handle.Get return even if the handle is destroyed while running. - s := &memoize.Store{} - g := s.Generation("g") + var store memoize.Store c := make(chan int) var v int - h, release := g.GetHandle("key", func(ctx context.Context, _ memoize.Arg) interface{} { + h, release := store.Handle("key", func(ctx context.Context, _ interface{}) interface{} { <-c <-c if err := ctx.Err(); err != nil { @@ -147,7 +105,7 @@ func TestHandleDestroyedWhileRunning(t *testing.T) { var got interface{} var err error go func() { - got, err = h.Get(ctx, g, nil) + got, err = h.Get(ctx, nil) wg.Done() }()