diff --git a/pkg/server/servemode.go b/pkg/server/grpc_server.go similarity index 58% rename from pkg/server/servemode.go rename to pkg/server/grpc_server.go index fe5345458cc7..af5b786aff0e 100644 --- a/pkg/server/servemode.go +++ b/pkg/server/grpc_server.go @@ -18,10 +18,31 @@ import ( "strings" "sync/atomic" + "github.com/cockroachdb/cockroach/pkg/rpc" + "google.golang.org/grpc" "google.golang.org/grpc/codes" grpcstatus "google.golang.org/grpc/status" ) +// grpcServer is a wrapper on top of a grpc.Server that includes an interceptor +// and a mode of operation that can instruct the interceptor to refuse certain +// RPCs. +type grpcServer struct { + *grpc.Server + mode serveMode +} + +func newGRPCServer(rpcCtx *rpc.Context) *grpcServer { + s := &grpcServer{} + s.mode.set(modeInitializing) + s.Server = rpc.NewServerWithInterceptor(rpcCtx, func(path string) error { + return s.intercept(path) + }) + return s +} + +type serveMode int32 + // A list of the server states for bootstrap process. const ( // modeInitializing is intended for server initialization process. @@ -36,25 +57,31 @@ const ( modeDraining ) -type serveMode int32 +func (s *grpcServer) setMode(mode serveMode) { + s.mode.set(mode) +} -// Intercept implements filtering rules for each server state. -func (s *Server) Intercept() func(string) error { - interceptors := map[string]struct{}{ - "/cockroach.rpc.Heartbeat/Ping": {}, - "/cockroach.gossip.Gossip/Gossip": {}, - "/cockroach.server.serverpb.Init/Bootstrap": {}, - "/cockroach.server.serverpb.Status/Details": {}, - } - return func(fullName string) error { - if s.serveMode.operational() { - return nil - } - if _, allowed := interceptors[fullName]; !allowed { - return WaitingForInitError(fullName) - } +func (s *grpcServer) operational() bool { + sMode := s.mode.get() + return sMode == modeOperational || sMode == modeDraining +} + +var rpcsAllowedWhileBootstrapping = map[string]struct{}{ + "/cockroach.rpc.Heartbeat/Ping": {}, + "/cockroach.gossip.Gossip/Gossip": {}, + "/cockroach.server.serverpb.Init/Bootstrap": {}, + "/cockroach.server.serverpb.Status/Details": {}, +} + +// intercept implements filtering rules for each server state. +func (s *grpcServer) intercept(fullName string) error { + if s.operational() { return nil } + if _, allowed := rpcsAllowedWhileBootstrapping[fullName]; !allowed { + return s.waitingForInitError(fullName) + } + return nil } func (s *serveMode) set(mode serveMode) { @@ -65,14 +92,9 @@ func (s *serveMode) get() serveMode { return serveMode(atomic.LoadInt32((*int32)(s))) } -func (s *serveMode) operational() bool { - sMode := s.get() - return sMode == modeOperational || sMode == modeDraining -} - -// WaitingForInitError indicates that the server cannot run the specified -// method until the node has been initialized. -func WaitingForInitError(methodName string) error { +// waitingForInitError creates an error indicating that the server cannot run +// the specified method until the node has been initialized. +func (s *grpcServer) waitingForInitError(methodName string) error { return grpcstatus.Errorf(codes.Unavailable, "node waiting for init; %s not available", methodName) } diff --git a/pkg/server/servemode_test.go b/pkg/server/servemode_test.go index 4f11224e94b6..80d02488b28f 100644 --- a/pkg/server/servemode_test.go +++ b/pkg/server/servemode_test.go @@ -25,7 +25,8 @@ import ( func TestWaitingForInitError(t *testing.T) { defer leaktest.AfterTest(t)() - if err := WaitingForInitError("foo"); !IsWaitingForInit(err) { + s := &grpcServer{} + if err := s.waitingForInitError("foo"); !IsWaitingForInit(err) { t.Errorf("WaitingForInitError() not recognized by IsWaitingForInit(): %v", err) } if err := grpcstatus.Errorf(codes.Unavailable, "foo"); IsWaitingForInit(err) { diff --git a/pkg/server/server.go b/pkg/server/server.go index 20c7f30446b1..aa5f5366b0d0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -149,12 +149,13 @@ func (mux *safeServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { type Server struct { nodeIDContainer base.NodeIDContainer - cfg Config - st *cluster.Settings - mux safeServeMux - clock *hlc.Clock - rpcContext *rpc.Context - grpc *grpc.Server + cfg Config + st *cluster.Settings + mux safeServeMux + clock *hlc.Clock + rpcContext *rpc.Context + // The gRPC server on which the different RPC handlers will be registered. + grpc *grpcServer gossip *gossip.Gossip nodeDialer *nodedialer.Dialer nodeLiveness *storage.NodeLiveness @@ -189,7 +190,6 @@ type Server struct { adminMemMetrics sql.MemoryMetrics // sqlMemMetrics are used to track memory usage of sql sessions. sqlMemMetrics sql.MemoryMetrics - serveMode } // NewServer creates a Server from a server.Config. @@ -212,7 +212,6 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { cfg: cfg, registry: metric.NewRegistry(), } - s.serveMode.set(modeInitializing) // If the tracer has a Close function, call it after the server stops. if tr, ok := cfg.AmbientCtx.Tracer.(stop.Closer); ok { @@ -251,14 +250,14 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { } } - s.grpc = rpc.NewServerWithInterceptor(s.rpcContext, s.Intercept()) + s.grpc = newGRPCServer(s.rpcContext) s.gossip = gossip.New( s.cfg.AmbientCtx, &s.rpcContext.ClusterID, &s.nodeIDContainer, s.rpcContext, - s.grpc, + s.grpc.Server, s.stopper, s.registry, s.cfg.Locality, @@ -340,7 +339,7 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { ) s.raftTransport = storage.NewRaftTransport( - s.cfg.AmbientCtx, st, s.nodeDialer, s.grpc, s.stopper, + s.cfg.AmbientCtx, st, s.nodeDialer, s.grpc.Server, s.stopper, ) // Set up internal memory metrics for use by internal SQL executors. @@ -481,9 +480,9 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { s.node = NewNode( storeCfg, s.recorder, s.registry, s.stopper, txnMetrics, nil /* execCfg */, &s.rpcContext.ClusterID) - roachpb.RegisterInternalServer(s.grpc, s.node) - storage.RegisterPerReplicaServer(s.grpc, s.node.perReplicaServer) - s.node.storeCfg.ClosedTimestamp.RegisterClosedTimestampServer(s.grpc) + roachpb.RegisterInternalServer(s.grpc.Server, s.node) + storage.RegisterPerReplicaServer(s.grpc.Server, s.node.perReplicaServer) + s.node.storeCfg.ClosedTimestamp.RegisterClosedTimestampServer(s.grpc.Server) s.sessionRegistry = sql.NewSessionRegistry() s.jobRegistry = jobs.MakeRegistry( @@ -535,7 +534,7 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { } s.distSQLServer = distsqlrun.NewServer(ctx, distSQLCfg) - distsqlpb.RegisterDistSQLServer(s.grpc, s.distSQLServer) + distsqlpb.RegisterDistSQLServer(s.grpc.Server, s.distSQLServer) s.admin = newAdminServer(s) s.status = newStatusServer( @@ -555,7 +554,7 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { ) s.authentication = newAuthenticationServer(s) for _, gw := range []grpcGatewayServer{s.admin, s.status, s.authentication, &s.tsServer} { - gw.RegisterService(s.grpc) + gw.RegisterService(s.grpc.Server) } // TODO(andrei): We're creating an initServer even through the inspection of @@ -565,7 +564,7 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { // figure out early if our engines are bootstrapped and, if they are, create a // dummy implementation of the InitServer that rejects all RPCs. s.initServer = newInitServer(s.gossip.Connected, s.stopper.ShouldStop()) - serverpb.RegisterInitServer(s.grpc, s.initServer) + serverpb.RegisterInitServer(s.grpc.Server, s.initServer) nodeInfo := sql.NodeInfo{ AdminURL: cfg.AdminURL, @@ -1538,7 +1537,7 @@ func (s *Server) Start(ctx context.Context) error { s.distSQLServer.Start() s.pgServer.Start(ctx, s.stopper) - s.serveMode.set(modeOperational) + s.grpc.setMode(modeOperational) log.Infof(ctx, "starting %s server at %s (use: %s)", s.cfg.HTTPRequestScheme(), s.cfg.HTTPAddr, s.cfg.HTTPAdvertiseAddr) @@ -1758,7 +1757,7 @@ func (s *Server) doDrain( switch mode { case serverpb.DrainMode_CLIENT: if setTo { - s.serveMode.set(modeDraining) + s.grpc.setMode(modeDraining) // Wait for drainUnreadyWait. This will fail load balancer checks and // delay draining so that client traffic can move off this node. time.Sleep(drainWait.Get(&s.st.SV)) @@ -1766,7 +1765,7 @@ func (s *Server) doDrain( if err := func() error { if !setTo { // Execute this last. - defer func() { s.serveMode.set(modeOperational) }() + defer func() { s.grpc.setMode(modeOperational) }() } // Since enabling the lease manager's draining mode will prevent // the acquisition of new leases, the switch must be made after diff --git a/pkg/server/status.go b/pkg/server/status.go index 057546aadf23..830cece9199d 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -590,7 +590,7 @@ func (s *statusServer) Details( return resp, nil } - serveMode := s.admin.server.serveMode.get() + serveMode := s.admin.server.grpc.mode.get() if serveMode != modeOperational { return nil, grpcstatus.Error(codes.Unavailable, "node is not ready") }