internal/sync: add Swap to HashTrieMap

This change adds the Swap operation (with the same semantics as
sync.Map's Swap) to HashTrieMap.

Change-Id: I8697a0c8c2eb761e2452a41b868b590ccbfa5c03
Reviewed-on: https://go-review.googlesource.com/c/go/+/594064
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Michael Knyszek <mknyszek@google.com>
This commit is contained in:
Michael Anthony Knyszek 2024-06-21 21:00:15 +00:00 committed by Gopher Robot
parent 85fa418fd5
commit 28bac5640c
2 changed files with 284 additions and 2 deletions

View File

@ -195,6 +195,80 @@ func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uin
return &top.node
}
// Swap swaps the value for a key and returns the previous value if any.
// The loaded result reports whether the key was present.
func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
ht.init()
hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
var i *indirect[K, V]
var hashShift uint
var slot *atomic.Pointer[node[K, V]]
var n *node[K, V]
for {
// Find the key or a candidate location for insertion.
i = ht.root
hashShift = 8 * goarch.PtrSize
haveInsertPoint := false
for hashShift != 0 {
hashShift -= nChildrenLog2
slot = &i.children[(hash>>hashShift)&nChildrenMask]
n = slot.Load()
if n == nil || n.isEntry {
// We found a nil slot which is a candidate for insertion,
// or an existing entry that we'll replace.
haveInsertPoint = true
break
}
i = n.indirect()
}
if !haveInsertPoint {
panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
}
// Grab the lock and double-check what we saw.
i.mu.Lock()
n = slot.Load()
if (n == nil || n.isEntry) && !i.dead.Load() {
// What we saw is still true, so we can continue with the insert.
break
}
// We have to start over.
i.mu.Unlock()
}
// N.B. This lock is held from when we broke out of the outer loop above.
// We specifically break this out so that we can use defer here safely.
// One option is to break this out into a new function instead, but
// there's so much local iteration state used below that this turns out
// to be cleaner.
defer i.mu.Unlock()
var zero V
var oldEntry *entry[K, V]
if n != nil {
// Swap if the keys compare.
oldEntry = n.entry()
newEntry, old, swapped := oldEntry.swap(key, new)
if swapped {
slot.Store(&newEntry.node)
return old, true
}
}
// The keys didn't compare, so we're doing an insertion.
newEntry := newEntryNode(key, new)
if oldEntry == nil {
// Easy case: create a new entry and store it.
slot.Store(&newEntry.node)
} else {
// We possibly need to expand the entry already there into one or more new nodes.
//
// Publish the node last, which will make both oldEntry and newEntry visible. We
// don't want readers to be able to observe that oldEntry isn't in the tree.
slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
}
return zero, false
}
// CompareAndSwap swaps the old and new values for key
// if the value stored in the map is equal to old.
// The value type must be of a comparable type, otherwise CompareAndSwap will panic.
@ -439,6 +513,35 @@ func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bo
return *new(V), false
}
// swap replaces an entry in the overflow chain if keys compare equal. Returns the new entry chain,
// the old value, and whether or not anything was swapped.
//
// swap must be called under the mutex of the indirect node which e is a child of.
func (head *entry[K, V]) swap(key K, new V) (*entry[K, V], V, bool) {
if head.key == key {
// Return the new head of the list.
e := newEntryNode(key, new)
if chain := head.overflow.Load(); chain != nil {
e.overflow.Store(chain)
}
return e, head.value, true
}
i := &head.overflow
e := i.Load()
for e != nil {
if e.key == key {
eNew := newEntryNode(key, new)
eNew.overflow.Store(e.overflow.Load())
i.Store(eNew)
return head, e.value, true
}
i = &e.overflow
e = e.overflow.Load()
}
var zero V
return head, zero, false
}
// compareAndSwap replaces an entry in the overflow chain if both the key and value compare
// equal. Returns the new entry chain and whether or not anything was swapped.
//

View File

@ -266,7 +266,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int]
}
}
})
t.Run("ConcurrentLifecycleSwapUnsharedKeys", func(t *testing.T) {
t.Run("ConcurrentLifecycleCompareAndSwapUnsharedKeys", func(t *testing.T) {
m := newMap()
gmp := runtime.GOMAXPROCS(-1)
@ -300,7 +300,7 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int]
}
wg.Wait()
})
t.Run("ConcurrentLifecycleSwapAndDeleteUnsharedKeys", func(t *testing.T) {
t.Run("ConcurrentLifecycleCompareAndSwapAndDeleteUnsharedKeys", func(t *testing.T) {
m := newMap()
gmp := runtime.GOMAXPROCS(-1)
@ -365,6 +365,161 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int]
}
wg.Wait()
})
t.Run("SwapAll", func(t *testing.T) {
m := newMap()
for i, s := range testData {
expectMissing(t, s, 0)(m.Load(s))
expectNotLoadedFromSwap(t, s, i)(m.Swap(s, i))
expectPresent(t, s, i)(m.Load(s))
expectLoadedFromSwap(t, s, i, i)(m.Swap(s, i))
}
for j := range 3 {
for i, s := range testData {
expectPresent(t, s, i+j)(m.Load(s))
expectLoadedFromSwap(t, s, i+j, i+j+1)(m.Swap(s, i+j+1))
expectPresent(t, s, i+j+1)(m.Load(s))
}
}
for i, s := range testData {
expectLoadedFromSwap(t, s, i+3, i+3)(m.Swap(s, i+3))
}
})
t.Run("SwapOne", func(t *testing.T) {
m := newMap()
for i, s := range testData {
expectMissing(t, s, 0)(m.Load(s))
expectNotLoadedFromSwap(t, s, i)(m.Swap(s, i))
expectPresent(t, s, i)(m.Load(s))
expectLoadedFromSwap(t, s, i, i)(m.Swap(s, i))
}
expectLoadedFromSwap(t, testData[15], 15, 16)(m.Swap(testData[15], 16))
for i, s := range testData {
if i == 15 {
expectPresent(t, s, 16)(m.Load(s))
} else {
expectPresent(t, s, i)(m.Load(s))
}
}
})
t.Run("SwapMultiple", func(t *testing.T) {
m := newMap()
for i, s := range testData {
expectMissing(t, s, 0)(m.Load(s))
expectNotLoadedFromSwap(t, s, i)(m.Swap(s, i))
expectPresent(t, s, i)(m.Load(s))
expectLoadedFromSwap(t, s, i, i)(m.Swap(s, i))
}
for _, i := range []int{1, 105, 6, 85} {
expectLoadedFromSwap(t, testData[i], i, i+1)(m.Swap(testData[i], i+1))
}
for i, s := range testData {
if i == 1 || i == 105 || i == 6 || i == 85 {
expectPresent(t, s, i+1)(m.Load(s))
} else {
expectPresent(t, s, i)(m.Load(s))
}
}
})
t.Run("ConcurrentLifecycleSwapUnsharedKeys", func(t *testing.T) {
m := newMap()
gmp := runtime.GOMAXPROCS(-1)
var wg sync.WaitGroup
for i := range gmp {
wg.Add(1)
go func(id int) {
defer wg.Done()
makeKey := func(s string) string {
return s + "-" + strconv.Itoa(id)
}
for _, s := range testData {
key := makeKey(s)
expectMissing(t, key, 0)(m.Load(key))
expectNotLoadedFromSwap(t, key, id)(m.Swap(key, id))
expectPresent(t, key, id)(m.Load(key))
expectLoadedFromSwap(t, key, id, id)(m.Swap(key, id))
}
for _, s := range testData {
key := makeKey(s)
expectPresent(t, key, id)(m.Load(key))
expectLoadedFromSwap(t, key, id, id+1)(m.Swap(key, id+1))
expectPresent(t, key, id+1)(m.Load(key))
}
for _, s := range testData {
key := makeKey(s)
expectPresent(t, key, id+1)(m.Load(key))
}
}(i)
}
wg.Wait()
})
t.Run("ConcurrentLifecycleSwapAndDeleteUnsharedKeys", func(t *testing.T) {
m := newMap()
gmp := runtime.GOMAXPROCS(-1)
var wg sync.WaitGroup
for i := range gmp {
wg.Add(1)
go func(id int) {
defer wg.Done()
makeKey := func(s string) string {
return s + "-" + strconv.Itoa(id)
}
for _, s := range testData {
key := makeKey(s)
expectMissing(t, key, 0)(m.Load(key))
expectNotLoadedFromSwap(t, key, id)(m.Swap(key, id))
expectPresent(t, key, id)(m.Load(key))
expectLoadedFromSwap(t, key, id, id)(m.Swap(key, id))
}
for _, s := range testData {
key := makeKey(s)
expectPresent(t, key, id)(m.Load(key))
expectLoadedFromSwap(t, key, id, id+1)(m.Swap(key, id+1))
expectPresent(t, key, id+1)(m.Load(key))
expectDeleted(t, key, id+1)(m.CompareAndDelete(key, id+1))
expectNotLoadedFromSwap(t, key, id+2)(m.Swap(key, id+2))
expectPresent(t, key, id+2)(m.Load(key))
}
for _, s := range testData {
key := makeKey(s)
expectPresent(t, key, id+2)(m.Load(key))
}
}(i)
}
wg.Wait()
})
t.Run("ConcurrentSwapSharedKeys", func(t *testing.T) {
m := newMap()
// Load up the map.
for i, s := range testData {
expectMissing(t, s, 0)(m.Load(s))
expectStored(t, s, i)(m.LoadOrStore(s, i))
}
gmp := runtime.GOMAXPROCS(-1)
var wg sync.WaitGroup
for i := range gmp {
wg.Add(1)
go func(id int) {
defer wg.Done()
for i, s := range testData {
m.Swap(s, i+1)
expectPresent(t, s, i+1)(m.Load(s))
}
for i, s := range testData {
expectPresent(t, s, i+1)(m.Load(s))
}
}(i)
}
wg.Wait()
})
}
func testAll[K, V comparable](t *testing.T, m *isync.HashTrieMap[K, V], testData map[K]V, yield func(K, V) bool) {
@ -497,6 +652,30 @@ func expectNotSwapped[K, V comparable](t *testing.T, key K, old, new V) func(swa
}
}
func expectLoadedFromSwap[K, V comparable](t *testing.T, key K, want, new V) func(got V, loaded bool) {
t.Helper()
return func(got V, loaded bool) {
t.Helper()
if !loaded {
t.Errorf("expected key %v to be in map and for %v to have been swapped for %v", key, want, new)
} else if want != got {
t.Errorf("key %v had its value %v swapped for %v, but expected it to have value %v", key, got, new, want)
}
}
}
func expectNotLoadedFromSwap[K, V comparable](t *testing.T, key K, new V) func(old V, loaded bool) {
t.Helper()
return func(old V, loaded bool) {
t.Helper()
if loaded {
t.Errorf("expected key %v to not be in map, but found value %v for it", key, old)
}
}
}
func testDataMap(data []string) map[string]int {
m := make(map[string]int)
for i, s := range data {