diff --git a/internal/jsonrpc2_v2/conn.go b/internal/jsonrpc2_v2/conn.go index 7d99a02f74..606c3f99cb 100644 --- a/internal/jsonrpc2_v2/conn.go +++ b/internal/jsonrpc2_v2/conn.go @@ -112,6 +112,7 @@ func newConnection(ctx context.Context, rwc io.ReadWriteCloser, binder Binder) ( go c.readIncoming(ctx, reader, readToQueue) go c.manageQueue(ctx, options.Preempter, readToQueue, queueToDeliver) go c.deliverMessages(ctx, options.Handler, queueToDeliver) + // releaseing the writer must be the last thing we do in case any requests // are blocked waiting for the connection to be ready c.writerBox <- options.Framer.Writer(rwc) diff --git a/internal/jsonrpc2_v2/jsonrpc2.go b/internal/jsonrpc2_v2/jsonrpc2.go index faaf205e9a..271f42cf26 100644 --- a/internal/jsonrpc2_v2/jsonrpc2.go +++ b/internal/jsonrpc2_v2/jsonrpc2.go @@ -57,11 +57,11 @@ func (f HandlerFunc) Handle(ctx context.Context, req *Request) (interface{}, err return f(ctx, req) } -// async is a small helper for things with an asynchronous result that you can -// wait for. +// async is a small helper for operations with an asynchronous result that you +// can wait for. type async struct { - ready chan struct{} - errBox chan error + ready chan struct{} // signals that the operation has completed + errBox chan error // guards the operation result } func newAsync() *async { diff --git a/internal/jsonrpc2_v2/jsonrpc2_test.go b/internal/jsonrpc2_v2/jsonrpc2_test.go index 6d057b45ae..1157779f3b 100644 --- a/internal/jsonrpc2_v2/jsonrpc2_test.go +++ b/internal/jsonrpc2_v2/jsonrpc2_test.go @@ -126,7 +126,7 @@ func TestConnectionHeader(t *testing.T) { func testConnection(t *testing.T, framer jsonrpc2.Framer) { stacktest.NoLeak(t) ctx := eventtest.NewContext(context.Background(), t) - listener, err := jsonrpc2.NetPipe(ctx) + listener, err := jsonrpc2.NetPipeListener(ctx) if err != nil { t.Fatal(err) } diff --git a/internal/jsonrpc2_v2/net.go b/internal/jsonrpc2_v2/net.go index c8cfaab40b..0b413d8913 100644 --- a/internal/jsonrpc2_v2/net.go +++ b/internal/jsonrpc2_v2/net.go @@ -80,11 +80,11 @@ func (n *netDialer) Dial(ctx context.Context) (io.ReadWriteCloser, error) { return n.dialer.DialContext(ctx, n.network, n.address) } -// NetPipe returns a new Listener that listens using net.Pipe. +// NetPipeListener returns a new Listener that listens using net.Pipe. // It is only possibly to connect to it using the Dialier returned by the // Dialer method, each call to that method will generate a new pipe the other // side of which will be returnd from the Accept call. -func NetPipe(ctx context.Context) (Listener, error) { +func NetPipeListener(ctx context.Context) (Listener, error) { return &netPiper{ done: make(chan struct{}), dialed: make(chan io.ReadWriteCloser), diff --git a/internal/jsonrpc2_v2/serve_test.go b/internal/jsonrpc2_v2/serve_test.go index 7f1dbc3c97..26cf6a58c4 100644 --- a/internal/jsonrpc2_v2/serve_test.go +++ b/internal/jsonrpc2_v2/serve_test.go @@ -89,7 +89,7 @@ func TestServe(t *testing.T) { return jsonrpc2.NetListener(ctx, "tcp", "localhost:0", jsonrpc2.NetListenOptions{}) }}, {"pipe", func(ctx context.Context) (jsonrpc2.Listener, error) { - return jsonrpc2.NetPipe(ctx) + return jsonrpc2.NetPipeListener(ctx) }}, } diff --git a/internal/lsp/lsprpc/binder.go b/internal/lsp/lsprpc/binder.go index 3f5cb3b423..61f82dead7 100644 --- a/internal/lsp/lsprpc/binder.go +++ b/internal/lsp/lsprpc/binder.go @@ -7,9 +7,12 @@ package lsprpc import ( "context" "encoding/json" + "fmt" + "golang.org/x/tools/internal/event" jsonrpc2_v2 "golang.org/x/tools/internal/jsonrpc2_v2" "golang.org/x/tools/internal/lsp/protocol" + "golang.org/x/tools/internal/xcontext" errors "golang.org/x/xerrors" ) @@ -87,8 +90,19 @@ func (b *ForwardBinder) Bind(ctx context.Context, conn *jsonrpc2_v2.Connection) return opts, err } server := protocol.ServerDispatcherV2(serverConn) + preempter := &canceler{ + conn: conn, + } + detached := xcontext.Detach(ctx) + go func() { + conn.Wait() + if err := serverConn.Close(); err != nil { + event.Log(detached, fmt.Sprintf("closing remote connection: %v", err)) + } + }() return jsonrpc2_v2.ConnectionOptions{ - Handler: protocol.ServerHandlerV2(server), + Handler: protocol.ServerHandlerV2(server), + Preempter: preempter, }, nil } diff --git a/internal/lsp/lsprpc/binder_test.go b/internal/lsp/lsprpc/binder_test.go index d29de0f774..5cbdb2003b 100644 --- a/internal/lsp/lsprpc/binder_test.go +++ b/internal/lsp/lsprpc/binder_test.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TODO(rFindley): move this to lsprpc_test once it no longer shares with -// lsprpc_test.go. +// TODO(rFindley): move this to the lsprpc_test package once it no longer +// shares with lsprpc_test.go. package lsprpc @@ -19,25 +19,41 @@ import ( ) type testEnv struct { - listener jsonrpc2_v2.Listener - conn *jsonrpc2_v2.Connection - rpcServer *jsonrpc2_v2.Server + listener jsonrpc2_v2.Listener + server *jsonrpc2_v2.Server + + // non-nil if constructed with forwarded=true + fwdListener jsonrpc2_v2.Listener + fwdServer *jsonrpc2_v2.Server + + // the ingoing connection, either to the forwarder or server + conn *jsonrpc2_v2.Connection } func (e testEnv) Shutdown(t *testing.T) { if err := e.listener.Close(); err != nil { t.Error(err) } + if e.fwdListener != nil { + if err := e.fwdListener.Close(); err != nil { + t.Error(err) + } + } if err := e.conn.Close(); err != nil { t.Error(err) } - if err := e.rpcServer.Wait(); err != nil { + if e.fwdServer != nil { + if err := e.fwdServer.Wait(); err != nil { + t.Error(err) + } + } + if err := e.server.Wait(); err != nil { t.Error(err) } } -func startServing(ctx context.Context, t *testing.T, server protocol.Server, client protocol.Client) testEnv { - listener, err := jsonrpc2_v2.NetPipe(ctx) +func startServing(ctx context.Context, t *testing.T, server protocol.Server, client protocol.Client, forwarded bool) testEnv { + listener, err := jsonrpc2_v2.NetPipeListener(ctx) if err != nil { t.Fatal(err) } @@ -49,69 +65,102 @@ func startServing(ctx context.Context, t *testing.T, server protocol.Server, cli if err != nil { t.Fatal(err) } + env := testEnv{ + listener: listener, + server: rpcServer, + } clientBinder := NewClientBinder(func(context.Context, protocol.Server) protocol.Client { return client }) - conn, err := jsonrpc2_v2.Dial(ctx, listener.Dialer(), clientBinder) - if err != nil { - t.Fatal(err) - } - return testEnv{ - listener: listener, - rpcServer: rpcServer, - conn: conn, - } -} - -func TestClientLoggingV2(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - client := fakeClient{logs: make(chan string, 10)} - env := startServing(ctx, t, pingServer{}, client) - defer env.Shutdown(t) - if err := protocol.ServerDispatcherV2(env.conn).DidOpen(ctx, &protocol.DidOpenTextDocumentParams{}); err != nil { - t.Errorf("DidOpen: %v", err) - } - select { - case got := <-client.logs: - want := "ping" - matched, err := regexp.MatchString(want, got) + if forwarded { + fwdListener, err := jsonrpc2_v2.NetPipeListener(ctx) if err != nil { t.Fatal(err) } - if !matched { - t.Errorf("got log %q, want a log containing %q", got, want) + fwdBinder := NewForwardBinder(listener.Dialer()) + fwdServer, err := jsonrpc2_v2.Serve(ctx, fwdListener, fwdBinder) + if err != nil { + t.Fatal(err) } - case <-time.After(1 * time.Second): - t.Error("timeout waiting for client log") + conn, err := jsonrpc2_v2.Dial(ctx, fwdListener.Dialer(), clientBinder) + if err != nil { + t.Fatal(err) + } + env.fwdListener = fwdListener + env.fwdServer = fwdServer + env.conn = conn + } else { + conn, err := jsonrpc2_v2.Dial(ctx, listener.Dialer(), clientBinder) + if err != nil { + t.Fatal(err) + } + env.conn = conn + } + return env +} + +func TestClientLoggingV2(t *testing.T) { + ctx := context.Background() + + for name, forwarded := range map[string]bool{ + "forwarded": true, + "standalone": false, + } { + t.Run(name, func(t *testing.T) { + client := fakeClient{logs: make(chan string, 10)} + env := startServing(ctx, t, pingServer{}, client, forwarded) + defer env.Shutdown(t) + if err := protocol.ServerDispatcherV2(env.conn).DidOpen(ctx, &protocol.DidOpenTextDocumentParams{}); err != nil { + t.Errorf("DidOpen: %v", err) + } + select { + case got := <-client.logs: + want := "ping" + matched, err := regexp.MatchString(want, got) + if err != nil { + t.Fatal(err) + } + if !matched { + t.Errorf("got log %q, want a log containing %q", got, want) + } + case <-time.After(1 * time.Second): + t.Error("timeout waiting for client log") + } + }) } } func TestRequestCancellationV2(t *testing.T) { ctx := context.Background() - server := waitableServer{ - started: make(chan struct{}), - completed: make(chan error), - } - client := fakeClient{logs: make(chan string, 10)} - env := startServing(ctx, t, server, client) - defer env.Shutdown(t) + for name, forwarded := range map[string]bool{ + "forwarded": true, + "standalone": false, + } { + t.Run(name, func(t *testing.T) { + server := waitableServer{ + started: make(chan struct{}), + completed: make(chan error), + } + client := fakeClient{logs: make(chan string, 10)} + env := startServing(ctx, t, server, client, forwarded) + defer env.Shutdown(t) - sd := protocol.ServerDispatcherV2(env.conn) - ctx, cancel := context.WithCancel(ctx) + sd := protocol.ServerDispatcherV2(env.conn) + ctx, cancel := context.WithCancel(ctx) - result := make(chan error) - go func() { - _, err := sd.Hover(ctx, &protocol.HoverParams{}) - result <- err - }() - // Wait for the Hover request to start. - <-server.started - cancel() - if err := <-result; err == nil { - t.Error("nil error for cancelled Hover(), want non-nil") - } - if err := <-server.completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") { - t.Errorf("Hover(): unexpected server-side error %v", err) + result := make(chan error) + go func() { + _, err := sd.Hover(ctx, &protocol.HoverParams{}) + result <- err + }() + // Wait for the Hover request to start. + <-server.started + cancel() + if err := <-result; err == nil { + t.Error("nil error for cancelled Hover(), want non-nil") + } + if err := <-server.completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") { + t.Errorf("Hover(): unexpected server-side error %v", err) + } + }) } }