From aa8ec88cccb3b3bb5bb5d031b779bb2aeb6bc3f1 Mon Sep 17 00:00:00 2001 From: Derek Parker Date: Wed, 3 May 2023 13:16:52 -0700 Subject: [PATCH] ensure proper evaluation order if compare can panic --- src/cmd/compile/internal/compare/compare.go | 17 +++++++++++++---- src/cmd/compile/internal/reflectdata/alg.go | 10 +++++----- src/cmd/compile/internal/walk/compare.go | 2 +- test/fixedbugs/issue8606.go | 5 +++++ 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/cmd/compile/internal/compare/compare.go b/src/cmd/compile/internal/compare/compare.go index d8ae7bf24a..e12b679717 100644 --- a/src/cmd/compile/internal/compare/compare.go +++ b/src/cmd/compile/internal/compare/compare.go @@ -166,10 +166,14 @@ func calculateCostForType(t *types.Type) int64 { // It works by building a list of boolean conditions to satisfy. // Conditions must be evaluated in the returned order and // properly short-circuited by the caller. -func EqStruct(t *types.Type, np, nq ir.Node) []ir.Node { +// The first return value is the flattened list of conditions, +// the second value is a boolean indicating whether any of the +// comparisons could panic. +func EqStruct(t *types.Type, np, nq ir.Node) ([]ir.Node, bool) { // The conditions are a list-of-lists. Conditions are reorderable // within each inner list. The outer lists must be evaluated in order. var conds [][]ir.Node + var canPanic bool conds = append(conds, []ir.Node{}) and := func(n ir.Node) { i := len(conds) - 1 @@ -187,9 +191,14 @@ func EqStruct(t *types.Type, np, nq ir.Node) []ir.Node { continue } + typeCanPanic := EqCanPanic(f.Type) + if !canPanic { + canPanic = typeCanPanic + } + // Compare non-memory fields with field equality. if !IsRegularMemory(f.Type) { - if EqCanPanic(f.Type) { + if typeCanPanic { // Enforce ordering by starting a new set of reorderable conditions. conds = append(conds, []ir.Node{}) } @@ -203,7 +212,7 @@ func EqStruct(t *types.Type, np, nq ir.Node) []ir.Node { default: and(ir.NewBinaryExpr(base.Pos, ir.OEQ, p, q)) } - if EqCanPanic(f.Type) { + if typeCanPanic { // Also enforce ordering after something that can panic. conds = append(conds, []ir.Node{}) } @@ -238,7 +247,7 @@ func EqStruct(t *types.Type, np, nq ir.Node) []ir.Node { }) flatConds = append(flatConds, c...) } - return flatConds + return flatConds, canPanic } // EqString returns the nodes diff --git a/src/cmd/compile/internal/reflectdata/alg.go b/src/cmd/compile/internal/reflectdata/alg.go index dc36965968..69de685ca0 100644 --- a/src/cmd/compile/internal/reflectdata/alg.go +++ b/src/cmd/compile/internal/reflectdata/alg.go @@ -530,7 +530,7 @@ func eqFunc(t *types.Type) *ir.Func { qi := ir.NewIndexExpr(tmpPos, nq, ir.NewInt(tmpPos, 0)) qi.SetBounded(true) qi.SetType(t.Elem()) - flatConds := compare.EqStruct(t.Elem(), pi, qi) + flatConds, canPanic := compare.EqStruct(t.Elem(), pi, qi) for _, c := range flatConds { if isCall(c) { hasCallExprs = true @@ -538,7 +538,7 @@ func eqFunc(t *types.Type) *ir.Func { allCallExprs = false } } - if !hasCallExprs || allCallExprs { + if !hasCallExprs || allCallExprs || canPanic { checkAll(1, true, func(pi, qi ir.Node) ir.Node { // p[i] == q[i] return ir.NewBinaryExpr(base.Pos, ir.OEQ, pi, qi) @@ -546,7 +546,7 @@ func eqFunc(t *types.Type) *ir.Func { } else { checkAll(4, false, func(pi, qi ir.Node) ir.Node { expr = nil - flatConds := compare.EqStruct(t.Elem(), pi, qi) + flatConds, _ := compare.EqStruct(t.Elem(), pi, qi) if len(flatConds) == 0 { return ir.NewBool(base.Pos, true) } @@ -559,7 +559,7 @@ func eqFunc(t *types.Type) *ir.Func { }) checkAll(2, true, func(pi, qi ir.Node) ir.Node { expr = nil - flatConds := compare.EqStruct(t.Elem(), pi, qi) + flatConds, _ := compare.EqStruct(t.Elem(), pi, qi) for _, c := range flatConds { if isCall(c) { and(c) @@ -576,7 +576,7 @@ func eqFunc(t *types.Type) *ir.Func { } case types.TSTRUCT: - flatConds := compare.EqStruct(t, np, nq) + flatConds, _ := compare.EqStruct(t, np, nq) if len(flatConds) == 0 { fn.Body.Append(ir.NewAssignStmt(base.Pos, nr, ir.NewBool(base.Pos, true))) } else { diff --git a/src/cmd/compile/internal/walk/compare.go b/src/cmd/compile/internal/walk/compare.go index 58d6b57496..625cfecee0 100644 --- a/src/cmd/compile/internal/walk/compare.go +++ b/src/cmd/compile/internal/walk/compare.go @@ -228,7 +228,7 @@ func walkCompare(n *ir.BinaryExpr, init *ir.Nodes) ir.Node { cmpl = safeExpr(cmpl, init) cmpr = safeExpr(cmpr, init) if t.IsStruct() { - conds := compare.EqStruct(t, cmpl, cmpr) + conds, _ := compare.EqStruct(t, cmpl, cmpr) if n.Op() == ir.OEQ { for _, cond := range conds { and(cond) diff --git a/test/fixedbugs/issue8606.go b/test/fixedbugs/issue8606.go index 2900406e31..6bac02a1da 100644 --- a/test/fixedbugs/issue8606.go +++ b/test/fixedbugs/issue8606.go @@ -34,6 +34,10 @@ func main() { f any i int } + type S4 struct { + a [1000]byte + b any + } b := []byte{1} s1 := S3{func() {}, 0} s2 := S3{func() {}, 1} @@ -72,6 +76,7 @@ func main() { {false, T3{s: "fooz", j: b}, T3{s: "bar", j: b}}, {true, A{s1, s2}, A{s2, s1}}, {true, s1, s2}, + {false, S4{[1000]byte{0}, func() {}}, S4{[1000]byte{1}, func() {}}}, } { f := func() { defer func() {