diff --git a/.changelog/5442.bugfix.md b/.changelog/5442.bugfix.md new file mode 100644 index 00000000000..fcdcc7b9a40 --- /dev/null +++ b/.changelog/5442.bugfix.md @@ -0,0 +1 @@ +go/runtime/host/protocol/connection: Cancel call if connection is closed diff --git a/go/runtime/host/protocol/connection.go b/go/runtime/host/protocol/connection.go index 832235eaa23..b4737469442 100644 --- a/go/runtime/host/protocol/connection.go +++ b/go/runtime/host/protocol/connection.go @@ -200,7 +200,7 @@ type connection struct { // nolint: maligned handler Handler state state - pendingRequests map[uint64]chan *Body + pendingRequests map[uint64]chan<- *Body nextRequestID uint64 info *RuntimeInfoResponse @@ -293,39 +293,21 @@ func (c *connection) call(ctx context.Context, body *Body) (result *Body, err er } }() - respCh, err := c.makeRequest(ctx, body) - if err != nil { - return nil, err - } - - select { - case resp, ok := <-respCh: - if !ok { - return nil, fmt.Errorf("channel closed") - } - - if resp.Error != nil { - // Decode error. - err = errors.FromCode(resp.Error.Module, resp.Error.Code, resp.Error.Message) - return nil, err - } - - return resp, nil - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -func (c *connection) makeRequest(ctx context.Context, body *Body) (<-chan *Body, error) { // Create channel for sending the response and grab next request identifier. - ch := make(chan *Body, 1) + respCh := make(chan *Body, 1) c.Lock() id := c.nextRequestID c.nextRequestID++ - c.pendingRequests[id] = ch + c.pendingRequests[id] = respCh c.Unlock() + defer func() { + c.Lock() + defer c.Unlock() + delete(c.pendingRequests, id) + }() + msg := Message{ ID: id, MessageType: MessageRequest, @@ -333,11 +315,17 @@ func (c *connection) makeRequest(ctx context.Context, body *Body) (<-chan *Body, } // Queue the message. - if err := c.sendMessage(ctx, &msg); err != nil { + if err = c.sendMessage(ctx, &msg); err != nil { return nil, fmt.Errorf("failed to send message: %w", err) } - return ch, nil + // Await a response. + resp, err := c.readResponse(ctx, respCh) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + return resp, nil } func (c *connection) sendMessage(ctx context.Context, msg *Message) error { @@ -351,9 +339,23 @@ func (c *connection) sendMessage(ctx context.Context, msg *Message) error { } } -func (c *connection) workerOutgoing() { - defer c.quitWg.Done() +func (c *connection) readResponse(ctx context.Context, respCh <-chan *Body) (*Body, error) { + select { + case resp := <-respCh: + if resp.Error != nil { + // Decode error. + return nil, errors.FromCode(resp.Error.Module, resp.Error.Code, resp.Error.Message) + } + + return resp, nil + case <-c.closeCh: + return nil, fmt.Errorf("connection closed") + case <-ctx.Done(): + return nil, ctx.Err() + } +} +func (c *connection) workerOutgoing() { for { select { case msg := <-c.outCh: @@ -450,7 +452,6 @@ func (c *connection) handleMessage(ctx context.Context, message *Message) { } respCh <- &message.Body - close(respCh) default: c.logger.Warn("received a malformed message from worker, ignoring", "message", fmt.Sprintf("%+v", message), @@ -459,24 +460,18 @@ func (c *connection) handleMessage(ctx context.Context, message *Message) { } func (c *connection) workerIncoming() { + // Wait for request handlers to finish. + var wg sync.WaitGroup + defer wg.Wait() + + // Cancel all request handlers. ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + defer func() { // Close connection and signal that connection is closed. _ = c.conn.Close() close(c.closeCh) - - // Cancel all request handlers. - cancel() - - // Close all pending request channels. - c.Lock() - for id, ch := range c.pendingRequests { - close(ch) - delete(c.pendingRequests, id) - } - c.Unlock() - - c.quitWg.Done() }() for { @@ -491,7 +486,11 @@ func (c *connection) workerIncoming() { } // Handle message in a separate goroutine. - go c.handleMessage(ctx, &message) + wg.Add(1) + go func() { + defer wg.Done() + c.handleMessage(ctx, &message) + }() } } @@ -507,8 +506,14 @@ func (c *connection) initConn(conn net.Conn) { c.codec = cbor.NewMessageCodec(conn, moduleName) c.quitWg.Add(2) - go c.workerIncoming() - go c.workerOutgoing() + go func() { + defer c.quitWg.Done() + c.workerIncoming() + }() + go func() { + defer c.quitWg.Done() + c.workerOutgoing() + }() // Change protocol state to Initializing so that some of the requests are allowed. c.setStateLocked(stateInitializing) @@ -583,7 +588,7 @@ func NewConnection(logger *logging.Logger, runtimeID common.Namespace, handler H runtimeID: runtimeID, handler: handler, state: stateUninitialized, - pendingRequests: make(map[uint64]chan *Body), + pendingRequests: make(map[uint64]chan<- *Body), outCh: make(chan *Message), closeCh: make(chan struct{}), logger: logger, diff --git a/go/runtime/host/sgx/sgx.go b/go/runtime/host/sgx/sgx.go index 5e15158ed58..be04cd5bd8b 100644 --- a/go/runtime/host/sgx/sgx.go +++ b/go/runtime/host/sgx/sgx.go @@ -41,9 +41,10 @@ const ( // Runtime RAK initialization timeout. // - // This can take a long time in deployments that run multiple - // nodes on a single machine, all sharing the same EPC. - runtimeRAKTimeout = 60 * time.Second + // This can take a long time in deployments that run multiple nodes on a single machine, all + // sharing the same EPC. Additionally, this includes time to do the initial consensus light + // client sync and freshness verification which can take some time. + runtimeRAKTimeout = 5 * time.Minute // Runtime attest interval. defaultRuntimeAttestInterval = 2 * time.Hour )