diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go index 503de8f927..84b0096c77 100644 --- a/src/go/build/deps_test.go +++ b/src/go/build/deps_test.go @@ -58,7 +58,6 @@ var depsRules = ` internal/nettrace, internal/platform, internal/profilerecord, - internal/runtime/exithook, internal/trace/traceviewer/format, log/internal, math/bits, @@ -79,7 +78,6 @@ var depsRules = ` internal/goexperiment, internal/goos, internal/profilerecord, - internal/runtime/exithook, math/bits < internal/bytealg < internal/stringslite @@ -88,6 +86,7 @@ var depsRules = ` < runtime/internal/sys < internal/runtime/syscall < internal/runtime/atomic + < internal/runtime/exithook < runtime/internal/math < runtime < sync/atomic diff --git a/src/internal/runtime/exithook/hooks.go b/src/internal/runtime/exithook/hooks.go index 931154c45d..eb8aa1ce0a 100644 --- a/src/internal/runtime/exithook/hooks.go +++ b/src/internal/runtime/exithook/hooks.go @@ -13,6 +13,11 @@ // restricted dialects used for the trickier parts of the runtime. package exithook +import ( + "internal/runtime/atomic" + _ "unsafe" // for linkname +) + // A Hook is a function to be run at program termination // (when someone invokes os.Exit, or when main.main returns). // Hooks are run in reverse order of registration: @@ -23,40 +28,56 @@ type Hook struct { } var ( + locked atomic.Int32 + runGoid atomic.Uint64 hooks []Hook running bool + + // runtime sets these for us + Gosched func() + Goid func() uint64 + Throw func(string) ) // Add adds a new exit hook. func Add(h Hook) { + for !locked.CompareAndSwap(0, 1) { + Gosched() + } hooks = append(hooks, h) + locked.Store(0) } // Run runs the exit hooks. -// It returns an error if Run is already running or -// if one of the hooks panics. -func Run(code int) (err error) { - if running { - return exitError("exit hook invoked exit") +// +// If an exit hook panics, Run will throw with the panic on the stack. +// If an exit hook invokes exit in the same goroutine, the goroutine will throw. +// If an exit hook invokes exit in another goroutine, that exit will block. +func Run(code int) { + for !locked.CompareAndSwap(0, 1) { + if Goid() == runGoid.Load() { + Throw("exit hook invoked exit") + } + Gosched() } - running = true + defer locked.Store(0) + runGoid.Store(Goid()) + defer runGoid.Store(0) defer func() { - if x := recover(); x != nil { - err = exitError("exit hook invoked panic") + if e := recover(); e != nil { + Throw("exit hook invoked panic") } }() - local := hooks - hooks = nil - for i := len(local) - 1; i >= 0; i-- { - h := local[i] - if code == 0 || h.RunOnFailure { - h.F() + for len(hooks) > 0 { + h := hooks[len(hooks)-1] + hooks = hooks[:len(hooks)-1] + if code != 0 && !h.RunOnFailure { + continue } + h.F() } - running = false - return nil } type exitError string diff --git a/src/runtime/ehooks_test.go b/src/runtime/ehooks_test.go index 2265256a0b..4beb20b0be 100644 --- a/src/runtime/ehooks_test.go +++ b/src/runtime/ehooks_test.go @@ -28,32 +28,36 @@ func TestExitHooks(t *testing.T) { scenarios := []struct { mode string expected string - musthave string + musthave []string }{ { mode: "simple", expected: "bar foo", - musthave: "", }, { mode: "goodexit", expected: "orange apple", - musthave: "", }, { mode: "badexit", expected: "blub blix", - musthave: "", }, { - mode: "panics", - expected: "", - musthave: "fatal error: exit hook invoked panic", + mode: "panics", + musthave: []string{ + "fatal error: exit hook invoked panic", + "main.testPanics", + }, }, { - mode: "callsexit", + mode: "callsexit", + musthave: []string{ + "fatal error: exit hook invoked exit", + }, + }, + { + mode: "exit2", expected: "", - musthave: "fatal error: exit hook invoked exit", }, } @@ -71,20 +75,18 @@ func TestExitHooks(t *testing.T) { out, _ := cmd.CombinedOutput() outs := strings.ReplaceAll(string(out), "\n", " ") outs = strings.TrimSpace(outs) - if s.expected != "" { - if s.expected != outs { - t.Logf("raw output: %q", outs) - t.Errorf("failed%s mode %s: wanted %q got %q", bt, - s.mode, s.expected, outs) + if s.expected != "" && s.expected != outs { + t.Fatalf("failed%s mode %s: wanted %q\noutput:\n%s", bt, + s.mode, s.expected, outs) + } + for _, need := range s.musthave { + if !strings.Contains(outs, need) { + t.Fatalf("failed mode %s: output does not contain %q\noutput:\n%s", + s.mode, need, outs) } - } else if s.musthave != "" { - if !strings.Contains(outs, s.musthave) { - t.Logf("raw output: %q", outs) - t.Errorf("failed mode %s: output does not contain %q", - s.mode, s.musthave) - } - } else { - panic("badly written scenario") + } + if s.expected == "" && s.musthave == nil && outs != "" { + t.Errorf("failed mode %s: wanted no output\noutput:\n%s", s.mode, outs) } } } diff --git a/src/runtime/proc.go b/src/runtime/proc.go index c5bf537a75..17b2e4d9c2 100644 --- a/src/runtime/proc.go +++ b/src/runtime/proc.go @@ -310,10 +310,14 @@ func os_beforeExit(exitCode int) { } } +func init() { + exithook.Gosched = Gosched + exithook.Goid = func() uint64 { return getg().goid } + exithook.Throw = throw +} + func runExitHooks(code int) { - if err := exithook.Run(code); err != nil { - throw(err.Error()) - } + exithook.Run(code) } // start forcegc helper goroutine diff --git a/src/runtime/testdata/testexithooks/testexithooks.go b/src/runtime/testdata/testexithooks/testexithooks.go index 151b5dc62b..d734aacb2d 100644 --- a/src/runtime/testdata/testexithooks/testexithooks.go +++ b/src/runtime/testdata/testexithooks/testexithooks.go @@ -6,8 +6,9 @@ package main import ( "flag" - "os" "internal/runtime/exithook" + "os" + "time" _ "unsafe" ) @@ -26,6 +27,8 @@ func main() { testPanics() case "callsexit": testHookCallsExit() + case "exit2": + testExit2() default: panic("unknown mode") } @@ -81,3 +84,12 @@ func testHookCallsExit() { exithook.Add(exithook.Hook{F: f3, RunOnFailure: true}) os.Exit(1) } + +func testExit2() { + f1 := func() { time.Sleep(100 * time.Millisecond) } + exithook.Add(exithook.Hook{F: f1}) + for range 10 { + go os.Exit(0) + } + os.Exit(0) +}