internal/lsp/lsprpc: use middleware for the V2 handshaking

The number of concerns satisfied by the lsprpc package is proving to be
a barrier to both factoring out the debug and event package, and use of
service discovery.

The new Binder API admits a nice abstraction that can help make the
lsprpc package more modular, or perhaps even unnecessary: a binder
middleware that can instrument all aspects of the connection lifecycle.
In this CL, this pattern is used to decouple the server handshake from
the actual forwarding setup. Later CLs will implement additional
functionality using this pattern.

The TestEnv helper is refactored to be more scriptable.

Change-Id: I6060bc4bba57e4ee7e161a5d6edbc40c6fccbaa8
Reviewed-on: https://go-review.googlesource.com/c/tools/+/331369
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
This commit is contained in:
Rob Findley 2021-06-28 10:04:03 -04:00 committed by Robert Findley
parent 20dafe5d60
commit ea370293d7
5 changed files with 347 additions and 101 deletions

View File

@ -16,8 +16,20 @@ import (
errors "golang.org/x/xerrors"
)
// The BinderFunc type adapts a bind function to implement the jsonrpc2.Binder
// interface.
type BinderFunc func(ctx context.Context, conn *jsonrpc2_v2.Connection) (jsonrpc2_v2.ConnectionOptions, error)
func (f BinderFunc) Bind(ctx context.Context, conn *jsonrpc2_v2.Connection) (jsonrpc2_v2.ConnectionOptions, error) {
return f(ctx, conn)
}
// Middleware defines a transformation of jsonrpc2 Binders, that may be
// composed to build jsonrpc2 servers.
type Middleware func(jsonrpc2_v2.Binder) jsonrpc2_v2.Binder
// A ServerFunc is used to construct an LSP server for a given client.
type ServerFunc func(context.Context, protocol.ClientCloser) protocol.Server
type ClientFunc func(context.Context, protocol.Server) protocol.Client
// ServerBinder binds incoming connections to a new server.
type ServerBinder struct {
@ -25,7 +37,7 @@ type ServerBinder struct {
}
func NewServerBinder(newServer ServerFunc) *ServerBinder {
return &ServerBinder{newServer}
return &ServerBinder{newServer: newServer}
}
func (b *ServerBinder) Bind(ctx context.Context, conn *jsonrpc2_v2.Connection) (jsonrpc2_v2.ConnectionOptions, error) {
@ -74,6 +86,7 @@ func (c *canceler) Preempt(ctx context.Context, req *jsonrpc2_v2.Request) (inter
type ForwardBinder struct {
dialer jsonrpc2_v2.Dialer
onBind func(*jsonrpc2_v2.Connection)
}
func NewForwardBinder(dialer jsonrpc2_v2.Dialer) *ForwardBinder {
@ -89,6 +102,9 @@ func (b *ForwardBinder) Bind(ctx context.Context, conn *jsonrpc2_v2.Connection)
if err != nil {
return opts, err
}
if b.onBind != nil {
b.onBind(serverConn)
}
server := protocol.ServerDispatcherV2(serverConn)
preempter := &canceler{
conn: conn,
@ -106,6 +122,10 @@ func (b *ForwardBinder) Bind(ctx context.Context, conn *jsonrpc2_v2.Connection)
}, nil
}
// A ClientFunc is used to construct an LSP client for a given server.
type ClientFunc func(context.Context, protocol.Server) protocol.Client
// ClientBinder binds an LSP client to an incoming connection.
type ClientBinder struct {
newClient ClientFunc
}

View File

@ -2,10 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// TODO(rFindley): move this to the lsprpc_test package once it no longer
// shares with lsprpc_test.go.
package lsprpc
package lsprpc_test
import (
"context"
@ -16,85 +13,71 @@ import (
jsonrpc2_v2 "golang.org/x/tools/internal/jsonrpc2_v2"
"golang.org/x/tools/internal/lsp/protocol"
. "golang.org/x/tools/internal/lsp/lsprpc"
)
type testEnv struct {
listener jsonrpc2_v2.Listener
server *jsonrpc2_v2.Server
// non-nil if constructed with forwarded=true
fwdListener jsonrpc2_v2.Listener
fwdServer *jsonrpc2_v2.Server
// the ingoing connection, either to the forwarder or server
conn *jsonrpc2_v2.Connection
type TestEnv struct {
Listeners []jsonrpc2_v2.Listener
Conns []*jsonrpc2_v2.Connection
Servers []*jsonrpc2_v2.Server
}
func (e testEnv) Shutdown(t *testing.T) {
if err := e.listener.Close(); err != nil {
t.Error(err)
}
if e.fwdListener != nil {
if err := e.fwdListener.Close(); err != nil {
func (e *TestEnv) Shutdown(t *testing.T) {
for _, l := range e.Listeners {
if err := l.Close(); err != nil {
t.Error(err)
}
}
if err := e.conn.Close(); err != nil {
t.Error(err)
}
if e.fwdServer != nil {
if err := e.fwdServer.Wait(); err != nil {
for _, c := range e.Conns {
if err := c.Close(); err != nil {
t.Error(err)
}
}
if err := e.server.Wait(); err != nil {
t.Error(err)
for _, s := range e.Servers {
if err := s.Wait(); err != nil {
t.Error(err)
}
}
}
func startServing(ctx context.Context, t *testing.T, server protocol.Server, client protocol.Client, forwarded bool) testEnv {
listener, err := jsonrpc2_v2.NetPipeListener(ctx)
func (e *TestEnv) serve(ctx context.Context, t *testing.T, server jsonrpc2_v2.Binder) (jsonrpc2_v2.Listener, *jsonrpc2_v2.Server) {
l, err := jsonrpc2_v2.NetPipeListener(ctx)
if err != nil {
t.Fatal(err)
}
newServer := func(ctx context.Context, client protocol.ClientCloser) protocol.Server {
e.Listeners = append(e.Listeners, l)
s, err := jsonrpc2_v2.Serve(ctx, l, server)
if err != nil {
t.Fatal(err)
}
e.Servers = append(e.Servers, s)
return l, s
}
func (e *TestEnv) dial(ctx context.Context, t *testing.T, dialer jsonrpc2_v2.Dialer, client jsonrpc2_v2.Binder, forwarded bool) *jsonrpc2_v2.Connection {
if forwarded {
l, _ := e.serve(ctx, t, NewForwardBinder(dialer))
dialer = l.Dialer()
}
conn, err := jsonrpc2_v2.Dial(ctx, dialer, client)
if err != nil {
t.Fatal(err)
}
e.Conns = append(e.Conns, conn)
return conn
}
func staticClientBinder(client protocol.Client) jsonrpc2_v2.Binder {
f := func(context.Context, protocol.Server) protocol.Client { return client }
return NewClientBinder(f)
}
func staticServerBinder(server protocol.Server) jsonrpc2_v2.Binder {
f := func(ctx context.Context, client protocol.ClientCloser) protocol.Server {
return server
}
serverBinder := NewServerBinder(newServer)
rpcServer, err := jsonrpc2_v2.Serve(ctx, listener, serverBinder)
if err != nil {
t.Fatal(err)
}
env := testEnv{
listener: listener,
server: rpcServer,
}
clientBinder := NewClientBinder(func(context.Context, protocol.Server) protocol.Client { return client })
if forwarded {
fwdListener, err := jsonrpc2_v2.NetPipeListener(ctx)
if err != nil {
t.Fatal(err)
}
fwdBinder := NewForwardBinder(listener.Dialer())
fwdServer, err := jsonrpc2_v2.Serve(ctx, fwdListener, fwdBinder)
if err != nil {
t.Fatal(err)
}
conn, err := jsonrpc2_v2.Dial(ctx, fwdListener.Dialer(), clientBinder)
if err != nil {
t.Fatal(err)
}
env.fwdListener = fwdListener
env.fwdServer = fwdServer
env.conn = conn
} else {
conn, err := jsonrpc2_v2.Dial(ctx, listener.Dialer(), clientBinder)
if err != nil {
t.Fatal(err)
}
env.conn = conn
}
return env
return NewServerBinder(f)
}
func TestClientLoggingV2(t *testing.T) {
@ -105,14 +88,17 @@ func TestClientLoggingV2(t *testing.T) {
"standalone": false,
} {
t.Run(name, func(t *testing.T) {
client := fakeClient{logs: make(chan string, 10)}
env := startServing(ctx, t, pingServer{}, client, forwarded)
client := FakeClient{Logs: make(chan string, 10)}
env := new(TestEnv)
defer env.Shutdown(t)
if err := protocol.ServerDispatcherV2(env.conn).DidOpen(ctx, &protocol.DidOpenTextDocumentParams{}); err != nil {
l, _ := env.serve(ctx, t, staticServerBinder(PingServer{}))
conn := env.dial(ctx, t, l.Dialer(), staticClientBinder(client), forwarded)
if err := protocol.ServerDispatcherV2(conn).DidOpen(ctx, &protocol.DidOpenTextDocumentParams{}); err != nil {
t.Errorf("DidOpen: %v", err)
}
select {
case got := <-client.logs:
case got := <-client.Logs:
want := "ping"
matched, err := regexp.MatchString(want, got)
if err != nil {
@ -136,15 +122,17 @@ func TestRequestCancellationV2(t *testing.T) {
"standalone": false,
} {
t.Run(name, func(t *testing.T) {
server := waitableServer{
started: make(chan struct{}),
completed: make(chan error),
server := WaitableServer{
Started: make(chan struct{}),
Completed: make(chan error),
}
client := fakeClient{logs: make(chan string, 10)}
env := startServing(ctx, t, server, client, forwarded)
env := new(TestEnv)
defer env.Shutdown(t)
l, _ := env.serve(ctx, t, staticServerBinder(server))
client := FakeClient{Logs: make(chan string, 10)}
conn := env.dial(ctx, t, l.Dialer(), staticClientBinder(client), forwarded)
sd := protocol.ServerDispatcherV2(env.conn)
sd := protocol.ServerDispatcherV2(conn)
ctx, cancel := context.WithCancel(ctx)
result := make(chan error)
@ -153,12 +141,12 @@ func TestRequestCancellationV2(t *testing.T) {
result <- err
}()
// Wait for the Hover request to start.
<-server.started
<-server.Started
cancel()
if err := <-result; err == nil {
t.Error("nil error for cancelled Hover(), want non-nil")
}
if err := <-server.completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") {
if err := <-server.Completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") {
t.Errorf("Hover(): unexpected server-side error %v", err)
}
})

View File

@ -22,14 +22,14 @@ import (
"golang.org/x/tools/internal/testenv"
)
type fakeClient struct {
type FakeClient struct {
protocol.Client
logs chan string
Logs chan string
}
func (c fakeClient) LogMessage(ctx context.Context, params *protocol.LogMessageParams) error {
c.logs <- params.Message
func (c FakeClient) LogMessage(ctx context.Context, params *protocol.LogMessageParams) error {
c.Logs <- params.Message
return nil
}
@ -43,9 +43,9 @@ func (fakeServer) Shutdown(ctx context.Context) error {
return nil
}
type pingServer struct{ fakeServer }
type PingServer struct{ fakeServer }
func (s pingServer) DidOpen(ctx context.Context, params *protocol.DidOpenTextDocumentParams) error {
func (s PingServer) DidOpen(ctx context.Context, params *protocol.DidOpenTextDocumentParams) error {
event.Log(ctx, "ping")
return nil
}
@ -54,8 +54,8 @@ func TestClientLogging(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server := pingServer{}
client := fakeClient{logs: make(chan string, 10)}
server := PingServer{}
client := FakeClient{Logs: make(chan string, 10)}
ctx = debug.WithInstance(ctx, "", "")
ss := NewStreamServer(cache.New(nil), false)
@ -70,7 +70,7 @@ func TestClientLogging(t *testing.T) {
}
select {
case got := <-client.logs:
case got := <-client.Logs:
want := "ping"
matched, err := regexp.MatchString(want, got)
if err != nil {
@ -84,20 +84,20 @@ func TestClientLogging(t *testing.T) {
}
}
// waitableServer instruments LSP request so that we can control their timing.
// WaitableServer instruments LSP request so that we can control their timing.
// The requests chosen are arbitrary: we simply needed one that blocks, and
// another that doesn't.
type waitableServer struct {
type WaitableServer struct {
fakeServer
started chan struct{}
completed chan error
Started chan struct{}
Completed chan error
}
func (s waitableServer) Hover(ctx context.Context, _ *protocol.HoverParams) (_ *protocol.Hover, err error) {
s.started <- struct{}{}
func (s WaitableServer) Hover(ctx context.Context, _ *protocol.HoverParams) (_ *protocol.Hover, err error) {
s.Started <- struct{}{}
defer func() {
s.completed <- err
s.Completed <- err
}()
select {
case <-ctx.Done():
@ -107,7 +107,7 @@ func (s waitableServer) Hover(ctx context.Context, _ *protocol.HoverParams) (_ *
return &protocol.Hover{}, nil
}
func (s waitableServer) Resolve(_ context.Context, item *protocol.CompletionItem) (*protocol.CompletionItem, error) {
func (s WaitableServer) Resolve(_ context.Context, item *protocol.CompletionItem) (*protocol.CompletionItem, error) {
return item, nil
}
@ -136,9 +136,9 @@ func setupForwarding(ctx context.Context, t *testing.T, s protocol.Server) (dire
func TestRequestCancellation(t *testing.T) {
ctx := context.Background()
server := waitableServer{
started: make(chan struct{}),
completed: make(chan error),
server := WaitableServer{
Started: make(chan struct{}),
Completed: make(chan error),
}
tsDirect, tsForwarded, cleanup := setupForwarding(ctx, t, server)
defer cleanup()
@ -167,12 +167,12 @@ func TestRequestCancellation(t *testing.T) {
result <- err
}()
// Wait for the Hover request to start.
<-server.started
<-server.Started
cancel()
if err := <-result; err == nil {
t.Error("nil error for cancelled Hover(), want non-nil")
}
if err := <-server.completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") {
if err := <-server.Completed; err == nil || !strings.Contains(err.Error(), "cancelled hover") {
t.Errorf("Hover(): unexpected server-side error %v", err)
}
})

View File

@ -0,0 +1,145 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lsprpc
import (
"context"
"encoding/json"
"sync"
"golang.org/x/tools/internal/event"
jsonrpc2_v2 "golang.org/x/tools/internal/jsonrpc2_v2"
"golang.org/x/xerrors"
)
// Metadata holds arbitrary data transferred between jsonrpc2 peers.
type Metadata map[string]interface{}
// PeerInfo holds information about a peering between jsonrpc2 servers.
type PeerInfo struct {
// RemoteID is the identity of the current server on its peer.
RemoteID int64
// LocalID is the identity of the peer on the server.
LocalID int64
// IsClient reports whether the peer is a client. If false, the peer is a
// server.
IsClient bool
// Metadata holds arbitrary information provided by the peer.
Metadata Metadata
}
// Handshaker handles both server and client handshaking over jsonrpc2. To
// instrument server-side handshaking, use Handshaker.Middleware. To instrument
// client-side handshaking, call Handshaker.ClientHandshake for any new
// client-side connections.
type Handshaker struct {
// Metadata will be shared with peers via handshaking.
Metadata Metadata
mu sync.Mutex
prevID int64
peers map[int64]PeerInfo
}
// Peers returns the peer info this handshaker knows about by way of either the
// server-side handshake middleware, or client-side handshakes.
func (h *Handshaker) Peers() []PeerInfo {
h.mu.Lock()
defer h.mu.Unlock()
var c []PeerInfo
for _, v := range h.peers {
c = append(c, v)
}
return c
}
// Middleware is a jsonrpc2 middleware function to augment connection binding
// to handle the handshake method, and record disconnections.
func (h *Handshaker) Middleware(inner jsonrpc2_v2.Binder) jsonrpc2_v2.Binder {
return BinderFunc(func(ctx context.Context, conn *jsonrpc2_v2.Connection) (jsonrpc2_v2.ConnectionOptions, error) {
opts, err := inner.Bind(ctx, conn)
if err != nil {
return opts, err
}
localID := h.nextID()
info := &PeerInfo{
RemoteID: localID,
Metadata: h.Metadata,
}
// Wrap the delegated handler to accept the handshake.
delegate := opts.Handler
opts.Handler = jsonrpc2_v2.HandlerFunc(func(ctx context.Context, req *jsonrpc2_v2.Request) (interface{}, error) {
if req.Method == handshakeMethod {
var peerInfo PeerInfo
if err := json.Unmarshal(req.Params, &peerInfo); err != nil {
return nil, xerrors.Errorf("%w: unmarshaling client info: %v", jsonrpc2_v2.ErrInvalidParams, err)
}
peerInfo.LocalID = localID
peerInfo.IsClient = true
h.recordPeer(peerInfo)
return info, nil
}
return delegate.Handle(ctx, req)
})
// Record the dropped client.
go h.cleanupAtDisconnect(conn, localID)
return opts, nil
})
}
// ClientHandshake performs a client-side handshake with the server at the
// other end of conn, recording the server's peer info and watching for conn's
// disconnection.
func (h *Handshaker) ClientHandshake(ctx context.Context, conn *jsonrpc2_v2.Connection) {
localID := h.nextID()
info := &PeerInfo{
RemoteID: localID,
Metadata: h.Metadata,
}
call := conn.Call(ctx, handshakeMethod, info)
var serverInfo PeerInfo
if err := call.Await(ctx, &serverInfo); err != nil {
event.Error(ctx, "performing handshake", err)
return
}
serverInfo.LocalID = localID
h.recordPeer(serverInfo)
go h.cleanupAtDisconnect(conn, localID)
}
func (h *Handshaker) nextID() int64 {
h.mu.Lock()
defer h.mu.Unlock()
h.prevID++
return h.prevID
}
func (h *Handshaker) cleanupAtDisconnect(conn *jsonrpc2_v2.Connection, peerID int64) {
conn.Wait()
h.mu.Lock()
defer h.mu.Unlock()
delete(h.peers, peerID)
}
func (h *Handshaker) recordPeer(info PeerInfo) {
h.mu.Lock()
defer h.mu.Unlock()
if h.peers == nil {
h.peers = make(map[int64]PeerInfo)
}
h.peers[info.LocalID] = info
}

View File

@ -0,0 +1,93 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lsprpc_test
import (
"context"
"errors"
"fmt"
"testing"
"time"
jsonrpc2_v2 "golang.org/x/tools/internal/jsonrpc2_v2"
. "golang.org/x/tools/internal/lsp/lsprpc"
)
var noopBinder = BinderFunc(func(context.Context, *jsonrpc2_v2.Connection) (jsonrpc2_v2.ConnectionOptions, error) {
return jsonrpc2_v2.ConnectionOptions{}, nil
})
func TestHandshakeMiddleware(t *testing.T) {
sh := &Handshaker{
Metadata: Metadata{
"answer": 42,
},
}
ctx := context.Background()
env := new(TestEnv)
defer env.Shutdown(t)
l, _ := env.serve(ctx, t, sh.Middleware(noopBinder))
conn := env.dial(ctx, t, l.Dialer(), noopBinder, false)
ch := &Handshaker{
Metadata: Metadata{
"question": 6 * 9,
},
}
check := func(connected bool) error {
clients := sh.Peers()
servers := ch.Peers()
want := 0
if connected {
want = 1
}
if got := len(clients); got != want {
return fmt.Errorf("got %d clients on the server, want %d", got, want)
}
if got := len(servers); got != want {
return fmt.Errorf("got %d servers on the client, want %d", got, want)
}
if !connected {
return nil
}
client := clients[0]
server := servers[0]
if _, ok := client.Metadata["question"]; !ok {
return errors.New("no client metadata")
}
if _, ok := server.Metadata["answer"]; !ok {
return errors.New("no server metadata")
}
if client.LocalID != server.RemoteID {
return fmt.Errorf("client.LocalID == %d, server.PeerID == %d", client.LocalID, server.RemoteID)
}
if client.RemoteID != server.LocalID {
return fmt.Errorf("client.PeerID == %d, server.LocalID == %d", client.RemoteID, server.LocalID)
}
return nil
}
if err := check(false); err != nil {
t.Fatalf("before handshake: %v", err)
}
ch.ClientHandshake(ctx, conn)
if err := check(true); err != nil {
t.Fatalf("after handshake: %v", err)
}
conn.Close()
// Wait for up to ~2s for connections to get cleaned up.
delay := 25 * time.Millisecond
for retries := 3; retries >= 0; retries-- {
time.Sleep(delay)
err := check(false)
if err == nil {
return
}
if retries == 0 {
t.Fatalf("after closing connection: %v", err)
}
delay *= 4
}
}