Skip to content

Commit

Permalink
go/runtime/host/protocol/connection: Cancel call if connection is closed
Browse files Browse the repository at this point in the history
  • Loading branch information
peternose committed Nov 13, 2023
1 parent 622d0bd commit 82e0afd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 39 deletions.
1 change: 1 addition & 0 deletions .changelog/5442.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
go/runtime/host/protocol/connection: Cancel call if connection is closed
73 changes: 34 additions & 39 deletions go/runtime/host/protocol/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ type connection struct { // nolint: maligned
handler Handler

state state
pendingRequests map[uint64]chan *Body
pendingRequests map[uint64]chan<- *Body
nextRequestID uint64

info *RuntimeInfoResponse
Expand Down Expand Up @@ -293,51 +293,39 @@ func (c *connection) call(ctx context.Context, body *Body) (result *Body, err er
}
}()

respCh, err := c.makeRequest(ctx, body)
if err != nil {
return nil, err
}

select {
case resp, ok := <-respCh:
if !ok {
return nil, fmt.Errorf("channel closed")
}

if resp.Error != nil {
// Decode error.
err = errors.FromCode(resp.Error.Module, resp.Error.Code, resp.Error.Message)
return nil, err
}

return resp, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}

func (c *connection) makeRequest(ctx context.Context, body *Body) (<-chan *Body, error) {
// Create channel for sending the response and grab next request identifier.
ch := make(chan *Body, 1)
respCh := make(chan *Body, 1)

c.Lock()
id := c.nextRequestID
c.nextRequestID++
c.pendingRequests[id] = ch
c.pendingRequests[id] = respCh
c.Unlock()

defer func() {
c.Lock()
defer c.Unlock()
delete(c.pendingRequests, id)
}()

msg := Message{
ID: id,
MessageType: MessageRequest,
Body: *body,
}

// Queue the message.
if err := c.sendMessage(ctx, &msg); err != nil {
if err = c.sendMessage(ctx, &msg); err != nil {
return nil, fmt.Errorf("failed to send message: %w", err)
}

return ch, nil
// Await a response.
resp, err := c.readResponse(ctx, respCh)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}

return resp, nil
}

func (c *connection) sendMessage(ctx context.Context, msg *Message) error {
Expand All @@ -351,6 +339,22 @@ func (c *connection) sendMessage(ctx context.Context, msg *Message) error {
}
}

func (c *connection) readResponse(ctx context.Context, respCh <-chan *Body) (*Body, error) {
select {
case resp := <-respCh:
if resp.Error != nil {
// Decode error.
return nil, errors.FromCode(resp.Error.Module, resp.Error.Code, resp.Error.Message)
}

return resp, nil
case <-c.closeCh:
return nil, fmt.Errorf("connection closed")
case <-ctx.Done():
return nil, ctx.Err()
}
}

func (c *connection) workerOutgoing() {
defer c.quitWg.Done()

Expand Down Expand Up @@ -450,7 +454,6 @@ func (c *connection) handleMessage(ctx context.Context, message *Message) {
}

respCh <- &message.Body
close(respCh)
default:
c.logger.Warn("received a malformed message from worker, ignoring",
"message", fmt.Sprintf("%+v", message),
Expand All @@ -468,14 +471,6 @@ func (c *connection) workerIncoming() {
// Cancel all request handlers.
cancel()

// Close all pending request channels.
c.Lock()
for id, ch := range c.pendingRequests {
close(ch)
delete(c.pendingRequests, id)
}
c.Unlock()

c.quitWg.Done()
}()

Expand Down Expand Up @@ -583,7 +578,7 @@ func NewConnection(logger *logging.Logger, runtimeID common.Namespace, handler H
runtimeID: runtimeID,
handler: handler,
state: stateUninitialized,
pendingRequests: make(map[uint64]chan *Body),
pendingRequests: make(map[uint64]chan<- *Body),
outCh: make(chan *Message),
closeCh: make(chan struct{}),
logger: logger,
Expand Down

0 comments on commit 82e0afd

Please sign in to comment.