Skip to content

Commit

Permalink
Merge pull request #742 from ava-labs/use-client-context
Browse files Browse the repository at this point in the history
use context given by client
  • Loading branch information
felipemadero authored Nov 2, 2024
2 parents 5189aac + 1be2d61 commit 5a17aec
Showing 1 changed file with 45 additions and 20 deletions.
65 changes: 45 additions & 20 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func (s *server) RPCVersion(context.Context, *rpcpb.RPCVersionRequest) (*rpcpb.R
return &rpcpb.RPCVersionResponse{Version: RPCVersion}, nil
}

func (s *server) Start(_ context.Context, req *rpcpb.StartRequest) (*rpcpb.StartResponse, error) {
func (s *server) Start(callContext context.Context, req *rpcpb.StartRequest) (*rpcpb.StartResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()

Expand Down Expand Up @@ -383,16 +383,15 @@ func (s *server) Start(_ context.Context, req *rpcpb.StartRequest) (*rpcpb.Start
zap.String("global-node-config", globalNodeConfig),
)

ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout())
ctx, cancel := s.getContext(callContext)
defer cancel()

if err := s.network.Start(ctx); err != nil {
s.log.Warn("start failed to complete", zap.Error(err))
s.stopAndRemoveNetwork(nil)
return nil, err
}

ctx, cancel = context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout())
defer cancel()
chainIDs, err := s.network.CreateChains(ctx, chainSpecs)
if err != nil {
s.log.Error("network never became healthy", zap.Error(err))
Expand Down Expand Up @@ -439,15 +438,15 @@ func (s *server) updateClusterInfo() {
// - timeout expires
// - network operation terminates with error
// - network operation terminates successfully by setting CustomChainsHealthy
func (s *server) WaitForHealthy(ctx context.Context, _ *rpcpb.WaitForHealthyRequest) (*rpcpb.WaitForHealthyResponse, error) {
func (s *server) WaitForHealthy(callContext context.Context, _ *rpcpb.WaitForHealthyRequest) (*rpcpb.WaitForHealthyResponse, error) {
s.log.Debug("WaitForHealthy")

s.mu.RLock()
if s.network == nil {
s.mu.RUnlock()
return nil, ErrNotBootstrapped
}
ctx, cancel := context.WithTimeout(ctx, s.network.GetWaitForHealthyTimeout())
ctx, cancel := s.getContext(callContext)
defer cancel()
s.mu.RUnlock()

Expand Down Expand Up @@ -555,7 +554,7 @@ func (s *server) CreateBlockchains(
}

func (s *server) AddPermissionlessDelegator(
_ context.Context,
callContext context.Context,
req *rpcpb.AddPermissionlessDelegatorRequest,
) (*rpcpb.AddPermissionlessDelegatorResponse, error) {
s.mu.Lock()
Expand Down Expand Up @@ -592,7 +591,7 @@ func (s *server) AddPermissionlessDelegator(
}
}

ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout())
ctx, cancel := s.getContext(callContext)
defer cancel()
err := s.network.AddPermissionlessDelegators(ctx, delegatorSpecList)
if err != nil {
Expand All @@ -609,7 +608,7 @@ func (s *server) AddPermissionlessDelegator(
}

func (s *server) AddPermissionlessValidator(
_ context.Context,
callContext context.Context,
req *rpcpb.AddPermissionlessValidatorRequest,
) (*rpcpb.AddPermissionlessValidatorResponse, error) {
s.mu.Lock()
Expand Down Expand Up @@ -649,7 +648,7 @@ func (s *server) AddPermissionlessValidator(
s.clusterInfo.Healthy = false
s.clusterInfo.CustomChainsHealthy = false

ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout())
ctx, cancel := s.getContext(callContext)
defer cancel()
err := s.network.AddPermissionlessValidators(ctx, validatorSpecList)

Expand All @@ -670,7 +669,7 @@ func (s *server) AddPermissionlessValidator(
}

func (s *server) AddSubnetValidators(
_ context.Context,
callContext context.Context,
req *rpcpb.AddSubnetValidatorsRequest,
) (*rpcpb.AddSubnetValidatorsResponse, error) {
s.mu.Lock()
Expand Down Expand Up @@ -707,7 +706,7 @@ func (s *server) AddSubnetValidators(
s.clusterInfo.Healthy = false
s.clusterInfo.CustomChainsHealthy = false

ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout())
ctx, cancel := s.getContext(callContext)
defer cancel()
err := s.network.AddSubnetValidators(ctx, validatorSpecList)

Expand All @@ -728,7 +727,7 @@ func (s *server) AddSubnetValidators(
}

func (s *server) RemoveSubnetValidator(
_ context.Context,
callContext context.Context,
req *rpcpb.RemoveSubnetValidatorRequest,
) (*rpcpb.RemoveSubnetValidatorResponse, error) {
s.mu.Lock()
Expand Down Expand Up @@ -765,7 +764,7 @@ func (s *server) RemoveSubnetValidator(
s.clusterInfo.Healthy = false
s.clusterInfo.CustomChainsHealthy = false

ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout())
ctx, cancel := s.getContext(callContext)
defer cancel()
err := s.network.RemoveSubnetValidator(ctx, validatorSpecList)

Expand All @@ -786,7 +785,7 @@ func (s *server) RemoveSubnetValidator(
}

func (s *server) TransformElasticSubnets(
_ context.Context,
callContext context.Context,
req *rpcpb.TransformElasticSubnetsRequest,
) (*rpcpb.TransformElasticSubnetsResponse, error) {
s.mu.Lock()
Expand Down Expand Up @@ -823,7 +822,7 @@ func (s *server) TransformElasticSubnets(
s.clusterInfo.Healthy = false
s.clusterInfo.CustomChainsHealthy = false

ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout())
ctx, cancel := s.getContext(callContext)
defer cancel()
txIDs, assetIDs, err := s.network.TransformSubnets(ctx, elasticSubnetSpecList)

Expand Down Expand Up @@ -853,7 +852,10 @@ func (s *server) TransformElasticSubnets(
return &rpcpb.TransformElasticSubnetsResponse{ClusterInfo: clusterInfo, TxIds: strTXIDs, AssetIds: strAssetIDs}, nil
}

func (s *server) CreateSubnets(_ context.Context, req *rpcpb.CreateSubnetsRequest) (*rpcpb.CreateSubnetsResponse, error) {
func (s *server) CreateSubnets(
callContext context.Context,
req *rpcpb.CreateSubnetsRequest,
) (*rpcpb.CreateSubnetsResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -874,7 +876,7 @@ func (s *server) CreateSubnets(_ context.Context, req *rpcpb.CreateSubnetsReques
s.clusterInfo.Healthy = false
s.clusterInfo.CustomChainsHealthy = false

ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout())
ctx, cancel := s.getContext(callContext)
defer cancel()
subnetIDs, err := s.network.CreateSubnets(ctx, subnetSpecs)
if err != nil {
Expand Down Expand Up @@ -1343,7 +1345,10 @@ func (s *server) SendOutboundMessage(ctx context.Context, req *rpcpb.SendOutboun
return &rpcpb.SendOutboundMessageResponse{Sent: sent}, err
}

func (s *server) LoadSnapshot(_ context.Context, req *rpcpb.LoadSnapshotRequest) (*rpcpb.LoadSnapshotResponse, error) {
func (s *server) LoadSnapshot(
callContext context.Context,
req *rpcpb.LoadSnapshotRequest,
) (*rpcpb.LoadSnapshotResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()

Expand Down Expand Up @@ -1404,7 +1409,7 @@ func (s *server) LoadSnapshot(_ context.Context, req *rpcpb.LoadSnapshotRequest)
return nil, err
}

ctx, cancel := context.WithTimeout(context.Background(), s.network.GetWaitForHealthyTimeout())
ctx, cancel := s.getContext(callContext)
defer cancel()
err = s.network.AwaitHealthyAndUpdateNetworkInfo(ctx)
if err != nil {
Expand Down Expand Up @@ -1789,3 +1794,23 @@ func applyDefaultPluginDir(userGivenPluginDir string) string {
}
return os.Getenv(constants.DefaultPluginDirEnvVar)
}

// Assumes [s.mu] is held.
func (s *server) getContext(callContext context.Context) (context.Context, context.CancelFunc) {
timeout := 3 * time.Minute
if s.network != nil {
timeout = s.network.GetWaitForHealthyTimeout()
}
var (
ctx context.Context
cancel context.CancelFunc
)
_, ok := callContext.Deadline()
if ok {
ctx = callContext
cancel = func() {}
} else {
ctx, cancel = context.WithTimeout(context.Background(), timeout)
}
return ctx, cancel
}

0 comments on commit 5a17aec

Please sign in to comment.