From d498867cf94ee8774da5d5b8924cc660b41c936b Mon Sep 17 00:00:00 2001 From: qiulaidongfeng <2645477756@qq.com> Date: Sun, 1 Sep 2024 15:57:39 +0800 Subject: [PATCH] fix bug Change-Id: I30d5017391a36d503844ff9bf8649a361e032ee9 --- src/hash/maphash/maphash.go | 43 ++++++++++++++++ src/hash/maphash/maphash_purego.go | 45 ----------------- src/hash/maphash/maphash_runtime.go | 24 ++++++--- src/hash/maphash/maphash_test.go | 77 +++++++++++++++++++++++++++-- 4 files changed, 134 insertions(+), 55 deletions(-) diff --git a/src/hash/maphash/maphash.go b/src/hash/maphash/maphash.go index 38d6c3de23..d7cf4a57a0 100644 --- a/src/hash/maphash/maphash.go +++ b/src/hash/maphash/maphash.go @@ -13,8 +13,10 @@ package maphash import ( + "fmt" "internal/abi" "internal/byteorder" + "reflect" ) // A Seed is a random value that selects the specific hash function @@ -301,3 +303,44 @@ func WriteComparable[T comparable](h *Hash, x T) { byteorder.LePutUint64(buf[:], v) h.Write(buf[:]) } + +func appendT(buf []byte, v reflect.Value) []byte { + switch v.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + return byteorder.LeAppendUint64(buf, uint64(v.Int())) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, reflect.Uintptr: + return byteorder.LeAppendUint64(buf, v.Uint()) + case reflect.Array: + for i := range v.Len() { + buf = appendT(buf, v.Index(i)) + } + return buf + case reflect.String: + return append(buf, v.String()...) + case reflect.Struct: + for i := range v.NumField() { + buf = appendT(buf, v.Field(i)) + } + return buf + case reflect.Complex64, reflect.Complex128: + c := v.Complex() + buf = byteorder.LeAppendUint64(buf, uint64(real(c))) + return byteorder.LeAppendUint64(buf, uint64(imag(c))) + case reflect.Float32, reflect.Float64: + return byteorder.LeAppendUint64(buf, uint64(v.Float())) + case reflect.Bool: + return byteorder.LeAppendUint16(buf, btoi(v.Bool())) + case reflect.UnsafePointer, reflect.Pointer: + return byteorder.LeAppendUint64(buf, uint64(v.Pointer())) + case reflect.Interface: + return appendT(buf, v.Elem()) + } + panic(fmt.Errorf("hash/maphash: %s not comparable", v.Type().String())) +} + +func btoi(b bool) uint16 { + if b { + return 1 + } + return 0 +} diff --git a/src/hash/maphash/maphash_purego.go b/src/hash/maphash/maphash_purego.go index 83d7bd5d51..7000c86914 100644 --- a/src/hash/maphash/maphash_purego.go +++ b/src/hash/maphash/maphash_purego.go @@ -107,48 +107,3 @@ func comparableF[T comparable](seed uint64, v T, t *abi.Type) uint64 { buf = appendT(buf, vv) return wyhash(buf, seed, uint64(len(buf))) } - -func appendT(buf []byte, v reflect.Value) []byte { - switch v.Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: - return byteorder.LeAppendUint64(buf, uint64(v.Int())) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, reflect.Uintptr: - return byteorder.LeAppendUint64(buf, v.Uint()) - case reflect.Array: - for i := range v.Len() { - buf = appendT(buf, v.Index(i)) - } - return buf - case reflect.String: - return append(buf, v.String()...) - case reflect.Struct: - for i := range v.NumField() { - buf = appendT(buf, v.Field(i)) - } - return buf - case reflect.Complex64, reflect.Complex128: - c := v.Complex() - buf = byteorder.LeAppendUint64(buf, uint64(real(c))) - return byteorder.LeAppendUint64(buf, uint64(imag(c))) - case reflect.Float32, reflect.Float64: - return byteorder.LeAppendUint64(buf, uint64(v.Float())) - case reflect.Bool: - return byteorder.LeAppendUint16(buf, btoi(v.Bool())) - case reflect.UnsafePointer, reflect.Pointer: - return byteorder.LeAppendUint64(buf, uint64(v.Pointer())) - case reflect.Interface: - // There is no need to check comparability here, - // because !purego also treats it as a piece of memory. - a := v.InterfaceData() - buf = byteorder.LeAppendUint64(buf, uint64(a[0])) - return byteorder.LeAppendUint64(buf, uint64(a[1])) - } - panic("unreachable") -} - -func btoi(b bool) uint16 { - if b { - return 1 - } - return 0 -} diff --git a/src/hash/maphash/maphash_runtime.go b/src/hash/maphash/maphash_runtime.go index 2f1b75f140..2305f56653 100644 --- a/src/hash/maphash/maphash_runtime.go +++ b/src/hash/maphash/maphash_runtime.go @@ -9,6 +9,7 @@ package maphash import ( "internal/abi" "internal/unsafeheader" + "reflect" "unsafe" ) @@ -45,17 +46,28 @@ func randUint64() uint64 { } func comparableF[T comparable](seed uint64, v T, t *abi.Type) uint64 { - k := t.Kind() - len := t.Size() ptr := unsafe.Pointer(&v) + l := t.Size() + k := t.Kind() if k == abi.String { - len = uintptr(((*unsafeheader.String)(unsafe.Pointer(&v))).Len) + l = uintptr(((*unsafeheader.String)(unsafe.Pointer(&v))).Len) ptr = ((*unsafeheader.String)(unsafe.Pointer(&v))).Data + } else if t.TFlag&abi.TFlagRegularMemory == 0 { + // Note: if T like struct {s string} + // str value equal but ptr not equal, + // if think of it as a contiguous piece of memory, + // hash it, that happen v1 == v2 + // Comparable(s, v1) != Comparable(s, v2). + vv := reflect.ValueOf(v) + buf := make([]byte, 0, vv.Type().Size()) + buf = appendT(buf, vv) + ptr = unsafe.Pointer(&buf[0]) + l = uintptr(len(buf)) } if unsafe.Sizeof(uintptr(0)) == 8 { - return uint64(runtime_memhash(ptr, uintptr(seed), len)) + return uint64(runtime_memhash(ptr, uintptr(seed), l)) } - lo := runtime_memhash(ptr, uintptr(seed), len) - hi := runtime_memhash(ptr, uintptr(seed>>32), len) + lo := runtime_memhash(ptr, uintptr(seed), l) + hi := runtime_memhash(ptr, uintptr(seed>>32), l) return uint64(hi)<<32 | uint64(lo) } diff --git a/src/hash/maphash/maphash_test.go b/src/hash/maphash/maphash_test.go index 8428f9a711..29afe23ed4 100644 --- a/src/hash/maphash/maphash_test.go +++ b/src/hash/maphash/maphash_test.go @@ -6,10 +6,13 @@ package maphash import ( "bytes" + "crypto/rand" "fmt" "hash" "reflect" + "strings" "testing" + "unsafe" ) func TestUnseededHash(t *testing.T) { @@ -224,11 +227,43 @@ func TestComparable(t *testing.T) { i int f float64 }{i: 9, f: 9.9}) + type S struct { + s string + } + s1 := S{s: heapStr(t)} + s2 := S{s: heapStr(t)} + if unsafe.StringData(s1.s) == unsafe.StringData(s2.s) { + t.Fatalf("unexpected two heapStr ptr equal") + } + if s1.s != s2.s { + t.Fatalf("unexpected two heapStr value not equal") + } + testComparable(t, s1, s2) } -func testComparable[T comparable](t *testing.T, v T) { - t.Run(reflect.TypeFor[T]().Name(), func(t *testing.T) { +var heapStrValue []byte + +//go:noinline +func heapStr(t *testing.T) string { + s := make([]byte, 10) + if heapStrValue != nil { + copy(s, heapStrValue) + } else { + _, err := rand.Read(s) + if err != nil { + t.Fatal(err) + } + heapStrValue = s + } + return string(s) +} + +func testComparable[T comparable](t *testing.T, v T, v2 ...T) { + t.Run(reflect.TypeFor[T]().String(), func(t *testing.T) { var a, b T = v, v + if len(v2) != 0 { + b = v2[0] + } var pa *T = &a seed := MakeSeed() if Comparable(seed, a) != Comparable(seed, b) { @@ -266,11 +301,26 @@ func TestWriteComparable(t *testing.T) { i int f float64 }{i: 9, f: 9.9}) + type S struct { + s string + } + s1 := S{s: heapStr(t)} + s2 := S{s: heapStr(t)} + if unsafe.StringData(s1.s) == unsafe.StringData(s2.s) { + t.Fatalf("unexpected two heapStr ptr equal") + } + if s1.s != s2.s { + t.Fatalf("unexpected two heapStr value not equal") + } + testWriteComparable(t, s1, s2) } -func testWriteComparable[T comparable](t *testing.T, v T) { - t.Run(reflect.TypeFor[T]().Name(), func(t *testing.T) { +func testWriteComparable[T comparable](t *testing.T, v T, v2 ...T) { + t.Run(reflect.TypeFor[T]().String(), func(t *testing.T) { var a, b T = v, v + if len(v2) != 0 { + b = v2[0] + } var pa *T = &a h1 := Hash{} h2 := Hash{} @@ -292,6 +342,25 @@ func testWriteComparable[T comparable](t *testing.T, v T) { }) } +func TestComparableShouldPanic(t *testing.T) { + s := []byte("s") + a := any(s) + defer func() { + err := recover() + if err == nil { + t.Fatalf("hash any([]byte) should panic(error) in maphash.appendT") + } + e, ok := err.(error) + if !ok { + t.Fatalf("hash any([]byte) should panic(error) in maphash.appendT") + } + if !strings.Contains(e.Error(), "comparable") { + t.Fatalf("hash any([]byte) should panic(error) in maphash.appendT") + } + }() + Comparable(MakeSeed(), a) +} + // Make sure a Hash implements the hash.Hash and hash.Hash64 interfaces. var _ hash.Hash = &Hash{} var _ hash.Hash64 = &Hash{}