diff --git a/internal/lsp/cache/analysis.go b/internal/lsp/cache/analysis.go index b277324a0a..ee80bbcd52 100644 --- a/internal/lsp/cache/analysis.go +++ b/internal/lsp/cache/analysis.go @@ -70,7 +70,7 @@ type actionHandleKey source.Hash // package (as different analyzers are applied, either in sequence or // parallel), and across packages (as dependencies are analyzed). type actionHandle struct { - handle *memoize.Handle + promise *memoize.Promise analyzer *analysis.Analyzer pkg *pkg @@ -155,7 +155,7 @@ func (s *snapshot) actionHandle(ctx context.Context, id PackageID, a *analysis.A } } - handle, release := s.store.Handle(buildActionKey(a, ph), func(ctx context.Context, arg interface{}) interface{} { + promise, release := s.store.Promise(buildActionKey(a, ph), func(ctx context.Context, arg interface{}) interface{} { snapshot := arg.(*snapshot) // Analyze dependencies first. results, err := execAll(ctx, snapshot, deps) @@ -170,7 +170,7 @@ func (s *snapshot) actionHandle(ctx context.Context, id PackageID, a *analysis.A ah := &actionHandle{ analyzer: a, pkg: pkg, - handle: handle, + promise: promise, } s.mu.Lock() @@ -188,7 +188,7 @@ func (s *snapshot) actionHandle(ctx context.Context, id PackageID, a *analysis.A } func (act *actionHandle) analyze(ctx context.Context, snapshot *snapshot) ([]*source.Diagnostic, interface{}, error) { - d, err := snapshot.awaitHandle(ctx, act.handle) + d, err := snapshot.awaitPromise(ctx, act.promise) if err != nil { return nil, nil, err } @@ -218,7 +218,7 @@ func execAll(ctx context.Context, snapshot *snapshot, actions []*actionHandle) ( for _, act := range actions { act := act g.Go(func() error { - v, err := snapshot.awaitHandle(ctx, act.handle) + v, err := snapshot.awaitPromise(ctx, act.promise) if err != nil { return err } diff --git a/internal/lsp/cache/check.go b/internal/lsp/cache/check.go index 79a6ff3eeb..abc1724572 100644 --- a/internal/lsp/cache/check.go +++ b/internal/lsp/cache/check.go @@ -44,7 +44,7 @@ type packageHandleKey source.Hash // A packageHandle is a handle to the future result of type-checking a package. // The resulting package is obtained from the check() method. type packageHandle struct { - handle *memoize.Handle // [typeCheckResult] + promise *memoize.Promise // [typeCheckResult] // m is the metadata associated with the package. m *KnownMetadata @@ -141,7 +141,7 @@ func (s *snapshot) buildPackageHandle(ctx context.Context, id PackageID, mode so phKey := computePackageKey(m.ID, compiledGoFiles, m, depKeys, mode, experimentalKey) // TODO(adonovan): extract lambda into a standalone function to // avoid implicit lexical dependencies. - handle, release := s.store.Handle(phKey, func(ctx context.Context, arg interface{}) interface{} { + promise, release := s.store.Promise(phKey, func(ctx context.Context, arg interface{}) interface{} { snapshot := arg.(*snapshot) // Start type checking of direct dependencies, @@ -169,9 +169,9 @@ func (s *snapshot) buildPackageHandle(ctx context.Context, id PackageID, mode so }) ph := &packageHandle{ - handle: handle, - m: m, - key: phKey, + promise: promise, + m: m, + key: phKey, } s.mu.Lock() @@ -289,7 +289,7 @@ func hashConfig(config *packages.Config) source.Hash { } func (ph *packageHandle) check(ctx context.Context, s *snapshot) (*pkg, error) { - v, err := s.awaitHandle(ctx, ph.handle) + v, err := s.awaitPromise(ctx, ph.promise) if err != nil { return nil, err } @@ -306,7 +306,7 @@ func (ph *packageHandle) ID() string { } func (ph *packageHandle) cached() (*pkg, error) { - v := ph.handle.Cached() + v := ph.promise.Cached() if v == nil { return nil, fmt.Errorf("no cached type information for %s", ph.m.PkgPath) } diff --git a/internal/lsp/cache/mod.go b/internal/lsp/cache/mod.go index f9d148b737..57fa1e2d0a 100644 --- a/internal/lsp/cache/mod.go +++ b/internal/lsp/cache/mod.go @@ -39,19 +39,19 @@ func (s *snapshot) ParseMod(ctx context.Context, fh source.FileHandle) (*source. // cache miss? if !hit { - handle, release := s.store.Handle(fh.FileIdentity(), func(ctx context.Context, _ interface{}) interface{} { + promise, release := s.store.Promise(fh.FileIdentity(), func(ctx context.Context, _ interface{}) interface{} { parsed, err := parseModImpl(ctx, fh) return parseModResult{parsed, err} }) - entry = handle + entry = promise s.mu.Lock() s.parseModHandles.Set(uri, entry, func(_, _ interface{}) { release() }) s.mu.Unlock() } // Await result. - v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) + v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func (s *snapshot) ParseWork(ctx context.Context, fh source.FileHandle) (*source // cache miss? if !hit { - handle, release := s.store.Handle(fh.FileIdentity(), func(ctx context.Context, _ interface{}) interface{} { + handle, release := s.store.Promise(fh.FileIdentity(), func(ctx context.Context, _ interface{}) interface{} { parsed, err := parseWorkImpl(ctx, fh) return parseWorkResult{parsed, err} }) @@ -128,7 +128,7 @@ func (s *snapshot) ParseWork(ctx context.Context, fh source.FileHandle) (*source } // Await result. - v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) + v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) if err != nil { return nil, err } @@ -223,7 +223,7 @@ func (s *snapshot) ModWhy(ctx context.Context, fh source.FileHandle) (map[string // cache miss? if !hit { - handle := memoize.NewHandle("modWhy", func(ctx context.Context, arg interface{}) interface{} { + handle := memoize.NewPromise("modWhy", func(ctx context.Context, arg interface{}) interface{} { why, err := modWhyImpl(ctx, arg.(*snapshot), fh) return modWhyResult{why, err} }) @@ -235,7 +235,7 @@ func (s *snapshot) ModWhy(ctx context.Context, fh source.FileHandle) (map[string } // Await result. - v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) + v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) if err != nil { return nil, err } diff --git a/internal/lsp/cache/mod_tidy.go b/internal/lsp/cache/mod_tidy.go index a04bacf8ee..361f526ddf 100644 --- a/internal/lsp/cache/mod_tidy.go +++ b/internal/lsp/cache/mod_tidy.go @@ -69,7 +69,7 @@ func (s *snapshot) ModTidy(ctx context.Context, pm *source.ParsedModule) (*sourc return nil, err } - handle := memoize.NewHandle("modTidy", func(ctx context.Context, arg interface{}) interface{} { + handle := memoize.NewPromise("modTidy", func(ctx context.Context, arg interface{}) interface{} { tidied, err := modTidyImpl(ctx, arg.(*snapshot), uri.Filename(), pm) return modTidyResult{tidied, err} @@ -82,7 +82,7 @@ func (s *snapshot) ModTidy(ctx context.Context, pm *source.ParsedModule) (*sourc } // Await result. - v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) + v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) if err != nil { return nil, err } diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go index 62aea2229b..77e893a668 100644 --- a/internal/lsp/cache/parse.go +++ b/internal/lsp/cache/parse.go @@ -58,7 +58,7 @@ func (s *snapshot) ParseGo(ctx context.Context, fh source.FileHandle, mode sourc // cache miss? if !hit { - handle, release := s.store.Handle(key, func(ctx context.Context, arg interface{}) interface{} { + handle, release := s.store.Promise(key, func(ctx context.Context, arg interface{}) interface{} { parsed, err := parseGoImpl(ctx, arg.(*snapshot).FileSet(), fh, mode) return parseGoResult{parsed, err} }) @@ -76,7 +76,7 @@ func (s *snapshot) ParseGo(ctx context.Context, fh source.FileHandle, mode sourc } // Await result. - v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) + v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) if err != nil { return nil, err } @@ -92,7 +92,7 @@ func (s *snapshot) peekParseGoLocked(fh source.FileHandle, mode source.ParseMode if !hit { return nil, nil // no-one has requested this file } - v := entry.(*memoize.Handle).Cached() + v := entry.(*memoize.Promise).Cached() if v == nil { return nil, nil // parsing is still in progress } diff --git a/internal/lsp/cache/snapshot.go b/internal/lsp/cache/snapshot.go index 9f26b1e59c..9e52cda5af 100644 --- a/internal/lsp/cache/snapshot.go +++ b/internal/lsp/cache/snapshot.go @@ -84,7 +84,7 @@ type snapshot struct { files filesMap // parsedGoFiles maps a parseKey to the handle of the future result of parsing it. - parsedGoFiles *persistent.Map // from parseKey to *memoize.Handle[parseGoResult] + parsedGoFiles *persistent.Map // from parseKey to *memoize.Promise[parseGoResult] // parseKeysByURI records the set of keys of parsedGoFiles that // need to be invalidated for each URI. @@ -94,7 +94,7 @@ type snapshot struct { // symbolizeHandles maps each file URI to a handle for the future // result of computing the symbols declared in that file. - symbolizeHandles *persistent.Map // from span.URI to *memoize.Handle[symbolizeResult] + symbolizeHandles *persistent.Map // from span.URI to *memoize.Promise[symbolizeResult] // packages maps a packageKey to a *packageHandle. // It may be invalidated when a file's content changes. @@ -103,7 +103,7 @@ type snapshot struct { // - packages.Get(id).m.Metadata == meta.metadata[id].Metadata for all ids // - if a package is in packages, then all of its dependencies should also // be in packages, unless there is a missing import - packages *persistent.Map // from packageKey to *memoize.Handle[*packageHandle] + packages *persistent.Map // from packageKey to *memoize.Promise[*packageHandle] // isActivePackageCache maps package ID to the cached value if it is active or not. // It may be invalidated when metadata changes or a new file is opened or closed. @@ -122,17 +122,17 @@ type snapshot struct { // parseModHandles keeps track of any parseModHandles for the snapshot. // The handles need not refer to only the view's go.mod file. - parseModHandles *persistent.Map // from span.URI to *memoize.Handle[parseModResult] + parseModHandles *persistent.Map // from span.URI to *memoize.Promise[parseModResult] // parseWorkHandles keeps track of any parseWorkHandles for the snapshot. // The handles need not refer to only the view's go.work file. - parseWorkHandles *persistent.Map // from span.URI to *memoize.Handle[parseWorkResult] + parseWorkHandles *persistent.Map // from span.URI to *memoize.Promise[parseWorkResult] // Preserve go.mod-related handles to avoid garbage-collecting the results // of various calls to the go command. The handles need not refer to only // the view's go.mod file. - modTidyHandles *persistent.Map // from span.URI to *memoize.Handle[modTidyResult] - modWhyHandles *persistent.Map // from span.URI to *memoize.Handle[modWhyResult] + modTidyHandles *persistent.Map // from span.URI to *memoize.Promise[modTidyResult] + modWhyHandles *persistent.Map // from span.URI to *memoize.Promise[modWhyResult] workspace *workspace // (not guarded by mu) @@ -170,8 +170,8 @@ func (s *snapshot) Acquire() func() { return s.refcount.Done } -func (s *snapshot) awaitHandle(ctx context.Context, h *memoize.Handle) (interface{}, error) { - return h.Get(ctx, s) +func (s *snapshot) awaitPromise(ctx context.Context, p *memoize.Promise) (interface{}, error) { + return p.Get(ctx, s) } // destroy waits for all leases on the snapshot to expire then releases diff --git a/internal/lsp/cache/symbols.go b/internal/lsp/cache/symbols.go index b562d5bbdd..e98f554969 100644 --- a/internal/lsp/cache/symbols.go +++ b/internal/lsp/cache/symbols.go @@ -35,12 +35,12 @@ func (s *snapshot) symbolize(ctx context.Context, fh source.FileHandle) ([]sourc if !hit { type symbolHandleKey source.Hash key := symbolHandleKey(fh.FileIdentity().Hash) - handle, release := s.store.Handle(key, func(_ context.Context, arg interface{}) interface{} { + promise, release := s.store.Promise(key, func(_ context.Context, arg interface{}) interface{} { symbols, err := symbolizeImpl(arg.(*snapshot), fh) return symbolizeResult{symbols, err} }) - entry = handle + entry = promise s.mu.Lock() s.symbolizeHandles.Set(uri, entry, func(_, _ interface{}) { release() }) @@ -48,7 +48,7 @@ func (s *snapshot) symbolize(ctx context.Context, fh source.FileHandle) ([]sourc } // Await result. - v, err := s.awaitHandle(ctx, entry.(*memoize.Handle)) + v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) if err != nil { return nil, err } diff --git a/internal/memoize/memoize.go b/internal/memoize/memoize.go index 8c921c7e16..aa4d58d2f2 100644 --- a/internal/memoize/memoize.go +++ b/internal/memoize/memoize.go @@ -2,13 +2,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package memoize supports memoizing the return values of functions with -// idempotent results that are expensive to compute. +// Package memoize defines a "promise" abstraction that enables +// memoization of the result of calling an expensive but idempotent +// function. // -// To use this package, create a Store, call its Handle method to -// acquire a handle to (aka a "promise" of) the future result of a -// function, and call Handle.Get to obtain the result. Get may block -// if the function has not finished (or started). +// Call p = NewPromise(f) to obtain a promise for the future result of +// calling f(), and call p.Get() to obtain that result. All calls to +// p.Get return the result of a single call of f(). +// Get blocks if the function has not finished (or started). +// +// A Store is a map of arbitrary keys to promises. Use Store.Promise +// to create a promise in the store. All calls to Handle(k) return the +// same promise as long as it is in the store. These promises are +// reference-counted and must be explicitly released. Once the last +// reference is released, the promise is removed from the store. package memoize import ( @@ -22,22 +29,13 @@ import ( "golang.org/x/tools/internal/xcontext" ) -// TODO(adonovan): rename Handle to Promise, and present it before Store. - -// Store binds keys to functions, returning handles that can be used to access -// the function's result. -type Store struct { - handlesMu sync.Mutex - handles map[interface{}]*Handle -} - // Function is the type of a function that can be memoized. // // If the arg is a RefCounted, its Acquire/Release operations are called. // // The argument must not materially affect the result of the function -// in ways that are not captured by the handle's key, since if -// Handle.Get is called twice concurrently, with the same (implicit) +// in ways that are not captured by the promise's key, since if +// Promise.Get is called twice concurrently, with the same (implicit) // key but different arguments, the Function is called only once but // its result must be suitable for both callers. // @@ -63,21 +61,13 @@ type RefCounted interface { Acquire() func() } -type state int - -const ( - stateIdle = iota // newly constructed, or last waiter was cancelled - stateRunning // start was called and not cancelled - stateCompleted // function call ran to completion -) - -// A Handle represents the future result of a call to a function. -type Handle struct { +// A Promise represents the future result of a call to a function. +type Promise struct { debug string // for observability - mu sync.Mutex // lock ordering: Store.handlesMu before Handle.mu + mu sync.Mutex - // A Handle starts out IDLE, waiting for something to demand + // A Promise starts out IDLE, waiting for something to demand // its evaluation. It then transitions into RUNNING state. // // While RUNNING, waiters tracks the number of Get calls @@ -105,128 +95,78 @@ type Handle struct { refcount int32 // accessed using atomic load/store } -// Handle returns a reference-counted handle for the future result of -// calling the specified function. Calls to Handle with the same key -// return the same handle, and all calls to Handle.Get on a given -// handle return the same result but the function is called at most once. +// NewPromise returns a promise for the future result of calling the +// specified function. // -// The caller must call the returned function to decrement the -// handle's reference count when it is no longer needed. -func (store *Store) Handle(key interface{}, function Function) (*Handle, func()) { +// The debug string is used to classify promises in logs and metrics. +// It should be drawn from a small set. +func NewPromise(debug string, function Function) *Promise { if function == nil { panic("nil function") } - - store.handlesMu.Lock() - h, ok := store.handles[key] - if !ok { - // new handle - h = &Handle{ - function: function, - refcount: 1, - debug: reflect.TypeOf(key).String(), - } - - if store.handles == nil { - store.handles = map[interface{}]*Handle{} - } - store.handles[key] = h - } else { - // existing handle - atomic.AddInt32(&h.refcount, 1) - } - store.handlesMu.Unlock() - - release := func() { - if atomic.AddInt32(&h.refcount, -1) == 0 { - store.handlesMu.Lock() - delete(store.handles, key) - store.handlesMu.Unlock() - } - } - return h, release -} - -// Stats returns the number of each type of value in the store. -func (s *Store) Stats() map[reflect.Type]int { - result := map[reflect.Type]int{} - - s.handlesMu.Lock() - defer s.handlesMu.Unlock() - - for k := range s.handles { - result[reflect.TypeOf(k)]++ - } - return result -} - -// DebugOnlyIterate iterates through all live cache entries and calls f on them. -// It should only be used for debugging purposes. -func (s *Store) DebugOnlyIterate(f func(k, v interface{})) { - s.handlesMu.Lock() - defer s.handlesMu.Unlock() - - for k, h := range s.handles { - if v := h.Cached(); v != nil { - f(k, v) - } - } -} - -// NewHandle returns a handle for the future result of calling the -// specified function. -// -// The debug string is used to classify handles in logs and metrics. -// It should be drawn from a small set. -func NewHandle(debug string, function Function) *Handle { - return &Handle{ + return &Promise{ debug: debug, function: function, } } -// Cached returns the value associated with a handle. +type state int + +const ( + stateIdle = iota // newly constructed, or last waiter was cancelled + stateRunning // start was called and not cancelled + stateCompleted // function call ran to completion +) + +// Cached returns the value associated with a promise. // // It will never cause the value to be generated. // It will return the cached value, if present. -func (h *Handle) Cached() interface{} { - h.mu.Lock() - defer h.mu.Unlock() - if h.state == stateCompleted { - return h.value +func (p *Promise) Cached() interface{} { + p.mu.Lock() + defer p.mu.Unlock() + if p.state == stateCompleted { + return p.value } return nil } -// Get returns the value associated with a handle. +// Get returns the value associated with a promise. +// +// All calls to Promise.Get on a given promise return the +// same result but the function is called (to completion) at most once. // // If the value is not yet ready, the underlying function will be invoked. +// // If ctx is cancelled, Get returns (nil, Canceled). -func (h *Handle) Get(ctx context.Context, arg interface{}) (interface{}, error) { +// If all concurrent calls to Get are cancelled, the context provided +// to the function is cancelled. A later call to Get may attempt to +// call the function again. +func (p *Promise) Get(ctx context.Context, arg interface{}) (interface{}, error) { if ctx.Err() != nil { return nil, ctx.Err() } - h.mu.Lock() - switch h.state { + p.mu.Lock() + switch p.state { case stateIdle: - return h.run(ctx, arg) + return p.run(ctx, arg) case stateRunning: - return h.wait(ctx) + return p.wait(ctx) case stateCompleted: - defer h.mu.Unlock() - return h.value, nil + defer p.mu.Unlock() + return p.value, nil default: panic("unknown state") } } -// run starts h.function and returns the result. h.mu must be locked. -func (h *Handle) run(ctx context.Context, arg interface{}) (interface{}, error) { +// run starts p.function and returns the result. p.mu must be locked. +func (p *Promise) run(ctx context.Context, arg interface{}) (interface{}, error) { childCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) - h.cancel = cancel - h.state = stateRunning - h.done = make(chan struct{}) - function := h.function // Read under the lock + p.cancel = cancel + p.state = stateRunning + p.done = make(chan struct{}) + function := p.function // Read under the lock // Make sure that the argument isn't destroyed while we're running in it. release := func() {} @@ -235,7 +175,7 @@ func (h *Handle) run(ctx context.Context, arg interface{}) (interface{}, error) } go func() { - trace.WithRegion(childCtx, fmt.Sprintf("Handle.run %s", h.debug), func() { + trace.WithRegion(childCtx, fmt.Sprintf("Promise.run %s", p.debug), func() { defer release() // Just in case the function does something expensive without checking // the context, double-check we're still alive. @@ -247,51 +187,115 @@ func (h *Handle) run(ctx context.Context, arg interface{}) (interface{}, error) return } - h.mu.Lock() - defer h.mu.Unlock() - // It's theoretically possible that the handle has been cancelled out + p.mu.Lock() + defer p.mu.Unlock() + // It's theoretically possible that the promise has been cancelled out // of the run that started us, and then started running again since we // checked childCtx above. Even so, that should be harmless, since each // run should produce the same results. - if h.state != stateRunning { + if p.state != stateRunning { return } - h.value = v - h.function = nil // aid GC - h.state = stateCompleted - close(h.done) + p.value = v + p.function = nil // aid GC + p.state = stateCompleted + close(p.done) }) }() - return h.wait(ctx) + return p.wait(ctx) } -// wait waits for the value to be computed, or ctx to be cancelled. h.mu must be locked. -func (h *Handle) wait(ctx context.Context) (interface{}, error) { - h.waiters++ - done := h.done - h.mu.Unlock() +// wait waits for the value to be computed, or ctx to be cancelled. p.mu must be locked. +func (p *Promise) wait(ctx context.Context) (interface{}, error) { + p.waiters++ + done := p.done + p.mu.Unlock() select { case <-done: - h.mu.Lock() - defer h.mu.Unlock() - if h.state == stateCompleted { - return h.value, nil + p.mu.Lock() + defer p.mu.Unlock() + if p.state == stateCompleted { + return p.value, nil } return nil, nil case <-ctx.Done(): - h.mu.Lock() - defer h.mu.Unlock() - h.waiters-- - if h.waiters == 0 && h.state == stateRunning { - h.cancel() - close(h.done) - h.state = stateIdle - h.done = nil - h.cancel = nil + p.mu.Lock() + defer p.mu.Unlock() + p.waiters-- + if p.waiters == 0 && p.state == stateRunning { + p.cancel() + close(p.done) + p.state = stateIdle + p.done = nil + p.cancel = nil } return nil, ctx.Err() } } + +// A Store maps arbitrary keys to reference-counted promises. +type Store struct { + promisesMu sync.Mutex + promises map[interface{}]*Promise +} + +// Promise returns a reference-counted promise for the future result of +// calling the specified function. +// +// Calls to Promise with the same key return the same promise, +// incrementing its reference count. The caller must call the +// returned function to decrement the promise's reference count when +// it is no longer needed. Once the last reference has been released, +// the promise is removed from the store. +func (store *Store) Promise(key interface{}, function Function) (*Promise, func()) { + store.promisesMu.Lock() + p, ok := store.promises[key] + if !ok { + p = NewPromise(reflect.TypeOf(key).String(), function) + if store.promises == nil { + store.promises = map[interface{}]*Promise{} + } + store.promises[key] = p + } + atomic.AddInt32(&p.refcount, 1) + store.promisesMu.Unlock() + + release := func() { + if atomic.AddInt32(&p.refcount, -1) == 0 { + store.promisesMu.Lock() + delete(store.promises, key) + store.promisesMu.Unlock() + } + } + return p, release +} + +// Stats returns the number of each type of key in the store. +func (s *Store) Stats() map[reflect.Type]int { + result := map[reflect.Type]int{} + + s.promisesMu.Lock() + defer s.promisesMu.Unlock() + + for k := range s.promises { + result[reflect.TypeOf(k)]++ + } + return result +} + +// DebugOnlyIterate iterates through the store and, for each completed +// promise, calls f(k, v) for the map key k and function result v. It +// should only be used for debugging purposes. +func (s *Store) DebugOnlyIterate(f func(k, v interface{})) { + s.promisesMu.Lock() + defer s.promisesMu.Unlock() + + for k, p := range s.promises { + if v := p.Cached(); v != nil { + f(k, v) + } + } +} diff --git a/internal/memoize/memoize_test.go b/internal/memoize/memoize_test.go index bde02bf613..3550f1eb14 100644 --- a/internal/memoize/memoize_test.go +++ b/internal/memoize/memoize_test.go @@ -18,7 +18,7 @@ func TestGet(t *testing.T) { evaled := 0 - h, release := store.Handle("key", func(context.Context, interface{}) interface{} { + h, release := store.Promise("key", func(context.Context, interface{}) interface{} { evaled++ return "res" }) @@ -30,7 +30,7 @@ func TestGet(t *testing.T) { } } -func expectGet(t *testing.T, h *memoize.Handle, wantV interface{}) { +func expectGet(t *testing.T, h *memoize.Promise, wantV interface{}) { t.Helper() gotV, gotErr := h.Get(context.Background(), nil) if gotV != wantV || gotErr != nil { @@ -38,29 +38,50 @@ func expectGet(t *testing.T, h *memoize.Handle, wantV interface{}) { } } -func TestHandleRefCounting(t *testing.T) { +func TestNewPromise(t *testing.T) { + calls := 0 + f := func(context.Context, interface{}) interface{} { + calls++ + return calls + } + + // All calls to Get on the same promise return the same result. + p1 := memoize.NewPromise("debug", f) + expectGet(t, p1, 1) + expectGet(t, p1, 1) + + // A new promise calls the function again. + p2 := memoize.NewPromise("debug", f) + expectGet(t, p2, 2) + expectGet(t, p2, 2) + + // The original promise is unchanged. + expectGet(t, p1, 1) +} + +func TestStoredPromiseRefCounting(t *testing.T) { var store memoize.Store v1 := false v2 := false - h1, release1 := store.Handle("key1", func(context.Context, interface{}) interface{} { + p1, release1 := store.Promise("key1", func(context.Context, interface{}) interface{} { return &v1 }) - h2, release2 := store.Handle("key2", func(context.Context, interface{}) interface{} { + p2, release2 := store.Promise("key2", func(context.Context, interface{}) interface{} { return &v2 }) - expectGet(t, h1, &v1) - expectGet(t, h2, &v2) + expectGet(t, p1, &v1) + expectGet(t, p2, &v2) - expectGet(t, h1, &v1) - expectGet(t, h2, &v2) + expectGet(t, p1, &v1) + expectGet(t, p2, &v2) - h2Copy, release2Copy := store.Handle("key2", func(context.Context, interface{}) interface{} { + p2Copy, release2Copy := store.Promise("key2", func(context.Context, interface{}) interface{} { return &v1 }) - if h2 != h2Copy { - t.Error("NewHandle returned a new value while old is not destroyed yet") + if p2 != p2Copy { + t.Error("Promise returned a new value while old is not destroyed yet") } - expectGet(t, h2Copy, &v2) + expectGet(t, p2Copy, &v2) release2() if got, want := v2, false; got != want { @@ -72,23 +93,23 @@ func TestHandleRefCounting(t *testing.T) { } release1() - h2Copy, release2Copy = store.Handle("key2", func(context.Context, interface{}) interface{} { + p2Copy, release2Copy = store.Promise("key2", func(context.Context, interface{}) interface{} { return &v2 }) - if h2 == h2Copy { - t.Error("NewHandle returned previously destroyed value") + if p2 == p2Copy { + t.Error("Promise returned previously destroyed value") } release2Copy() } -func TestHandleDestroyedWhileRunning(t *testing.T) { - // Test that calls to Handle.Get return even if the handle is destroyed while running. +func TestPromiseDestroyedWhileRunning(t *testing.T) { + // Test that calls to Promise.Get return even if the promise is destroyed while running. var store memoize.Store c := make(chan int) var v int - h, release := store.Handle("key", func(ctx context.Context, _ interface{}) interface{} { + h, release := store.Promise("key", func(ctx context.Context, _ interface{}) interface{} { <-c <-c if err := ctx.Err(); err != nil { @@ -109,9 +130,9 @@ func TestHandleDestroyedWhileRunning(t *testing.T) { wg.Done() }() - c <- 0 // send once to enter the handle function - release() // release before the handle function returns - c <- 0 // let the handle function proceed + c <- 0 // send once to enter the promise function + release() // release before the promise function returns + c <- 0 // let the promise function proceed wg.Wait()