diff --git a/server/server.go b/server/server.go index e9521487..e056990b 100644 --- a/server/server.go +++ b/server/server.go @@ -262,7 +262,7 @@ func (s *server) RPCVersion(context.Context, *rpcpb.RPCVersionRequest) (*rpcpb.R return &rpcpb.RPCVersionResponse{Version: RPCVersion}, nil } -func (s *server) Start(_ context.Context, req *rpcpb.StartRequest) (*rpcpb.StartResponse, error) { +func (s *server) Start(callContext context.Context, req *rpcpb.StartRequest) (*rpcpb.StartResponse, error) { s.mu.Lock() defer s.mu.Unlock() @@ -383,16 +383,15 @@ func (s *server) Start(_ context.Context, req *rpcpb.StartRequest) (*rpcpb.Start zap.String("global-node-config", globalNodeConfig), ) - ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout()) + ctx, cancel := s.getContext(callContext) defer cancel() + if err := s.network.Start(ctx); err != nil { s.log.Warn("start failed to complete", zap.Error(err)) s.stopAndRemoveNetwork(nil) return nil, err } - ctx, cancel = context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout()) - defer cancel() chainIDs, err := s.network.CreateChains(ctx, chainSpecs) if err != nil { s.log.Error("network never became healthy", zap.Error(err)) @@ -439,7 +438,7 @@ func (s *server) updateClusterInfo() { // - timeout expires // - network operation terminates with error // - network operation terminates successfully by setting CustomChainsHealthy -func (s *server) WaitForHealthy(ctx context.Context, _ *rpcpb.WaitForHealthyRequest) (*rpcpb.WaitForHealthyResponse, error) { +func (s *server) WaitForHealthy(callContext context.Context, _ *rpcpb.WaitForHealthyRequest) (*rpcpb.WaitForHealthyResponse, error) { s.log.Debug("WaitForHealthy") s.mu.RLock() @@ -447,7 +446,7 @@ func (s *server) WaitForHealthy(ctx context.Context, _ *rpcpb.WaitForHealthyRequ s.mu.RUnlock() return nil, ErrNotBootstrapped } - ctx, cancel := context.WithTimeout(ctx, s.network.GetWaitForHealthyTimeout()) + ctx, cancel := s.getContext(callContext) defer cancel() s.mu.RUnlock() @@ -555,7 +554,7 @@ func (s *server) CreateBlockchains( } func (s *server) AddPermissionlessDelegator( - _ context.Context, + callContext context.Context, req *rpcpb.AddPermissionlessDelegatorRequest, ) (*rpcpb.AddPermissionlessDelegatorResponse, error) { s.mu.Lock() @@ -592,7 +591,7 @@ func (s *server) AddPermissionlessDelegator( } } - ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout()) + ctx, cancel := s.getContext(callContext) defer cancel() err := s.network.AddPermissionlessDelegators(ctx, delegatorSpecList) if err != nil { @@ -609,7 +608,7 @@ func (s *server) AddPermissionlessDelegator( } func (s *server) AddPermissionlessValidator( - _ context.Context, + callContext context.Context, req *rpcpb.AddPermissionlessValidatorRequest, ) (*rpcpb.AddPermissionlessValidatorResponse, error) { s.mu.Lock() @@ -649,7 +648,7 @@ func (s *server) AddPermissionlessValidator( s.clusterInfo.Healthy = false s.clusterInfo.CustomChainsHealthy = false - ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout()) + ctx, cancel := s.getContext(callContext) defer cancel() err := s.network.AddPermissionlessValidators(ctx, validatorSpecList) @@ -670,7 +669,7 @@ func (s *server) AddPermissionlessValidator( } func (s *server) AddSubnetValidators( - _ context.Context, + callContext context.Context, req *rpcpb.AddSubnetValidatorsRequest, ) (*rpcpb.AddSubnetValidatorsResponse, error) { s.mu.Lock() @@ -707,7 +706,7 @@ func (s *server) AddSubnetValidators( s.clusterInfo.Healthy = false s.clusterInfo.CustomChainsHealthy = false - ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout()) + ctx, cancel := s.getContext(callContext) defer cancel() err := s.network.AddSubnetValidators(ctx, validatorSpecList) @@ -728,7 +727,7 @@ func (s *server) AddSubnetValidators( } func (s *server) RemoveSubnetValidator( - _ context.Context, + callContext context.Context, req *rpcpb.RemoveSubnetValidatorRequest, ) (*rpcpb.RemoveSubnetValidatorResponse, error) { s.mu.Lock() @@ -765,7 +764,7 @@ func (s *server) RemoveSubnetValidator( s.clusterInfo.Healthy = false s.clusterInfo.CustomChainsHealthy = false - ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout()) + ctx, cancel := s.getContext(callContext) defer cancel() err := s.network.RemoveSubnetValidator(ctx, validatorSpecList) @@ -786,7 +785,7 @@ func (s *server) RemoveSubnetValidator( } func (s *server) TransformElasticSubnets( - _ context.Context, + callContext context.Context, req *rpcpb.TransformElasticSubnetsRequest, ) (*rpcpb.TransformElasticSubnetsResponse, error) { s.mu.Lock() @@ -823,7 +822,7 @@ func (s *server) TransformElasticSubnets( s.clusterInfo.Healthy = false s.clusterInfo.CustomChainsHealthy = false - ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout()) + ctx, cancel := s.getContext(callContext) defer cancel() txIDs, assetIDs, err := s.network.TransformSubnets(ctx, elasticSubnetSpecList) @@ -853,7 +852,10 @@ func (s *server) TransformElasticSubnets( return &rpcpb.TransformElasticSubnetsResponse{ClusterInfo: clusterInfo, TxIds: strTXIDs, AssetIds: strAssetIDs}, nil } -func (s *server) CreateSubnets(_ context.Context, req *rpcpb.CreateSubnetsRequest) (*rpcpb.CreateSubnetsResponse, error) { +func (s *server) CreateSubnets( + callContext context.Context, + req *rpcpb.CreateSubnetsRequest, +) (*rpcpb.CreateSubnetsResponse, error) { s.mu.Lock() defer s.mu.Unlock() @@ -874,7 +876,7 @@ func (s *server) CreateSubnets(_ context.Context, req *rpcpb.CreateSubnetsReques s.clusterInfo.Healthy = false s.clusterInfo.CustomChainsHealthy = false - ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout()) + ctx, cancel := s.getContext(callContext) defer cancel() subnetIDs, err := s.network.CreateSubnets(ctx, subnetSpecs) if err != nil { @@ -1343,7 +1345,10 @@ func (s *server) SendOutboundMessage(ctx context.Context, req *rpcpb.SendOutboun return &rpcpb.SendOutboundMessageResponse{Sent: sent}, err } -func (s *server) LoadSnapshot(_ context.Context, req *rpcpb.LoadSnapshotRequest) (*rpcpb.LoadSnapshotResponse, error) { +func (s *server) LoadSnapshot( + callContext context.Context, + req *rpcpb.LoadSnapshotRequest, +) (*rpcpb.LoadSnapshotResponse, error) { s.mu.Lock() defer s.mu.Unlock() @@ -1404,7 +1409,7 @@ func (s *server) LoadSnapshot(_ context.Context, req *rpcpb.LoadSnapshotRequest) return nil, err } - ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout()) + ctx, cancel := s.getContext(callContext) defer cancel() err = s.network.AwaitHealthyAndUpdateNetworkInfo(ctx) if err != nil { @@ -1789,3 +1794,23 @@ func applyDefaultPluginDir(userGivenPluginDir string) string { } return os.Getenv(constants.DefaultPluginDirEnvVar) } + +// Assumes [s.mu] is held. +func (s *server) getContext(callContext context.Context) (context.Context, context.CancelFunc) { + timeout := 3 * time.Minute + if s.network != nil { + timeout = s.network.GetWaitForHealthyTimeout() + } + var ( + ctx context.Context + cancel context.CancelFunc + ) + _, ok := callContext.Deadline() + if ok { + ctx = callContext + cancel = func() {} + } else { + ctx, cancel = context.WithTimeout(context.Background(), timeout) + } + return ctx, cancel +}