diff --git a/gossip/client_test.go b/gossip/client_test.go index 455257e64ad9..145627d245e5 100644 --- a/gossip/client_test.go +++ b/gossip/client_test.go @@ -29,7 +29,6 @@ import ( "github.com/cockroachdb/cockroach/roachpb" "github.com/cockroachdb/cockroach/rpc" "github.com/cockroachdb/cockroach/util" - "github.com/cockroachdb/cockroach/util/grpcutil" "github.com/cockroachdb/cockroach/util/hlc" "github.com/cockroachdb/cockroach/util/leaktest" "github.com/cockroachdb/cockroach/util/stop" @@ -46,7 +45,7 @@ func startGossip(nodeID roachpb.NodeID, stopper *stop.Stopper, t *testing.T) *Go if err != nil { t.Fatal(err) } - ln, err := grpcutil.ListenAndServeGRPC(stopper, server, addr, tlsConfig) + ln, err := util.ListenAndServe(stopper, server, addr, tlsConfig) if err != nil { t.Fatal(err) } @@ -111,7 +110,7 @@ func startFakeServerGossips(t *testing.T) (local *Gossip, remote *fakeGossipServ if err != nil { t.Fatal(err) } - lln, err := grpcutil.ListenAndServeGRPC(stopper, lserver, laddr, lTLSConfig) + lln, err := util.ListenAndServe(stopper, lserver, laddr, lTLSConfig) if err != nil { t.Fatal(err) } @@ -127,7 +126,7 @@ func startFakeServerGossips(t *testing.T) (local *Gossip, remote *fakeGossipServ if err != nil { t.Fatal(err) } - rln, err := grpcutil.ListenAndServeGRPC(stopper, rserver, raddr, rTLSConfig) + rln, err := util.ListenAndServe(stopper, rserver, raddr, rTLSConfig) if err != nil { t.Fatal(err) } @@ -333,7 +332,7 @@ func TestClientRegisterWithInitNodeID(t *testing.T) { if err != nil { t.Fatal(err) } - ln, err := grpcutil.ListenAndServeGRPC(stopper, server, addr, TLSConfig) + ln, err := util.ListenAndServe(stopper, server, addr, TLSConfig) if err != nil { t.Fatal(err) } diff --git a/gossip/simulation/network.go b/gossip/simulation/network.go index 443ff19a5ddb..544f69dea56d 100644 --- a/gossip/simulation/network.go +++ b/gossip/simulation/network.go @@ -30,7 +30,6 @@ import ( "github.com/cockroachdb/cockroach/rpc" "github.com/cockroachdb/cockroach/util" "github.com/cockroachdb/cockroach/util/encoding" - "github.com/cockroachdb/cockroach/util/grpcutil" "github.com/cockroachdb/cockroach/util/hlc" "github.com/cockroachdb/cockroach/util/log" "github.com/cockroachdb/cockroach/util/stop" @@ -93,7 +92,7 @@ func NewNetwork(nodeCount int) *Network { func (n *Network) CreateNode() (*Node, error) { server := grpc.NewServer() testAddr := util.CreateTestAddr("tcp") - ln, err := grpcutil.ListenAndServeGRPC(n.Stopper, server, testAddr, n.tlsConfig) + ln, err := util.ListenAndServe(n.Stopper, server, testAddr, n.tlsConfig) if err != nil { return nil, err } diff --git a/storage/raft_transport.go b/storage/raft_transport.go index 37b4599d99b0..3e4bbd54c31c 100644 --- a/storage/raft_transport.go +++ b/storage/raft_transport.go @@ -100,7 +100,7 @@ func (lt *localRPCTransport) Listen(id roachpb.StoreID, handler RaftMessageHandl RegisterMultiRaftServer(grpcServer, handler) addr := util.CreateTestAddr("tcp") - ln, err := grpcutil.ListenAndServeGRPC(lt.stopper, grpcServer, addr, nil) + ln, err := util.ListenAndServe(lt.stopper, grpcServer, addr, nil) if err != nil { return err } diff --git a/util/grpcutil/grpc_util.go b/util/grpcutil/grpc_util.go index 06c8e17715ee..5b6e802383d9 100644 --- a/util/grpcutil/grpc_util.go +++ b/util/grpcutil/grpc_util.go @@ -17,8 +17,6 @@ package grpcutil import ( - "crypto/tls" - "net" "net/http" "strings" @@ -29,34 +27,9 @@ import ( "google.golang.org/grpc/transport" "github.com/cockroachdb/cockroach/util" - "github.com/cockroachdb/cockroach/util/log" "github.com/cockroachdb/cockroach/util/stop" ) -// ListenAndServeGRPC creates a listener and serves server on it, closing -// the listener when signalled by the stopper. -func ListenAndServeGRPC(stopper *stop.Stopper, server *grpc.Server, addr net.Addr, config *tls.Config) (net.Listener, error) { - ln, err := util.Listen(addr, config) - if err != nil { - return nil, err - } - - stopper.RunWorker(func() { - if err := server.Serve(ln); err != nil && !util.IsClosedConnection(err) { - log.Fatal(err) - } - }) - - stopper.RunWorker(func() { - <-stopper.ShouldDrain() - if err := ln.Close(); err != nil { - log.Fatal(err) - } - }) - - return ln, nil -} - // GRPCHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC // connections or otherHandler otherwise. func GRPCHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler { diff --git a/util/net.go b/util/net.go index b14f78b97094..50aadac18ed1 100644 --- a/util/net.go +++ b/util/net.go @@ -17,7 +17,9 @@ package util import ( + "bytes" "crypto/tls" + "io" "log" "net" "net/http" @@ -29,13 +31,51 @@ import ( "github.com/cockroachdb/cockroach/util/stop" ) +type replayableConn struct { + net.Conn + buf bytes.Buffer + reader io.Reader +} + +// Do not call `replay` more than once, bad things will happen. +func (bc *replayableConn) replay() *replayableConn { + bc.reader = io.MultiReader(&bc.buf, bc.Conn) + return bc +} + +func (bc *replayableConn) Read(b []byte) (int, error) { + return bc.reader.Read(b) +} + +func newBufferedConn(conn net.Conn) *replayableConn { + bc := replayableConn{Conn: conn} + bc.reader = io.TeeReader(conn, &bc.buf) + return &bc +} + +type replayableConnListener struct { + net.Listener +} + +func (ml *replayableConnListener) Accept() (net.Conn, error) { + conn, err := ml.Listener.Accept() + if err == nil { + conn = newBufferedConn(conn) + } + return conn, err +} + // Listen delegates to `net.Listen` and, if tlsConfig is not nil, to `tls.NewListener`. // The returned listener's Addr() method will return an address with the hostname unresovled, // which means it can be used to initiate TLS connections. func Listen(addr net.Addr, tlsConfig *tls.Config) (net.Listener, error) { ln, err := net.Listen(addr.Network(), addr.String()) - if err == nil && tlsConfig != nil { - ln = tls.NewListener(ln, tlsConfig) + if err == nil { + if tlsConfig != nil { + ln = tls.NewListener(ln, tlsConfig) + } else { + ln = &replayableConnListener{ln} + } } return ln, err @@ -66,7 +106,29 @@ func ListenAndServe(stopper *stop.Stopper, handler http.Handler, addr net.Addr, mu.Unlock() }, } - if err := http2.ConfigureServer(&httpServer, nil); err != nil { + + var http2Server http2.Server + + if tlsConfig == nil { + connOpts := http2.ServeConnOpts{ + BaseConfig: &httpServer, + Handler: handler, + } + + httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor == 2 { + if conn, _, err := w.(http.Hijacker).Hijack(); err == nil { + http2Server.ServeConn(conn.(*replayableConn).replay(), &connOpts) + } else { + log.Fatal(err) + } + } else { + handler.ServeHTTP(w, r) + } + }) + } + + if err := http2.ConfigureServer(&httpServer, &http2Server); err != nil { return nil, err }