diff --git a/src/sync/atomic/value.go b/src/sync/atomic/value.go index eab7e70c9b..61f81d8fd3 100644 --- a/src/sync/atomic/value.go +++ b/src/sync/atomic/value.go @@ -25,7 +25,7 @@ type ifaceWords struct { // Load returns the value set by the most recent Store. // It returns nil if there has been no call to Store for this Value. -func (v *Value) Load() (x interface{}) { +func (v *Value) Load() (val interface{}) { vp := (*ifaceWords)(unsafe.Pointer(v)) typ := LoadPointer(&vp.typ) if typ == nil || uintptr(typ) == ^uintptr(0) { @@ -33,21 +33,21 @@ func (v *Value) Load() (x interface{}) { return nil } data := LoadPointer(&vp.data) - xp := (*ifaceWords)(unsafe.Pointer(&x)) - xp.typ = typ - xp.data = data + vlp := (*ifaceWords)(unsafe.Pointer(&val)) + vlp.typ = typ + vlp.data = data return } // Store sets the value of the Value to x. // All calls to Store for a given Value must use values of the same concrete type. // Store of an inconsistent type panics, as does Store(nil). -func (v *Value) Store(x interface{}) { - if x == nil { +func (v *Value) Store(val interface{}) { + if val == nil { panic("sync/atomic: store of nil value into Value") } vp := (*ifaceWords)(unsafe.Pointer(v)) - xp := (*ifaceWords)(unsafe.Pointer(&x)) + vlp := (*ifaceWords)(unsafe.Pointer(&val)) for { typ := LoadPointer(&vp.typ) if typ == nil { @@ -61,8 +61,8 @@ func (v *Value) Store(x interface{}) { continue } // Complete first store. - StorePointer(&vp.data, xp.data) - StorePointer(&vp.typ, xp.typ) + StorePointer(&vp.data, vlp.data) + StorePointer(&vp.typ, vlp.typ) runtime_procUnpin() return } @@ -73,14 +73,121 @@ func (v *Value) Store(x interface{}) { continue } // First store completed. Check type and overwrite data. - if typ != xp.typ { + if typ != vlp.typ { panic("sync/atomic: store of inconsistently typed value into Value") } - StorePointer(&vp.data, xp.data) + StorePointer(&vp.data, vlp.data) return } } +// Swap stores new into Value and returns the previous value. It returns nil if +// the Value is empty. +// +// All calls to Swap for a given Value must use values of the same concrete +// type. Swap of an inconsistent type panics, as does Swap(nil). +func (v *Value) Swap(new interface{}) (old interface{}) { + if new == nil { + panic("sync/atomic: swap of nil value into Value") + } + vp := (*ifaceWords)(unsafe.Pointer(v)) + np := (*ifaceWords)(unsafe.Pointer(&new)) + for { + typ := LoadPointer(&vp.typ) + if typ == nil { + // Attempt to start first store. + // Disable preemption so that other goroutines can use + // active spin wait to wait for completion; and so that + // GC does not see the fake type accidentally. + runtime_procPin() + if !CompareAndSwapPointer(&vp.typ, nil, unsafe.Pointer(^uintptr(0))) { + runtime_procUnpin() + continue + } + // Complete first store. + StorePointer(&vp.data, np.data) + StorePointer(&vp.typ, np.typ) + runtime_procUnpin() + return nil + } + if uintptr(typ) == ^uintptr(0) { + // First store in progress. Wait. + // Since we disable preemption around the first store, + // we can wait with active spinning. + continue + } + // First store completed. Check type and overwrite data. + if typ != np.typ { + panic("sync/atomic: swap of inconsistently typed value into Value") + } + op := (*ifaceWords)(unsafe.Pointer(&old)) + op.typ, op.data = np.typ, SwapPointer(&vp.data, np.data) + return old + } +} + +// CompareAndSwapPointer executes the compare-and-swap operation for the Value. +// +// All calls to CompareAndSwap for a given Value must use values of the same +// concrete type. CompareAndSwap of an inconsistent type panics, as does +// CompareAndSwap(old, nil). +func (v *Value) CompareAndSwap(old, new interface{}) (swapped bool) { + if new == nil { + panic("sync/atomic: compare and swap of nil value into Value") + } + vp := (*ifaceWords)(unsafe.Pointer(v)) + np := (*ifaceWords)(unsafe.Pointer(&new)) + op := (*ifaceWords)(unsafe.Pointer(&old)) + if op.typ != nil && np.typ != op.typ { + panic("sync/atomic: compare and swap of inconsistently typed values") + } + for { + typ := LoadPointer(&vp.typ) + if typ == nil { + if old != nil { + return false + } + // Attempt to start first store. + // Disable preemption so that other goroutines can use + // active spin wait to wait for completion; and so that + // GC does not see the fake type accidentally. + runtime_procPin() + if !CompareAndSwapPointer(&vp.typ, nil, unsafe.Pointer(^uintptr(0))) { + runtime_procUnpin() + continue + } + // Complete first store. + StorePointer(&vp.data, np.data) + StorePointer(&vp.typ, np.typ) + runtime_procUnpin() + return true + } + if uintptr(typ) == ^uintptr(0) { + // First store in progress. Wait. + // Since we disable preemption around the first store, + // we can wait with active spinning. + continue + } + // First store completed. Check type and overwrite data. + if typ != np.typ { + panic("sync/atomic: compare and swap of inconsistently typed value into Value") + } + // Compare old and current via runtime equality check. + // This allows value types to be compared, something + // not offered by the package functions. + // CompareAndSwapPointer below only ensures vp.data + // has not changed since LoadPointer. + data := LoadPointer(&vp.data) + var i interface{} + (*ifaceWords)(unsafe.Pointer(&i)).typ = typ + (*ifaceWords)(unsafe.Pointer(&i)).data = data + if i != old { + return false + } + return CompareAndSwapPointer(&vp.data, data, np.data) + } +} + // Disable/enable preemption, implemented in runtime. func runtime_procPin() func runtime_procUnpin() diff --git a/src/sync/atomic/value_test.go b/src/sync/atomic/value_test.go index f289766340..a5e717d6e0 100644 --- a/src/sync/atomic/value_test.go +++ b/src/sync/atomic/value_test.go @@ -7,6 +7,9 @@ package atomic_test import ( "math/rand" "runtime" + "strconv" + "sync" + "sync/atomic" . "sync/atomic" "testing" ) @@ -133,3 +136,139 @@ func BenchmarkValueRead(b *testing.B) { } }) } + +var Value_SwapTests = []struct { + init interface{} + new interface{} + want interface{} + err interface{} +}{ + {init: nil, new: nil, err: "sync/atomic: swap of nil value into Value"}, + {init: nil, new: true, want: nil, err: nil}, + {init: true, new: "", err: "sync/atomic: swap of inconsistently typed value into Value"}, + {init: true, new: false, want: true, err: nil}, +} + +func TestValue_Swap(t *testing.T) { + for i, tt := range Value_SwapTests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var v Value + if tt.init != nil { + v.Store(tt.init) + } + defer func() { + err := recover() + switch { + case tt.err == nil && err != nil: + t.Errorf("should not panic, got %v", err) + case tt.err != nil && err == nil: + t.Errorf("should panic %v, got ", tt.err) + } + }() + if got := v.Swap(tt.new); got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) + } + if got := v.Load(); got != tt.new { + t.Errorf("got %v, want %v", got, tt.new) + } + }) + } +} + +func TestValueSwapConcurrent(t *testing.T) { + var v Value + var count uint64 + var g sync.WaitGroup + var m, n uint64 = 10000, 10000 + if testing.Short() { + m = 1000 + n = 1000 + } + for i := uint64(0); i < m*n; i += n { + i := i + g.Add(1) + go func() { + var c uint64 + for new := i; new < i+n; new++ { + if old := v.Swap(new); old != nil { + c += old.(uint64) + } + } + atomic.AddUint64(&count, c) + g.Done() + }() + } + g.Wait() + if want, got := (m*n-1)*(m*n)/2, count+v.Load().(uint64); got != want { + t.Errorf("sum from 0 to %d was %d, want %v", m*n-1, got, want) + } +} + +var heapA, heapB = struct{ uint }{0}, struct{ uint }{0} + +var Value_CompareAndSwapTests = []struct { + init interface{} + new interface{} + old interface{} + want bool + err interface{} +}{ + {init: nil, new: nil, old: nil, err: "sync/atomic: compare and swap of nil value into Value"}, + {init: nil, new: true, old: "", err: "sync/atomic: compare and swap of inconsistently typed values into Value"}, + {init: nil, new: true, old: true, want: false, err: nil}, + {init: nil, new: true, old: nil, want: true, err: nil}, + {init: true, new: "", err: "sync/atomic: compare and swap of inconsistently typed value into Value"}, + {init: true, new: true, old: false, want: false, err: nil}, + {init: true, new: true, old: true, want: true, err: nil}, + {init: heapA, new: struct{ uint }{1}, old: heapB, want: true, err: nil}, +} + +func TestValue_CompareAndSwap(t *testing.T) { + for i, tt := range Value_CompareAndSwapTests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var v Value + if tt.init != nil { + v.Store(tt.init) + } + defer func() { + err := recover() + switch { + case tt.err == nil && err != nil: + t.Errorf("got %v, wanted no panic", err) + case tt.err != nil && err == nil: + t.Errorf("did not panic, want %v", tt.err) + } + }() + if got := v.CompareAndSwap(tt.old, tt.new); got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} + +func TestValueCompareAndSwapConcurrent(t *testing.T) { + var v Value + var w sync.WaitGroup + v.Store(0) + m, n := 1000, 100 + if testing.Short() { + m = 100 + n = 100 + } + for i := 0; i < m; i++ { + i := i + w.Add(1) + go func() { + for j := i; j < m*n; runtime.Gosched() { + if v.CompareAndSwap(j, j+1) { + j += m + } + } + w.Done() + }() + } + w.Wait() + if stop := v.Load().(int); stop != m*n { + t.Errorf("did not get to %v, stopped at %v", m*n, stop) + } +}