diff --git a/internal/jsonrpc2/handler.go b/internal/jsonrpc2/handler.go index ad7fb9c4a7..860ea1f52a 100644 --- a/internal/jsonrpc2/handler.go +++ b/internal/jsonrpc2/handler.go @@ -74,7 +74,7 @@ func MethodNotFound(ctx context.Context, r *Request) error { func MustReply(handler Handler) Handler { return func(ctx context.Context, req *Request) error { err := handler(ctx, req) - if req.state < requestReplied { + if req.done != nil { panic(fmt.Errorf("request %q was never replied to", req.Method)) } return err diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index 441fe3f46c..c67a72ac96 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -38,22 +38,13 @@ type Conn struct { handling map[ID]*Request } -type requestState int - -const ( - requestWaiting = requestState(iota) - requestSerial - requestParallel - requestReplied - requestDone -) - // Request is sent to a server to represent a Call or Notify operaton. type Request struct { - conn *Conn - cancel context.CancelFunc - state requestState - nextRequest chan struct{} + conn *Conn + cancel context.CancelFunc + // done holds set of callbacks added by OnReply, and is set back to nil if + // Reply has been called. + done []func() // The Wire values of the request. WireRequest @@ -222,16 +213,17 @@ func (r *Request) IsNotify() bool { // This will mark the request as done, triggering any done // handlers func (r *Request) Reply(ctx context.Context, result interface{}, err error) error { - if r.state >= requestReplied { + if r.done == nil { return fmt.Errorf("reply invoked more than once") } - if r.state < requestParallel { - r.state = requestParallel - close(r.nextRequest) - } - r.state = requestReplied - recordStatus(ctx, nil) + defer func() { + recordStatus(ctx, err) + for i := len(r.done); i > 0; i-- { + r.done[i-1]() + } + r.done = nil + }() if r.IsNotify() { return nil @@ -280,6 +272,17 @@ func setHandling(r *Request, active bool) { } } +// OnReply adds a done callback to the request. +// All added callbacks are invoked during the one required call to Reply, and +// then dropped. +// It is an error to call this after Reply. +// This call is not safe for concurrent use, but should only be invoked by +// handlers and in general only one handler should be working on a request +// at any time. +func (r *Request) OnReply(do func()) { + r.done = append(r.done, do) +} + // combined has all the fields of both Request and Response. // We can decode this and then work out which it is. type combined struct { @@ -322,12 +325,12 @@ func (c *Conn) Run(runCtx context.Context, handler Handler) error { case msg.Method != "": // If method is set it must be a request. reqCtx, cancelReq := context.WithCancel(runCtx) - thisRequest := nextRequest + waitForPrevious := nextRequest nextRequest = make(chan struct{}) + unlockNext := nextRequest req := &Request{ - conn: c, - cancel: cancelReq, - nextRequest: nextRequest, + conn: c, + cancel: cancelReq, WireRequest: WireRequest{ VersionTag: msg.VersionTag, Method: msg.Method, @@ -335,6 +338,9 @@ func (c *Conn) Run(runCtx context.Context, handler Handler) error { ID: msg.ID, }, } + req.OnReply(func() { + close(unlockNext) + }) if c.LegacyHooks != nil { reqCtx = c.LegacyHooks.Request(reqCtx, c, Receive, &req.WireRequest) } @@ -349,9 +355,8 @@ func (c *Conn) Run(runCtx context.Context, handler Handler) error { setHandling(req, true) _, queueDone := event.StartSpan(reqCtx, "queued") go func() { - <-thisRequest + <-waitForPrevious queueDone() - req.state = requestSerial defer func() { setHandling(req, false) done()