diff --git a/src/cmd/compile/internal/ssa/_gen/generic.rules b/src/cmd/compile/internal/ssa/_gen/generic.rules index 0ae05ec641..aeda62591a 100644 --- a/src/cmd/compile/internal/ssa/_gen/generic.rules +++ b/src/cmd/compile/internal/ssa/_gen/generic.rules @@ -981,7 +981,7 @@ (ConstNil ) (ConstNil )) -(NilCheck (GetG mem) mem) => mem +(NilCheck ptr:(GetG mem) mem) => ptr (If (Not cond) yes no) => (If cond no yes) (If (ConstBool [c]) yes no) && c => (First yes no) @@ -2055,19 +2055,19 @@ && isSameCall(call.Aux, "runtime.newobject") => mem -(NilCheck (SelectN [0] call:(StaticLECall _ _)) _) +(NilCheck ptr:(SelectN [0] call:(StaticLECall _ _)) _) && isSameCall(call.Aux, "runtime.newobject") && warnRule(fe.Debug_checknil(), v, "removed nil check") - => (Invalid) + => ptr -(NilCheck (OffPtr (SelectN [0] call:(StaticLECall _ _))) _) +(NilCheck ptr:(OffPtr (SelectN [0] call:(StaticLECall _ _))) _) && isSameCall(call.Aux, "runtime.newobject") && warnRule(fe.Debug_checknil(), v, "removed nil check") - => (Invalid) + => ptr // Addresses of globals are always non-nil. -(NilCheck (Addr {_} (SB)) _) => (Invalid) -(NilCheck (Convert (Addr {_} (SB)) _) _) => (Invalid) +(NilCheck ptr:(Addr {_} (SB)) _) => ptr +(NilCheck ptr:(Convert (Addr {_} (SB)) _) _) => ptr // for late-expanded calls, recognize memequal applied to a single constant byte // Support is limited by 1, 2, 4, 8 byte sizes diff --git a/src/cmd/compile/internal/ssa/_gen/genericOps.go b/src/cmd/compile/internal/ssa/_gen/genericOps.go index a182afbaa8..69eb48ce44 100644 --- a/src/cmd/compile/internal/ssa/_gen/genericOps.go +++ b/src/cmd/compile/internal/ssa/_gen/genericOps.go @@ -477,7 +477,7 @@ var genericOps = []opData{ {name: "IsNonNil", argLength: 1, typ: "Bool"}, // arg0 != nil {name: "IsInBounds", argLength: 2, typ: "Bool"}, // 0 <= arg0 < arg1. arg1 is guaranteed >= 0. {name: "IsSliceInBounds", argLength: 2, typ: "Bool"}, // 0 <= arg0 <= arg1. arg1 is guaranteed >= 0. - {name: "NilCheck", argLength: 2, typ: "Void"}, // arg0=ptr, arg1=mem. Panics if arg0 is nil. Returns void. + {name: "NilCheck", argLength: 2, nilCheck: true}, // arg0=ptr, arg1=mem. Panics if arg0 is nil. Returns the ptr unmodified. // Pseudo-ops {name: "GetG", argLength: 1, zeroWidth: true}, // runtime.getg() (read g pointer). arg0=mem diff --git a/src/cmd/compile/internal/ssa/check.go b/src/cmd/compile/internal/ssa/check.go index f34b907419..bbfdaceaad 100644 --- a/src/cmd/compile/internal/ssa/check.go +++ b/src/cmd/compile/internal/ssa/check.go @@ -317,7 +317,28 @@ func checkFunc(f *Func) { if !v.Aux.(*ir.Name).Type().HasPointers() { f.Fatalf("vardef must have pointer type %s", v.Aux.(*ir.Name).Type().String()) } - + case OpNilCheck: + // nil checks have pointer type before scheduling, and + // void type after scheduling. + if f.scheduled { + if v.Uses != 0 { + f.Fatalf("nilcheck must have 0 uses %s", v.Uses) + } + if !v.Type.IsVoid() { + f.Fatalf("nilcheck must have void type %s", v.Type.String()) + } + } else { + if !v.Type.IsPtrShaped() && !v.Type.IsUintptr() { + f.Fatalf("nilcheck must have pointer type %s", v.Type.String()) + } + } + if !v.Args[0].Type.IsPtrShaped() && !v.Args[0].Type.IsUintptr() { + f.Fatalf("nilcheck must have argument of pointer type %s", v.Args[0].Type.String()) + } + if !v.Args[1].Type.IsMemory() { + f.Fatalf("bad arg 1 type to %s: want mem, have %s", + v.Op, v.Args[1].Type.String()) + } } // TODO: check for cycles in values diff --git a/src/cmd/compile/internal/ssa/deadcode.go b/src/cmd/compile/internal/ssa/deadcode.go index 52cc7f2ca7..ae9fd2ef24 100644 --- a/src/cmd/compile/internal/ssa/deadcode.go +++ b/src/cmd/compile/internal/ssa/deadcode.go @@ -110,16 +110,15 @@ func liveValues(f *Func, reachable []bool) (live []bool, liveOrderStmts []*Value } } for _, v := range b.Values { - if (opcodeTable[v.Op].call || opcodeTable[v.Op].hasSideEffects) && !live[v.ID] { + if (opcodeTable[v.Op].call || opcodeTable[v.Op].hasSideEffects || opcodeTable[v.Op].nilCheck) && !live[v.ID] { live[v.ID] = true q = append(q, v) if v.Pos.IsStmt() != src.PosNotStmt { liveOrderStmts = append(liveOrderStmts, v) } } - if v.Type.IsVoid() && !live[v.ID] { - // The only Void ops are nil checks and inline marks. We must keep these. - if v.Op == OpInlMark && !liveInlIdx[int(v.AuxInt)] { + if v.Op == OpInlMark { + if !liveInlIdx[int(v.AuxInt)] { // We don't need marks for bodies that // have been completely optimized away. // TODO: save marks only for bodies which diff --git a/src/cmd/compile/internal/ssa/deadstore.go b/src/cmd/compile/internal/ssa/deadstore.go index 7656e45cb9..cb3427103c 100644 --- a/src/cmd/compile/internal/ssa/deadstore.go +++ b/src/cmd/compile/internal/ssa/deadstore.go @@ -249,7 +249,7 @@ func elimDeadAutosGeneric(f *Func) { } if v.Uses == 0 && v.Op != OpNilCheck && !v.Op.IsCall() && !v.Op.HasSideEffects() || len(args) == 0 { - // Nil check has no use, but we need to keep it. + // We need to keep nil checks even if they have no use. // Also keep calls and values that have side effects. return } diff --git a/src/cmd/compile/internal/ssa/fuse.go b/src/cmd/compile/internal/ssa/fuse.go index 6d3fb70780..68defde7b4 100644 --- a/src/cmd/compile/internal/ssa/fuse.go +++ b/src/cmd/compile/internal/ssa/fuse.go @@ -169,7 +169,7 @@ func fuseBlockIf(b *Block) bool { // There may be false positives. func isEmpty(b *Block) bool { for _, v := range b.Values { - if v.Uses > 0 || v.Op.IsCall() || v.Op.HasSideEffects() || v.Type.IsVoid() { + if v.Uses > 0 || v.Op.IsCall() || v.Op.HasSideEffects() || v.Type.IsVoid() || opcodeTable[v.Op].nilCheck { return false } } diff --git a/src/cmd/compile/internal/ssa/fuse_test.go b/src/cmd/compile/internal/ssa/fuse_test.go index fa7921a18f..2f89938d1d 100644 --- a/src/cmd/compile/internal/ssa/fuse_test.go +++ b/src/cmd/compile/internal/ssa/fuse_test.go @@ -254,7 +254,7 @@ func TestFuseSideEffects(t *testing.T) { Valu("p", OpArg, c.config.Types.IntPtr, 0, nil), If("c1", "z0", "exit")), Bloc("z0", - Valu("nilcheck", OpNilCheck, types.TypeVoid, 0, nil, "p", "mem"), + Valu("nilcheck", OpNilCheck, c.config.Types.IntPtr, 0, nil, "p", "mem"), Goto("exit")), Bloc("exit", Exit("mem"), diff --git a/src/cmd/compile/internal/ssa/nilcheck.go b/src/cmd/compile/internal/ssa/nilcheck.go index 4f797a473f..c69cd8c32e 100644 --- a/src/cmd/compile/internal/ssa/nilcheck.go +++ b/src/cmd/compile/internal/ssa/nilcheck.go @@ -38,11 +38,14 @@ func nilcheckelim(f *Func) { work := make([]bp, 0, 256) work = append(work, bp{block: f.Entry}) - // map from value ID to bool indicating if value is known to be non-nil - // in the current dominator path being walked. This slice is updated by + // map from value ID to known non-nil version of that value ID + // (in the current dominator path being walked). This slice is updated by // walkStates to maintain the known non-nil values. - nonNilValues := f.Cache.allocBoolSlice(f.NumValues()) - defer f.Cache.freeBoolSlice(nonNilValues) + // If there is extrinsic information about non-nil-ness, this map + // points a value to itself. If a value is known non-nil because we + // already did a nil check on it, it points to the nil check operation. + nonNilValues := f.Cache.allocValueSlice(f.NumValues()) + defer f.Cache.freeValueSlice(nonNilValues) // make an initial pass identifying any non-nil values for _, b := range f.Blocks { @@ -54,7 +57,7 @@ func nilcheckelim(f *Func) { // We assume that SlicePtr is non-nil because we do a bounds check // before the slice access (and all cap>0 slices have a non-nil ptr). See #30366. if v.Op == OpAddr || v.Op == OpLocalAddr || v.Op == OpAddPtr || v.Op == OpOffPtr || v.Op == OpAdd32 || v.Op == OpAdd64 || v.Op == OpSub32 || v.Op == OpSub64 || v.Op == OpSlicePtr { - nonNilValues[v.ID] = true + nonNilValues[v.ID] = v } } } @@ -68,16 +71,16 @@ func nilcheckelim(f *Func) { if v.Op == OpPhi { argsNonNil := true for _, a := range v.Args { - if !nonNilValues[a.ID] { + if nonNilValues[a.ID] == nil { argsNonNil = false break } } if argsNonNil { - if !nonNilValues[v.ID] { + if nonNilValues[v.ID] == nil { changed = true } - nonNilValues[v.ID] = true + nonNilValues[v.ID] = v } } } @@ -103,8 +106,8 @@ func nilcheckelim(f *Func) { if len(b.Preds) == 1 { p := b.Preds[0].b if p.Kind == BlockIf && p.Controls[0].Op == OpIsNonNil && p.Succs[0].b == b { - if ptr := p.Controls[0].Args[0]; !nonNilValues[ptr.ID] { - nonNilValues[ptr.ID] = true + if ptr := p.Controls[0].Args[0]; nonNilValues[ptr.ID] == nil { + nonNilValues[ptr.ID] = ptr work = append(work, bp{op: ClearPtr, ptr: ptr}) } } @@ -117,14 +120,11 @@ func nilcheckelim(f *Func) { pendingLines.clear() // Next, process values in the block. - i := 0 for _, v := range b.Values { - b.Values[i] = v - i++ switch v.Op { case OpIsNonNil: ptr := v.Args[0] - if nonNilValues[ptr.ID] { + if nonNilValues[ptr.ID] != nil { if v.Pos.IsStmt() == src.PosIsStmt { // Boolean true is a terrible statement boundary. pendingLines.add(v.Pos) v.Pos = v.Pos.WithNotStmt() @@ -135,7 +135,7 @@ func nilcheckelim(f *Func) { } case OpNilCheck: ptr := v.Args[0] - if nonNilValues[ptr.ID] { + if nilCheck := nonNilValues[ptr.ID]; nilCheck != nil { // This is a redundant implicit nil check. // Logging in the style of the former compiler -- and omit line 1, // which is usually in generated code. @@ -145,14 +145,13 @@ func nilcheckelim(f *Func) { if v.Pos.IsStmt() == src.PosIsStmt { // About to lose a statement boundary pendingLines.add(v.Pos) } - v.reset(OpUnknown) - f.freeValue(v) - i-- + v.Op = OpCopy + v.SetArgs1(nilCheck) continue } // Record the fact that we know ptr is non nil, and remember to // undo that information when this dominator subtree is done. - nonNilValues[ptr.ID] = true + nonNilValues[ptr.ID] = v work = append(work, bp{op: ClearPtr, ptr: ptr}) fallthrough // a non-eliminated nil check might be a good place for a statement boundary. default: @@ -163,7 +162,7 @@ func nilcheckelim(f *Func) { } } // This reduces the lost statement count in "go" by 5 (out of 500 total). - for j := 0; j < i; j++ { // is this an ordering problem? + for j := range b.Values { // is this an ordering problem? v := b.Values[j] if v.Pos.IsStmt() != src.PosNotStmt && !isPoorStatementOp(v.Op) && pendingLines.contains(v.Pos) { v.Pos = v.Pos.WithIsStmt() @@ -174,7 +173,6 @@ func nilcheckelim(f *Func) { b.Pos = b.Pos.WithIsStmt() pendingLines.remove(b.Pos) } - b.truncateValues(i) // Add all dominated blocks to the work list. for w := sdom[node.block.ID].child; w != nil; w = sdom[w.ID].sibling { @@ -182,7 +180,7 @@ func nilcheckelim(f *Func) { } case ClearPtr: - nonNilValues[node.ptr.ID] = false + nonNilValues[node.ptr.ID] = nil continue } } diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 6b2320f44c..011bf94f72 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -39823,9 +39823,10 @@ var opcodeTable = [...]opInfo{ generic: true, }, { - name: "NilCheck", - argLen: 2, - generic: true, + name: "NilCheck", + argLen: 2, + nilCheck: true, + generic: true, }, { name: "GetG", diff --git a/src/cmd/compile/internal/ssa/rewrite.go b/src/cmd/compile/internal/ssa/rewrite.go index 09f588068e..c94a4202ec 100644 --- a/src/cmd/compile/internal/ssa/rewrite.go +++ b/src/cmd/compile/internal/ssa/rewrite.go @@ -859,6 +859,9 @@ func disjoint(p1 *Value, n1 int64, p2 *Value, n2 int64) bool { offset += base.AuxInt base = base.Args[0] } + if opcodeTable[base.Op].nilCheck { + base = base.Args[0] + } return base, offset } p1, off1 := baseAndOffset(p1) diff --git a/src/cmd/compile/internal/ssa/rewritegeneric.go b/src/cmd/compile/internal/ssa/rewritegeneric.go index 6dc87f411a..a018ca04b6 100644 --- a/src/cmd/compile/internal/ssa/rewritegeneric.go +++ b/src/cmd/compile/internal/ssa/rewritegeneric.go @@ -18968,79 +18968,84 @@ func rewriteValuegeneric_OpNilCheck(v *Value) bool { v_0 := v.Args[0] b := v.Block fe := b.Func.fe - // match: (NilCheck (GetG mem) mem) - // result: mem + // match: (NilCheck ptr:(GetG mem) mem) + // result: ptr for { - if v_0.Op != OpGetG { + ptr := v_0 + if ptr.Op != OpGetG { break } - mem := v_0.Args[0] + mem := ptr.Args[0] if mem != v_1 { break } - v.copyOf(mem) + v.copyOf(ptr) return true } - // match: (NilCheck (SelectN [0] call:(StaticLECall _ _)) _) + // match: (NilCheck ptr:(SelectN [0] call:(StaticLECall _ _)) _) // cond: isSameCall(call.Aux, "runtime.newobject") && warnRule(fe.Debug_checknil(), v, "removed nil check") - // result: (Invalid) + // result: ptr for { - if v_0.Op != OpSelectN || auxIntToInt64(v_0.AuxInt) != 0 { + ptr := v_0 + if ptr.Op != OpSelectN || auxIntToInt64(ptr.AuxInt) != 0 { break } - call := v_0.Args[0] + call := ptr.Args[0] if call.Op != OpStaticLECall || len(call.Args) != 2 || !(isSameCall(call.Aux, "runtime.newobject") && warnRule(fe.Debug_checknil(), v, "removed nil check")) { break } - v.reset(OpInvalid) + v.copyOf(ptr) return true } - // match: (NilCheck (OffPtr (SelectN [0] call:(StaticLECall _ _))) _) + // match: (NilCheck ptr:(OffPtr (SelectN [0] call:(StaticLECall _ _))) _) // cond: isSameCall(call.Aux, "runtime.newobject") && warnRule(fe.Debug_checknil(), v, "removed nil check") - // result: (Invalid) + // result: ptr for { - if v_0.Op != OpOffPtr { + ptr := v_0 + if ptr.Op != OpOffPtr { break } - v_0_0 := v_0.Args[0] - if v_0_0.Op != OpSelectN || auxIntToInt64(v_0_0.AuxInt) != 0 { + ptr_0 := ptr.Args[0] + if ptr_0.Op != OpSelectN || auxIntToInt64(ptr_0.AuxInt) != 0 { break } - call := v_0_0.Args[0] + call := ptr_0.Args[0] if call.Op != OpStaticLECall || len(call.Args) != 2 || !(isSameCall(call.Aux, "runtime.newobject") && warnRule(fe.Debug_checknil(), v, "removed nil check")) { break } - v.reset(OpInvalid) + v.copyOf(ptr) return true } - // match: (NilCheck (Addr {_} (SB)) _) - // result: (Invalid) + // match: (NilCheck ptr:(Addr {_} (SB)) _) + // result: ptr for { - if v_0.Op != OpAddr { + ptr := v_0 + if ptr.Op != OpAddr { break } - v_0_0 := v_0.Args[0] - if v_0_0.Op != OpSB { + ptr_0 := ptr.Args[0] + if ptr_0.Op != OpSB { break } - v.reset(OpInvalid) + v.copyOf(ptr) return true } - // match: (NilCheck (Convert (Addr {_} (SB)) _) _) - // result: (Invalid) + // match: (NilCheck ptr:(Convert (Addr {_} (SB)) _) _) + // result: ptr for { - if v_0.Op != OpConvert { + ptr := v_0 + if ptr.Op != OpConvert { break } - v_0_0 := v_0.Args[0] - if v_0_0.Op != OpAddr { + ptr_0 := ptr.Args[0] + if ptr_0.Op != OpAddr { break } - v_0_0_0 := v_0_0.Args[0] - if v_0_0_0.Op != OpSB { + ptr_0_0 := ptr_0.Args[0] + if ptr_0_0.Op != OpSB { break } - v.reset(OpInvalid) + v.copyOf(ptr) return true } return false diff --git a/src/cmd/compile/internal/ssa/schedule.go b/src/cmd/compile/internal/ssa/schedule.go index 13efb6ee70..fb38f40d63 100644 --- a/src/cmd/compile/internal/ssa/schedule.go +++ b/src/cmd/compile/internal/ssa/schedule.go @@ -307,14 +307,21 @@ func schedule(f *Func) { } // Remove SPanchored now that we've scheduled. + // Also unlink nil checks now that ordering is assured + // between the nil check and the uses of the nil-checked pointer. for _, b := range f.Blocks { for _, v := range b.Values { for i, a := range v.Args { - if a.Op == OpSPanchored { + if a.Op == OpSPanchored || opcodeTable[a.Op].nilCheck { v.SetArg(i, a.Args[0]) } } } + for i, c := range b.ControlValues() { + if c.Op == OpSPanchored || opcodeTable[c.Op].nilCheck { + b.ReplaceControl(i, c.Args[0]) + } + } } for _, b := range f.Blocks { i := 0 @@ -327,6 +334,15 @@ func schedule(f *Func) { v.resetArgs() f.freeValue(v) } else { + if opcodeTable[v.Op].nilCheck { + if v.Uses != 0 { + base.Fatalf("nilcheck still has %d uses", v.Uses) + } + // We can't delete the nil check, but we mark + // it as having void type so regalloc won't + // try to allocate a register for it. + v.Type = types.TypeVoid + } b.Values[i] = v i++ } diff --git a/src/cmd/compile/internal/ssa/value.go b/src/cmd/compile/internal/ssa/value.go index 1b33b1a1bb..4eaab40354 100644 --- a/src/cmd/compile/internal/ssa/value.go +++ b/src/cmd/compile/internal/ssa/value.go @@ -552,7 +552,11 @@ func (v *Value) LackingPos() bool { // if its use count drops to 0. func (v *Value) removeable() bool { if v.Type.IsVoid() { - // Void ops, like nil pointer checks, must stay. + // Void ops (inline marks), must stay. + return false + } + if opcodeTable[v.Op].nilCheck { + // Nil pointer checks must stay. return false } if v.Type.IsMemory() { diff --git a/src/cmd/compile/internal/ssagen/ssa.go b/src/cmd/compile/internal/ssagen/ssa.go index 5d5c79e581..24b82cffcd 100644 --- a/src/cmd/compile/internal/ssagen/ssa.go +++ b/src/cmd/compile/internal/ssagen/ssa.go @@ -2107,7 +2107,8 @@ func (s *state) stmt(n ir.Node) { case ir.OCHECKNIL: n := n.(*ir.UnaryExpr) p := s.expr(n.X) - s.nilCheck(p) + _ = s.nilCheck(p) + // TODO: check that throwing away the nilcheck result is ok. case ir.OINLMARK: n := n.(*ir.InlineMarkStmt) @@ -5729,18 +5730,20 @@ func (s *state) exprPtr(n ir.Node, bounded bool, lineno src.XPos) *ssa.Value { } return p } - s.nilCheck(p) + p = s.nilCheck(p) return p } // nilCheck generates nil pointer checking code. // Used only for automatically inserted nil checks, // not for user code like 'x != nil'. -func (s *state) nilCheck(ptr *ssa.Value) { +// Returns a "definitely not nil" copy of x to ensure proper ordering +// of the uses of the post-nilcheck pointer. +func (s *state) nilCheck(ptr *ssa.Value) *ssa.Value { if base.Debug.DisableNil != 0 || s.curfn.NilCheckDisabled() { - return + return ptr } - s.newValue2(ssa.OpNilCheck, types.TypeVoid, ptr, s.mem()) + return s.newValue2(ssa.OpNilCheck, ptr.Type, ptr, s.mem()) } // boundsCheck generates bounds checking code. Checks if 0 <= idx <[=] len, branches to exit if not. @@ -6092,8 +6095,8 @@ func (s *state) slice(v, i, j, k *ssa.Value, bounded bool) (p, l, c *ssa.Value) if !t.Elem().IsArray() { s.Fatalf("bad ptr to array in slice %v\n", t) } - s.nilCheck(v) - ptr = s.newValue1(ssa.OpCopy, types.NewPtr(t.Elem().Elem()), v) + nv := s.nilCheck(v) + ptr = s.newValue1(ssa.OpCopy, types.NewPtr(t.Elem().Elem()), nv) len = s.constInt(types.Types[types.TINT], t.Elem().NumElem()) cap = len default: diff --git a/test/fixedbugs/issue63657.go b/test/fixedbugs/issue63657.go new file mode 100644 index 0000000000..e32a4a34fb --- /dev/null +++ b/test/fixedbugs/issue63657.go @@ -0,0 +1,48 @@ +// run + +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Make sure address calculations don't float up before +// the corresponding nil check. + +package main + +type T struct { + a, b int +} + +//go:noinline +func f(x *T, p *bool, n int) { + *p = n != 0 + useStack(1000) + g(&x.b) +} + +//go:noinline +func g(p *int) { +} + +func useStack(n int) { + if n == 0 { + return + } + useStack(n - 1) +} + +func main() { + mustPanic(func() { + var b bool + f(nil, &b, 3) + }) +} + +func mustPanic(f func()) { + defer func() { + if recover() == nil { + panic("expected panic, got nil") + } + }() + f() +}