diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index 963f818539..6cbd60762c 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -147,14 +147,15 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface for _, h := range c.handlers { ctx = h.Request(ctx, c, Send, request) } - // we have to add ourselves to the pending map before we send, otherwise we - // are racing the response - rchan := make(chan *WireResponse) + // We have to add ourselves to the pending map before we send, otherwise we + // are racing the response. Also add a buffer to rchan, so that if we get a + // wire response between the time this call is cancelled and id is deleted + // from c.pending, the send to rchan will not block. + rchan := make(chan *WireResponse, 1) c.pendingMu.Lock() c.pending[id] = rchan c.pendingMu.Unlock() defer func() { - // clean up the pending response handler on the way out c.pendingMu.Lock() delete(c.pending, id) c.pendingMu.Unlock() @@ -189,7 +190,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface } return nil case <-ctx.Done(): - // allow the handler to propagate the cancel + // Allow the handler to propagate the cancel. cancelled := false for _, h := range c.handlers { if h.Cancel(ctx, c, id, cancelled) { @@ -328,10 +329,10 @@ func (c *Conn) Run(runCtx context.Context) error { } continue } - // work out which kind of message we have + // Work out whether this is a request or response. switch { case msg.Method != "": - // if method is set it must be a request + // If method is set it must be a request. reqCtx, cancelReq := context.WithCancel(runCtx) thisRequest := nextRequest nextRequest = make(chan struct{}) @@ -373,21 +374,19 @@ func (c *Conn) Run(runCtx context.Context) error { } }() case msg.ID != nil: - // we have a response, get the pending entry from the map + // If method is not set, this should be a response, in which case we must + // have an id to send the response back to the caller. c.pendingMu.Lock() - rchan := c.pending[*msg.ID] - if rchan != nil { - delete(c.pending, *msg.ID) - } + rchan, ok := c.pending[*msg.ID] c.pendingMu.Unlock() - // and send the reply to the channel - response := &WireResponse{ - Result: msg.Result, - Error: msg.Error, - ID: msg.ID, + if ok { + response := &WireResponse{ + Result: msg.Result, + Error: msg.Error, + ID: msg.ID, + } + rchan <- response } - rchan <- response - close(rchan) default: for _, h := range c.handlers { h.Error(runCtx, fmt.Errorf("message not a call, notify or response, ignoring")) diff --git a/internal/lsp/lsprpc/lsprpc_test.go b/internal/lsp/lsprpc/lsprpc_test.go index a36affc0c9..b7c20fe476 100644 --- a/internal/lsp/lsprpc/lsprpc_test.go +++ b/internal/lsp/lsprpc/lsprpc_test.go @@ -7,6 +7,7 @@ package lsprpc import ( "context" "regexp" + "sync" "testing" "time" @@ -61,40 +62,37 @@ func TestClientLogging(t *testing.T) { if !matched { t.Errorf("got log %q, want a log containing %q", got, want) } - case <-time.After(1000 * time.Second): + case <-time.After(1 * time.Second): t.Error("timeout waiting for client log") } } +// waitableServer instruments LSP request so that we can control their timing. +// The requests chosen are arbitrary: we simply needed one that blocks, and +// another that doesn't. type waitableServer struct { protocol.Server started chan struct{} - // finished records whether the request ended with a cancellation or not - // (true means the request was cancelled). - finished chan bool } -func (s waitableServer) CodeLens(ctx context.Context, params *protocol.CodeLensParams) ([]protocol.CodeLens, error) { +func (s waitableServer) Hover(ctx context.Context, _ *protocol.HoverParams) (*protocol.Hover, error) { s.started <- struct{}{} - cancelled := false - defer func() { - s.finished <- cancelled - }() select { case <-ctx.Done(): - cancelled = true return nil, ctx.Err() - case <-time.After(1 * time.Second): - cancelled = false + case <-time.After(200 * time.Millisecond): } - return []protocol.CodeLens{}, nil + return &protocol.Hover{}, nil +} + +func (s waitableServer) Resolve(_ context.Context, item *protocol.CompletionItem) (*protocol.CompletionItem, error) { + return item, nil } func TestRequestCancellation(t *testing.T) { server := waitableServer{ - started: make(chan struct{}), - finished: make(chan bool), + started: make(chan struct{}), } ss := &StreamServer{ accept: func(c protocol.Client) protocol.Server { @@ -119,14 +117,33 @@ func TestRequestCancellation(t *testing.T) { t.Run(test.serverType, func(t *testing.T) { cc := test.ts.Connect(ctx) cc.AddHandler(protocol.Canceller{}) - lensCtx, cancelLens := context.WithCancel(context.Background()) + ctx := context.Background() + ctx1, cancel1 := context.WithCancel(ctx) + var ( + err1, err2 error + wg sync.WaitGroup + ) + wg.Add(2) go func() { - protocol.ServerDispatcher(cc).CodeLens(lensCtx, &protocol.CodeLensParams{}) + defer wg.Done() + _, err1 = protocol.ServerDispatcher(cc).Hover(ctx1, &protocol.HoverParams{}) }() + go func() { + defer wg.Done() + _, err2 = protocol.ServerDispatcher(cc).Resolve(ctx, &protocol.CompletionItem{}) + }() + // Wait for the Hover request to start. <-server.started - cancelLens() - if got, want := <-server.finished, true; got != want { - t.Errorf("CodeLens was cancelled: %t, want %t", got, want) + cancel1() + wg.Wait() + if err1 == nil { + t.Errorf("cancelled Hover(): got nil err") + } + if err2 != nil { + t.Errorf("uncancelled Hover(): err: %v", err2) + } + if _, err := protocol.ServerDispatcher(cc).Resolve(ctx, &protocol.CompletionItem{}); err != nil { + t.Errorf("subsequent Hover(): %v", err) } }) }