diff --git a/internal/jsonrpc2_v2/conn.go b/internal/jsonrpc2_v2/conn.go index 7c48e2ec61..74f1de1535 100644 --- a/internal/jsonrpc2_v2/conn.go +++ b/internal/jsonrpc2_v2/conn.go @@ -82,10 +82,12 @@ type Connection struct { // inFlightState records the state of the incoming and outgoing calls on a // Connection. type inFlightState struct { - closing bool // disallow enqueuing further requests, and close the Closer when transitioning to idle - readErr error + closing bool // disallow enqueuing further requests, and close the Closer when transitioning to idle + readErr error + writeErr error - outgoing map[ID]*AsyncCall // calls only + outgoingCalls map[ID]*AsyncCall // calls only + outgoingNotifications int // # of notifications awaiting "write" // incoming stores the total number of incoming calls and notifications // that have not yet written or processed a result. @@ -104,7 +106,7 @@ type inFlightState struct { // updateInFlight locks the state of the connection's in-flight requests, allows // f to mutate that state, and closes the connection if it is idle and either -// is closing or has a read error. +// is closing or has a read or write error. func (c *Connection) updateInFlight(f func(*inFlightState)) { c.stateMu.Lock() defer c.stateMu.Unlock() @@ -113,8 +115,8 @@ func (c *Connection) updateInFlight(f func(*inFlightState)) { f(s) - idle := s.incoming == 0 && len(s.outgoing) == 0 && !s.handlerRunning - if idle && (s.closing || s.readErr != nil) && !s.closed { + idle := len(s.outgoingCalls) == 0 && s.outgoingNotifications == 0 && s.incoming == 0 && !s.handlerRunning + if idle && (s.closing || s.readErr != nil || s.writeErr != nil) && !s.closed { c.closeErr <- c.closer.Close() if c.onDone != nil { c.onDone() @@ -181,20 +183,42 @@ func newConnection(bindCtx context.Context, rwc io.ReadWriteCloser, binder Binde // Notify invokes the target method but does not wait for a response. // The params will be marshaled to JSON before sending over the wire, and will // be handed to the method invoked. -func (c *Connection) Notify(ctx context.Context, method string, params interface{}) error { - notify, err := NewNotification(method, params) - if err != nil { - return fmt.Errorf("marshaling notify parameters: %v", err) - } +func (c *Connection) Notify(ctx context.Context, method string, params interface{}) (err error) { ctx, done := event.Start(ctx, method, tag.Method.Of(method), tag.RPCDirection.Of(tag.Outbound), ) + attempted := false + + defer func() { + labelStatus(ctx, err) + done() + if attempted { + c.updateInFlight(func(s *inFlightState) { + s.outgoingNotifications-- + }) + } + }() + + c.updateInFlight(func(s *inFlightState) { + if s.writeErr != nil { + err = fmt.Errorf("%w: %v", ErrClientClosing, s.writeErr) + return + } + s.outgoingNotifications++ + attempted = true + }) + if err != nil { + return err + } + + notify, err := NewNotification(method, params) + if err != nil { + return fmt.Errorf("marshaling notify parameters: %v", err) + } + event.Metric(ctx, tag.Started.Of(1)) - err = c.write(ctx, notify) - labelStatus(ctx, err) - done() - return err + return c.write(ctx, notify) } // Call invokes the target method and returns an object that can be used to await the response. @@ -239,10 +263,18 @@ func (c *Connection) Call(ctx context.Context, method string, params interface{} err = fmt.Errorf("%w: %v", ErrClientClosing, s.readErr) return } - if s.outgoing == nil { - s.outgoing = make(map[ID]*AsyncCall) + if s.writeErr != nil { + // Don't start the call if the write end has failed, either. + // We have reason to believe that the write would not succeed, + // and if we avoid adding in-flight calls then eventually + // the connection will go idle and be closed. + err = fmt.Errorf("%w: %v", ErrClientClosing, s.writeErr) + return } - s.outgoing[ac.id] = ac + if s.outgoingCalls == nil { + s.outgoingCalls = make(map[ID]*AsyncCall) + } + s.outgoingCalls[ac.id] = ac }) if err != nil { ac.retire(&Response{ID: id, Error: err}) @@ -254,8 +286,8 @@ func (c *Connection) Call(ctx context.Context, method string, params interface{} // Sending failed. We will never get a response, so deliver a fake one if it // wasn't already retired by the connection breaking. c.updateInFlight(func(s *inFlightState) { - if s.outgoing[ac.id] == ac { - delete(s.outgoing, ac.id) + if s.outgoingCalls[ac.id] == ac { + delete(s.outgoingCalls, ac.id) ac.retire(&Response{ID: id, Error: err}) } else { // ac was already retired by the readIncoming goroutine: @@ -405,8 +437,8 @@ func (c *Connection) readIncoming(ctx context.Context, reader Reader, preempter case *Response: c.updateInFlight(func(s *inFlightState) { - if ac, ok := s.outgoing[msg.ID]; ok { - delete(s.outgoing, msg.ID) + if ac, ok := s.outgoingCalls[msg.ID]; ok { + delete(s.outgoingCalls, msg.ID) ac.retire(msg) } else { // TODO: How should we report unexpected responses? @@ -423,10 +455,10 @@ func (c *Connection) readIncoming(ctx context.Context, reader Reader, preempter // Retire any outgoing requests that were still in flight: with the Reader no // longer being processed, they necessarily cannot receive a response. - for id, ac := range s.outgoing { + for id, ac := range s.outgoingCalls { ac.retire(&Response{ID: id, Error: err}) } - s.outgoing = nil + s.outgoingCalls = nil }) } @@ -482,6 +514,14 @@ func (c *Connection) acceptRequest(ctx context.Context, msg *Request, msgBytes i err = ErrServerClosing return } + + if s.writeErr != nil { + // The write side of the connection appears to be broken, + // so we won't be able to write a response to this request. + // Avoid unnecessary work to compute it. + err = fmt.Errorf("%w: %v", ErrServerClosing, s.writeErr) + return + } } }) if err != nil { @@ -557,12 +597,19 @@ func (c *Connection) handleAsync() { return } - var result interface{} - err := req.ctx.Err() - if err == nil { - // Only deliver to the Handler if not already cancelled. - result, err = c.handler.Handle(req.ctx, req.Request) + // Only deliver to the Handler if not already canceled. + if err := req.ctx.Err(); err != nil { + c.updateInFlight(func(s *inFlightState) { + if s.writeErr != nil { + // Assume that req.ctx was canceled due to s.writeErr. + // TODO(#51365): use a Context API to plumb this through req.ctx. + err = fmt.Errorf("%w: %v", ErrServerClosing, s.writeErr) + } + }) + c.processResult("handleAsync", req, nil, err) } + + result, err := c.handler.Handle(req.ctx, req.Request) c.processResult(c.handler, req, result, err) } } @@ -646,12 +693,24 @@ func (c *Connection) write(ctx context.Context, msg Message) error { n, err := writer.Write(ctx, msg) event.Metric(ctx, tag.SentBytes.Of(n)) - // TODO: if err != nil, that suggests that future writes will not succeed, - // so we cannot possibly write the results of incoming Call requests. - // If the read side of the connection is also broken, we also might not have - // a way to receive cancellation notifications. - // - // Should we cancel the pending calls implicitly? + if err != nil && ctx.Err() == nil { + // The call to Write failed, and since ctx.Err() is nil we can't attribute + // the failure (even indirectly) to Context cancellation. The writer appears + // to be broken, and future writes are likely to also fail. + // + // If the read side of the connection is also broken, we might not even be + // able to receive cancellation notifications. Since we can't reliably write + // the results of incoming calls and can't receive explicit cancellations, + // cancel the calls now. + c.updateInFlight(func(s *inFlightState) { + if s.writeErr == nil { + s.writeErr = err + for _, r := range s.incomingByID { + r.cancel() + } + } + }) + } return err }