From b6583f36596216a3cb5c059fdf85a19512800cff Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 29 Mar 2023 15:06:23 -0700 Subject: [PATCH] Update azeventhubs to latest go-amqp (#20387) --- .../azeventhubs/consumer_client_test.go | 32 +- .../azeventhubs/internal/amqp_fakes.go | 14 +- .../azeventhubs/internal/amqpwrap/amqpwrap.go | 10 +- sdk/messaging/azeventhubs/internal/errors.go | 49 +- .../azeventhubs/internal/errors_test.go | 26 +- .../azeventhubs/internal/go-amqp/client.go | 183 --- .../azeventhubs/internal/go-amqp/conn.go | 864 ++++++------- .../azeventhubs/internal/go-amqp/const.go | 42 +- .../{manualCreditor.go => creditor.go} | 29 +- .../azeventhubs/internal/go-amqp/doc.go | 13 - .../azeventhubs/internal/go-amqp/errors.go | 126 +- .../go-amqp/internal/bitmap/bitmap.go | 1 + .../go-amqp/internal/buffer/buffer.go | 1 + .../internal/go-amqp/internal/debug/debug.go | 20 + .../go-amqp/internal/debug/debug_debug.go | 51 + .../go-amqp/internal/encoding/decode.go | 35 +- .../go-amqp/internal/encoding/encode.go | 31 +- .../go-amqp/internal/encoding/types.go | 57 +- .../go-amqp/internal/frames/frames.go | 48 +- .../go-amqp/internal/frames/parsing.go | 32 + .../internal/go-amqp/internal/log/log.go | 11 - .../go-amqp/internal/log/log_debug.go | 31 - .../internal/go-amqp/internal/queue/queue.go | 164 +++ .../go-amqp/internal/shared/shared.go | 36 + .../azeventhubs/internal/go-amqp/link.go | 1072 +++-------------- .../internal/go-amqp/link_options.go | 99 +- .../azeventhubs/internal/go-amqp/message.go | 39 +- .../azeventhubs/internal/go-amqp/receiver.go | 793 +++++++++--- .../azeventhubs/internal/go-amqp/sasl.go | 59 +- .../azeventhubs/internal/go-amqp/sender.go | 354 +++++- .../azeventhubs/internal/go-amqp/session.go | 570 ++++++--- .../azeventhubs/internal/links_test.go | 5 +- .../azeventhubs/internal/links_unit_test.go | 16 +- .../azeventhubs/internal/mock/mock_amqp.go | 32 +- .../azeventhubs/internal/mock/mock_helpers.go | 4 +- .../azeventhubs/internal/namespace.go | 4 +- .../azeventhubs/internal/namespace_test.go | 6 +- sdk/messaging/azeventhubs/internal/rpc.go | 8 +- .../internal/utils/retrier_test.go | 8 +- sdk/messaging/azeventhubs/partition_client.go | 15 +- .../azeventhubs/partition_client_unit_test.go | 10 +- sdk/messaging/azeventhubs/producer_client.go | 7 +- 42 files changed, 2646 insertions(+), 2361 deletions(-) delete mode 100644 sdk/messaging/azeventhubs/internal/go-amqp/client.go rename sdk/messaging/azeventhubs/internal/go-amqp/{manualCreditor.go => creditor.go} (75%) delete mode 100644 sdk/messaging/azeventhubs/internal/go-amqp/doc.go create mode 100644 sdk/messaging/azeventhubs/internal/go-amqp/internal/debug/debug.go create mode 100644 sdk/messaging/azeventhubs/internal/go-amqp/internal/debug/debug_debug.go delete mode 100644 sdk/messaging/azeventhubs/internal/go-amqp/internal/log/log.go delete mode 100644 sdk/messaging/azeventhubs/internal/go-amqp/internal/log/log_debug.go create mode 100644 sdk/messaging/azeventhubs/internal/go-amqp/internal/queue/queue.go create mode 100644 sdk/messaging/azeventhubs/internal/go-amqp/internal/shared/shared.go diff --git a/sdk/messaging/azeventhubs/consumer_client_test.go b/sdk/messaging/azeventhubs/consumer_client_test.go index 5d70759c4b0a..45f27fcfeabe 100644 --- a/sdk/messaging/azeventhubs/consumer_client_test.go +++ b/sdk/messaging/azeventhubs/consumer_client_test.go @@ -22,14 +22,20 @@ import ( ) func TestConsumerClient_UsingWebSockets(t *testing.T) { - // NOTE: This error is coming from the `nhooyr.io/websocket` package. There's an - // open discussion here: - // https://github.com/nhooyr/websocket/discussions/380 - // - // The frame it's waiting for (at this point) is the other half of the websocket CLOSE handshake. - // I wireshark'd this and confirmed that the frame does arrive, it's just not read by the local - // package. In this context, since the connection has already shut down, this is harmless. - var expectedWSErr = "failed to close WebSocket: failed to read frame header: EOF" + const ( + // NOTE: This error is coming from the `nhooyr.io/websocket` package. There's an + // open discussion here: + // https://github.com/nhooyr/websocket/discussions/380 + // + // The frame it's waiting for (at this point) is the other half of the websocket CLOSE handshake. + // I wireshark'd this and confirmed that the frame does arrive, it's just not read by the local + // package. In this context, since the connection has already shut down, this is harmless. + expectedWSErr1 = "failed to close WebSocket: failed to read frame header: EOF" + + // in addition, the returned error on close doesn't implement net.ErrClosed so we can also see this. + // https://github.com/nhooyr/websocket/issues/286 + expectedWSErr2 = "failed to read: WebSocket closed: sent close frame: status = StatusNormalClosure and reason = \"\"" + ) newWebSocketConnFn := func(ctx context.Context, args azeventhubs.WebSocketConnParams) (net.Conn, error) { opts := &websocket.DialOptions{ @@ -53,7 +59,10 @@ func TestConsumerClient_UsingWebSockets(t *testing.T) { defer func() { err := producerClient.Close(context.Background()) - require.EqualError(t, err, expectedWSErr) + require.Error(t, err) + if es := err.Error(); es != expectedWSErr1 && es != expectedWSErr2 { + t.Fatalf("unexpected error %v", err) + } }() partProps, err := producerClient.GetPartitionProperties(context.Background(), "0", nil) @@ -79,7 +88,10 @@ func TestConsumerClient_UsingWebSockets(t *testing.T) { defer func() { err := consumerClient.Close(context.Background()) - require.EqualError(t, err, expectedWSErr) + require.Error(t, err) + if es := err.Error(); es != expectedWSErr1 && es != expectedWSErr2 { + t.Fatalf("unexpected error %v", err) + } }() partClient, err := consumerClient.NewPartitionClient("0", &azeventhubs.PartitionClientOptions{ diff --git a/sdk/messaging/azeventhubs/internal/amqp_fakes.go b/sdk/messaging/azeventhubs/internal/amqp_fakes.go index dc9be3828dd6..7aca36a8f05e 100644 --- a/sdk/messaging/azeventhubs/internal/amqp_fakes.go +++ b/sdk/messaging/azeventhubs/internal/amqp_fakes.go @@ -34,14 +34,14 @@ type FakeAMQPReceiver struct { amqpwrap.AMQPReceiverCloser // ActiveCredits are incremented and decremented by IssueCredit and Receive. - ActiveCredits uint32 + ActiveCredits int32 // IssuedCredit just accumulates, so we can get an idea of how many credits we issued overall. IssuedCredit []uint32 // CreditsSetFromOptions is similar to issuedCredit, but only tracks credits added in via the LinkOptions.Credit // field (ie, enabling prefetch). - CreditsSetFromOptions uint32 + CreditsSetFromOptions int32 // ManualCreditsSetFromOptions is the value of the LinkOptions.ManualCredits value. ManualCreditsSetFromOptions bool @@ -71,10 +71,10 @@ func (ns *FakeNSForPartClient) NewAMQPSession(ctx context.Context) (amqpwrap.AMQ func (sess *FakeAMQPSession) NewReceiver(ctx context.Context, source string, opts *amqp.ReceiverOptions) (amqpwrap.AMQPReceiverCloser, error) { sess.NS.NewReceiverCalled++ - sess.NS.Receiver.ManualCreditsSetFromOptions = opts.ManualCredits + sess.NS.Receiver.ManualCreditsSetFromOptions = opts.Credit == -1 sess.NS.Receiver.CreditsSetFromOptions = opts.Credit - if !opts.ManualCredits { + if opts.Credit > 0 { sess.NS.Receiver.ActiveCredits = opts.Credit } @@ -92,11 +92,11 @@ func (sess *FakeAMQPSession) Close(ctx context.Context) error { } func (r *FakeAMQPReceiver) Credits() uint32 { - return r.ActiveCredits + return uint32(r.ActiveCredits) } func (r *FakeAMQPReceiver) IssueCredit(credit uint32) error { - r.ActiveCredits += credit + r.ActiveCredits += int32(credit) r.IssuedCredit = append(r.IssuedCredit, credit) return nil } @@ -105,7 +105,7 @@ func (r *FakeAMQPReceiver) LinkName() string { return r.NameForLink } -func (r *FakeAMQPReceiver) Receive(ctx context.Context) (*amqp.Message, error) { +func (r *FakeAMQPReceiver) Receive(ctx context.Context, o *amqp.ReceiveOptions) (*amqp.Message, error) { if len(r.Messages) > 0 { r.ActiveCredits-- m := r.Messages[0] diff --git a/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go b/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go index 4ad735498e59..fb2c29af12e0 100644 --- a/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go +++ b/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go @@ -14,7 +14,7 @@ import ( // AMQPReceiver is implemented by *amqp.Receiver type AMQPReceiver interface { IssueCredit(credit uint32) error - Receive(ctx context.Context) (*amqp.Message, error) + Receive(ctx context.Context, o *amqp.ReceiveOptions) (*amqp.Message, error) Prefetched() *amqp.Message // settlement functions @@ -40,7 +40,7 @@ type AMQPReceiverCloser interface { // AMQPSender is implemented by *amqp.Sender type AMQPSender interface { - Send(ctx context.Context, msg *amqp.Message) error + Send(ctx context.Context, msg *amqp.Message, o *amqp.SendOptions) error MaxMessageSize() uint64 LinkName() string } @@ -84,7 +84,7 @@ type RPCResponse struct { // It exists only so we can return AMQPSession, which itself only exists so we can // return interfaces for AMQPSender and AMQPReceiver from AMQPSession. type AMQPClientWrapper struct { - Inner *amqp.Client + Inner *amqp.Conn } func (w *AMQPClientWrapper) Close() error { @@ -150,8 +150,8 @@ func (rw *AMQPReceiverWrapper) IssueCredit(credit uint32) error { return err } -func (rw *AMQPReceiverWrapper) Receive(ctx context.Context) (*amqp.Message, error) { - message, err := rw.inner.Receive(ctx) +func (rw *AMQPReceiverWrapper) Receive(ctx context.Context, o *amqp.ReceiveOptions) (*amqp.Message, error) { + message, err := rw.inner.Receive(ctx, o) if err != nil { return nil, err diff --git a/sdk/messaging/azeventhubs/internal/errors.go b/sdk/messaging/azeventhubs/internal/errors.go index 7a3566a346ef..1af71a84f9a3 100644 --- a/sdk/messaging/azeventhubs/internal/errors.go +++ b/sdk/messaging/azeventhubs/internal/errors.go @@ -67,7 +67,7 @@ func TransformError(err error) error { // there are a few errors that all boil down to "bad creds or unauthorized" var amqpErr *amqp.Error - if errors.As(err, &amqpErr) && amqpErr.Condition == amqp.ErrorUnauthorizedAccess { + if errors.As(err, &amqpErr) && amqpErr.Condition == amqp.ErrCondUnauthorizedAccess { return exported.NewError(exported.ErrorCodeUnauthorizedAccess, err) } @@ -97,7 +97,7 @@ func IsQuickRecoveryError(err error) bool { return false } - var de *amqp.DetachError + var de *amqp.LinkError return errors.As(err, &de) } @@ -122,30 +122,30 @@ func IsDrainingError(err error) bool { return strings.Contains(err.Error(), "link is currently draining") } -const errorConditionLockLost = amqp.ErrorCondition("com.microsoft:message-lock-lost") +const errorConditionLockLost = amqp.ErrCond("com.microsoft:message-lock-lost") -var amqpConditionsToRecoveryKind = map[amqp.ErrorCondition]RecoveryKind{ +var amqpConditionsToRecoveryKind = map[amqp.ErrCond]RecoveryKind{ // no recovery needed, these are temporary errors. - amqp.ErrorCondition("com.microsoft:server-busy"): RecoveryKindNone, - amqp.ErrorCondition("com.microsoft:timeout"): RecoveryKindNone, - amqp.ErrorCondition("com.microsoft:operation-cancelled"): RecoveryKindNone, + amqp.ErrCond("com.microsoft:server-busy"): RecoveryKindNone, + amqp.ErrCond("com.microsoft:timeout"): RecoveryKindNone, + amqp.ErrCond("com.microsoft:operation-cancelled"): RecoveryKindNone, // Link recovery needed - amqp.ErrorDetachForced: RecoveryKindLink, // "amqp:link:detach-forced" - amqp.ErrorTransferLimitExceeded: RecoveryKindLink, // "amqp:link:transfer-limit-exceeded" + amqp.ErrCondDetachForced: RecoveryKindLink, // "amqp:link:detach-forced" + amqp.ErrCondTransferLimitExceeded: RecoveryKindLink, // "amqp:link:transfer-limit-exceeded" // Connection recovery needed - amqp.ErrorConnectionForced: RecoveryKindConn, // "amqp:connection:forced" - amqp.ErrorInternalError: RecoveryKindConn, // "amqp:internal-error" + amqp.ErrCondConnectionForced: RecoveryKindConn, // "amqp:connection:forced" + amqp.ErrCondInternalError: RecoveryKindConn, // "amqp:internal-error" // No recovery possible - this operation is non retriable. - amqp.ErrorMessageSizeExceeded: RecoveryKindFatal, // "amqp:link:message-size-exceeded" - amqp.ErrorUnauthorizedAccess: RecoveryKindFatal, // creds are bad - amqp.ErrorNotFound: RecoveryKindFatal, // "amqp:not-found" - amqp.ErrorNotAllowed: RecoveryKindFatal, // "amqp:not-allowed" - amqp.ErrorCondition("com.microsoft:entity-disabled"): RecoveryKindFatal, // entity is disabled in the portal - amqp.ErrorCondition("com.microsoft:session-cannot-be-locked"): RecoveryKindFatal, - errorConditionLockLost: RecoveryKindFatal, + amqp.ErrCondMessageSizeExceeded: RecoveryKindFatal, // "amqp:link:message-size-exceeded" + amqp.ErrCondUnauthorizedAccess: RecoveryKindFatal, // creds are bad + amqp.ErrCondNotFound: RecoveryKindFatal, // "amqp:not-found" + amqp.ErrCondNotAllowed: RecoveryKindFatal, // "amqp:not-allowed" + amqp.ErrCond("com.microsoft:entity-disabled"): RecoveryKindFatal, // entity is disabled in the portal + amqp.ErrCond("com.microsoft:session-cannot-be-locked"): RecoveryKindFatal, + errorConditionLockLost: RecoveryKindFatal, } // GetRecoveryKind determines the recovery type for non-session based links. @@ -192,15 +192,16 @@ func GetRecoveryKind(err error) RecoveryKind { } // check the "special" AMQP errors that aren't condition-based. - if errors.Is(err, amqp.ErrLinkClosed) || IsQuickRecoveryError(err) { + if IsQuickRecoveryError(err) { return RecoveryKindLink } - var connErr *amqp.ConnectionError + var connErr *amqp.ConnError + var sessionErr *amqp.SessionError if errors.As(err, &connErr) || // session closures appear to leak through when the connection itself is going down. - errors.Is(err, amqp.ErrSessionClosed) { + errors.As(err, &sessionErr) { return RecoveryKindConn } @@ -333,7 +334,7 @@ func IsNotAllowedError(err error) bool { var e *amqp.Error return errors.As(err, &e) && - e.Condition == amqp.ErrorNotAllowed + e.Condition == amqp.ErrCondNotAllowed } func (e ErrConnectionClosed) Error() string { @@ -341,10 +342,10 @@ func (e ErrConnectionClosed) Error() string { } func IsOwnershipLostError(err error) bool { - var de *amqp.DetachError + var de *amqp.LinkError if errors.As(err, &de) { - return de.RemoteError != nil && de.RemoteError.Condition == "amqp:link:stolen" + return de.RemoteErr != nil && de.RemoteErr.Condition == "amqp:link:stolen" } return false diff --git a/sdk/messaging/azeventhubs/internal/errors_test.go b/sdk/messaging/azeventhubs/internal/errors_test.go index 8c1a35290231..0152186cee38 100644 --- a/sdk/messaging/azeventhubs/internal/errors_test.go +++ b/sdk/messaging/azeventhubs/internal/errors_test.go @@ -17,9 +17,9 @@ import ( ) func TestOwnershipLost(t *testing.T) { - detachErr := &amqp.DetachError{ - RemoteError: &amqp.Error{ - Condition: amqp.ErrorCondition("amqp:link:stolen"), + detachErr := &amqp.LinkError{ + RemoteErr: &amqp.Error{ + Condition: amqp.ErrCond("amqp:link:stolen"), }, } @@ -32,15 +32,15 @@ func TestOwnershipLost(t *testing.T) { require.ErrorAs(t, transformedErr, &err) require.Equal(t, exported.ErrorCodeOwnershipLost, err.Code) - require.False(t, IsOwnershipLostError(&amqp.DetachError{})) - require.False(t, IsOwnershipLostError(&amqp.ConnectionError{})) + require.False(t, IsOwnershipLostError(&amqp.LinkError{})) + require.False(t, IsOwnershipLostError(&amqp.ConnError{})) require.False(t, IsOwnershipLostError(errors.New("definitely not an ownership lost error"))) } func TestGetRecoveryKind(t *testing.T) { require.Equal(t, GetRecoveryKind(nil), RecoveryKindNone) require.Equal(t, GetRecoveryKind(errConnResetNeeded), RecoveryKindConn) - require.Equal(t, GetRecoveryKind(&amqp.DetachError{}), RecoveryKindLink) + require.Equal(t, GetRecoveryKind(&amqp.LinkError{}), RecoveryKindLink) require.Equal(t, GetRecoveryKind(context.Canceled), RecoveryKindFatal) require.Equal(t, GetRecoveryKind(RPCError{Resp: &amqpwrap.RPCResponse{Code: http.StatusUnauthorized}}), RecoveryKindFatal) require.Equal(t, GetRecoveryKind(RPCError{Resp: &amqpwrap.RPCResponse{Code: http.StatusNotFound}}), RecoveryKindFatal) @@ -49,9 +49,9 @@ func TestGetRecoveryKind(t *testing.T) { func Test_TransformError(t *testing.T) { var asExportedErr *exported.Error - err := TransformError(&amqp.DetachError{ - RemoteError: &amqp.Error{ - Condition: amqp.ErrorCondition("amqp:link:stolen"), + err := TransformError(&amqp.LinkError{ + RemoteErr: &amqp.Error{ + Condition: amqp.ErrCond("amqp:link:stolen"), }, }) require.ErrorAs(t, err, &asExportedErr) @@ -61,7 +61,7 @@ func Test_TransformError(t *testing.T) { require.ErrorAs(t, err, &asExportedErr) require.Equal(t, exported.ErrorCodeUnauthorizedAccess, asExportedErr.Code) - err = TransformError(&amqp.Error{Condition: amqp.ErrorUnauthorizedAccess}) + err = TransformError(&amqp.Error{Condition: amqp.ErrCondUnauthorizedAccess}) require.ErrorAs(t, err, &asExportedErr) require.Equal(t, exported.ErrorCodeUnauthorizedAccess, asExportedErr.Code) @@ -75,14 +75,14 @@ func Test_TransformError(t *testing.T) { require.False(t, errors.As(err, &asExportedErr)) // sanity check, an RPCError but it's not a azservicebus.Code type error. - err = TransformError(&amqp.Error{Condition: amqp.ErrorNotFound}) + err = TransformError(&amqp.Error{Condition: amqp.ErrCondNotFound}) require.False(t, errors.As(err, &asExportedErr)) - err = TransformError(amqp.ErrLinkClosed) + err = TransformError(&amqp.LinkError{}) require.ErrorAs(t, err, &asExportedErr) require.Equal(t, exported.ErrorCodeConnectionLost, asExportedErr.Code) - err = TransformError(&amqp.ConnectionError{}) + err = TransformError(&amqp.ConnError{}) require.ErrorAs(t, err, &asExportedErr) require.Equal(t, exported.ErrorCodeConnectionLost, asExportedErr.Code) diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/client.go b/sdk/messaging/azeventhubs/internal/go-amqp/client.go deleted file mode 100644 index 70ec26f63c01..000000000000 --- a/sdk/messaging/azeventhubs/internal/go-amqp/client.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright (C) 2017 Kale Blankenship -// Portions Copyright (c) Microsoft Corporation -package amqp - -import ( - "context" - "encoding/base64" - "fmt" - "math/rand" - "net" - "sync" - "time" - - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/log" -) - -// Client is an AMQP client connection. -type Client struct { - conn *conn -} - -// Dial connects to an AMQP server. -// -// If the addr includes a scheme, it must be "amqp", "amqps", or "amqp+ssl". -// If no port is provided, 5672 will be used for "amqp" and 5671 for "amqps" or "amqp+ssl". -// -// If username and password information is not empty it's used as SASL PLAIN -// credentials, equal to passing ConnSASLPlain option. -// -// opts: pass nil to accept the default values. -func Dial(addr string, opts *ConnOptions) (*Client, error) { - c, err := dialConn(addr, opts) - if err != nil { - return nil, err - } - err = c.Start() - if err != nil { - return nil, err - } - return &Client{conn: c}, nil -} - -// New establishes an AMQP client connection over conn. -// opts: pass nil to accept the default values. -func New(conn net.Conn, opts *ConnOptions) (*Client, error) { - c, err := newConn(conn, opts) - if err != nil { - return nil, err - } - err = c.Start() - if err != nil { - return nil, err - } - return &Client{conn: c}, nil -} - -// Close disconnects the connection. -func (c *Client) Close() error { - return c.conn.Close() -} - -// NewSession opens a new AMQP session to the server. -// Returns ErrConnClosed if the underlying connection has been closed. -// opts: pass nil to accept the default values. -func (c *Client) NewSession(ctx context.Context, opts *SessionOptions) (*Session, error) { - // get a session allocated by Client.mux - var sResp newSessionResp - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-c.conn.Done: - return nil, c.conn.Err() - case sResp = <-c.conn.NewSession: - } - - if sResp.err != nil { - return nil, sResp.err - } - s := sResp.session - s.init(opts) - - // send Begin to server - begin := &frames.PerformBegin{ - NextOutgoingID: 0, - IncomingWindow: s.incomingWindow, - OutgoingWindow: s.outgoingWindow, - HandleMax: s.handleMax, - } - log.Debug(1, "TX (NewSession): %s", begin) - _ = s.txFrame(begin, nil) - - // wait for response - var fr frames.Frame - select { - case <-ctx.Done(): - // TODO: this will leak s - return nil, ctx.Err() - case <-c.conn.Done: - return nil, c.conn.Err() - case fr = <-s.rx: - } - log.Debug(1, "RX (NewSession): %s", fr.Body) - - begin, ok := fr.Body.(*frames.PerformBegin) - if !ok { - // this codepath is hard to hit (impossible?). if the response isn't a PerformBegin and we've not - // yet seen the remote channel number, the default clause in conn.mux will protect us from that. - // if we have seen the remote channel number then it's likely the session.mux for that channel will - // either swallow the frame or blow up in some other way, both causing this call to hang. - // deallocate session on error. we can't call - // s.Close() as the session mux hasn't started yet. - select { - case <-ctx.Done(): - // TODO: this will leak s - return nil, ctx.Err() - case c.conn.DelSession <- s: - } - return nil, fmt.Errorf("unexpected begin response: %+v", fr.Body) - } - - // start Session multiplexor - go s.mux(begin) - - return s, nil -} - -// SessionOption contains the optional settings for configuring an AMQP session. -type SessionOptions struct { - // IncomingWindow sets the maximum number of unacknowledged - // transfer frames the server can send. - IncomingWindow uint32 - - // OutgoingWindow sets the maximum number of unacknowledged - // transfer frames the client can send. - OutgoingWindow uint32 - - // MaxLinks sets the maximum number of links (Senders/Receivers) - // allowed on the session. - // - // Minimum: 1. - // Default: 4294967295. - MaxLinks uint32 -} - -// lockedRand provides a rand source that is safe for concurrent use. -type lockedRand struct { - mu sync.Mutex - src *rand.Rand -} - -func (r *lockedRand) Read(p []byte) (int, error) { - r.mu.Lock() - defer r.mu.Unlock() - return r.src.Read(p) -} - -// package scoped rand source to avoid any issues with seeding -// of the global source. -var pkgRand = &lockedRand{ - src: rand.New(rand.NewSource(time.Now().UnixNano())), -} - -// randBytes returns a base64 encoded string of n bytes. -func randString(n int) string { - b := make([]byte, n) - // from math/rand, cannot fail - _, _ = pkgRand.Read(b) - return base64.RawURLEncoding.EncodeToString(b) -} - -// linkKey uniquely identifies a link on a connection by name and direction. -// -// A link can be identified uniquely by the ordered tuple (source-container-id, target-container-id, name). -// On a single connection the container ID pairs can be abbreviated -// to a boolean flag indicating the direction of the link. -type linkKey struct { - name string - role encoding.Role // Local role: sender/receiver -} - -const maxTransferFrameHeader = 66 // determined by calcMaxTransferFrameHeader diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/conn.go b/sdk/messaging/azeventhubs/internal/go-amqp/conn.go index 6f5f3a86ead8..5f6da19590c8 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/conn.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/conn.go @@ -1,11 +1,12 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package amqp import ( "bytes" + "context" "crypto/tls" - "encoding/binary" "errors" "fmt" "math" @@ -16,9 +17,10 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/bitmap" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/buffer" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/shared" ) // Default connection options @@ -39,12 +41,12 @@ type ConnOptions struct { // Open frame and TLS ServerName (if not otherwise set). HostName string - // IdleTimeout specifies the maximum period in milliseconds between + // IdleTimeout specifies the maximum period between // receiving frames from the peer. // // Specify a value less than zero to disable idle timeout. // - // Default: 1 minute. + // Default: 1 minute (60000000000). IdleTimeout time.Duration // MaxFrameSize sets the maximum frame size that @@ -62,20 +64,11 @@ type ConnOptions struct { MaxSessions uint16 // Properties sets an entry in the connection properties map sent to the server. - Properties map[string]interface{} + Properties map[string]any // SASLType contains the specified SASL authentication mechanism. SASLType SASLType - // Timeout configures how long to wait for the - // server during connection establishment. - // - // Once the connection has been established, IdleTimeout - // applies. If duration is zero, no timeout will be applied. - // - // Default: 0. - Timeout time.Duration - // TLSConfig sets the tls.Config to be used during // TLS negotiation. // @@ -87,19 +80,47 @@ type ConnOptions struct { dialer dialer } -// used to abstract the underlying dialer for testing purposes -type dialer interface { - NetDialerDial(c *conn, host, port string) error - TLSDialWithDialer(c *conn, host, port string) error +// Dial connects to an AMQP server. +// +// If the addr includes a scheme, it must be "amqp", "amqps", or "amqp+ssl". +// If no port is provided, 5672 will be used for "amqp" and 5671 for "amqps" or "amqp+ssl". +// +// If username and password information is not empty it's used as SASL PLAIN +// credentials, equal to passing ConnSASLPlain option. +// +// opts: pass nil to accept the default values. +func Dial(ctx context.Context, addr string, opts *ConnOptions) (*Conn, error) { + deadline, _ := ctx.Deadline() + c, err := dialConn(deadline, addr, opts) + if err != nil { + return nil, err + } + err = c.start(deadline) + if err != nil { + return nil, err + } + return c, nil +} + +// NewConn establishes a new AMQP client connection over conn. +// opts: pass nil to accept the default values. +func NewConn(ctx context.Context, conn net.Conn, opts *ConnOptions) (*Conn, error) { + c, err := newConn(conn, opts) + if err != nil { + return nil, err + } + deadline, _ := ctx.Deadline() + err = c.start(deadline) + if err != nil { + return nil, err + } + return c, nil } -// conn is an AMQP connection. -// only exported fields and methods are part of public surface area, -// all others are considered to be internal implementation details. -type conn struct { - net net.Conn // underlying connection - connectTimeout time.Duration // time to wait for reads/writes during conn establishment - dialer dialer // used for testing purposes, it allows faking dialing TCP/TLS endpoints +// Conn is an AMQP connection. +type Conn struct { + net net.Conn // underlying connection + dialer dialer // used for testing purposes, it allows faking dialing TCP/TLS endpoints // TLS tlsNegotiation bool // negotiate TLS @@ -111,62 +132,64 @@ type conn struct { saslComplete bool // SASL negotiation complete; internal *except* for SASL auth methods // local settings - maxFrameSize uint32 // max frame size to accept - channelMax uint16 // maximum number of channels to allow - hostname string // hostname of remote server (set explicitly or parsed from URL) - idleTimeout time.Duration // maximum period between receiving frames - properties map[encoding.Symbol]interface{} // additional properties sent upon connection open - containerID string // set explicitly or randomly generated + maxFrameSize uint32 // max frame size to accept + channelMax uint16 // maximum number of channels to allow + hostname string // hostname of remote server (set explicitly or parsed from URL) + idleTimeout time.Duration // maximum period between receiving frames + properties map[encoding.Symbol]any // additional properties sent upon connection open + containerID string // set explicitly or randomly generated // peer settings peerIdleTimeout time.Duration // maximum period between sending frames - PeerMaxFrameSize uint32 // maximum frame size peer will accept + peerMaxFrameSize uint32 // maximum frame size peer will accept // conn state - errMu sync.Mutex // mux holds errMu from start until shutdown completes; operations are sequential before mux is started - err error // error to be returned to client - Done chan struct{} // indicates the connection is done + done chan struct{} // indicates the connection has terminated + doneErr error // contains the error state returned from Close(); DO NOT TOUCH outside of conn.go until done has been closed! + + // connReader and connWriter management + rxtxExit chan struct{} // signals connReader and connWriter to exit + closeOnce sync.Once // ensures that close() is only called once - // mux - NewSession chan newSessionResp // new Sessions are requested from mux by reading off this channel - DelSession chan *Session // session completion is indicated to mux by sending the Session on this channel - connErr chan error // connReader/Writer notifications of an error - closeMux chan struct{} // indicates that the mux should stop - closeMuxOnce sync.Once + // session tracking + channels *bitmap.Bitmap + sessionsByChannel map[uint16]*Session + sessionsByChannelMu sync.RWMutex // connReader - rxProto chan protoHeader // protoHeaders received by connReader - rxFrame chan frames.Frame // AMQP frames received by connReader - rxDone chan struct{} - connReaderRun chan func() // functions to be run by conn reader (set deadline on conn to run) + rxBuf buffer.Buffer // incoming bytes buffer + rxDone chan struct{} // closed when connReader exits + rxErr error // contains last error reading from c.net; DO NOT TOUCH outside of connReader until rxDone has been closed! // connWriter txFrame chan frames.Frame // AMQP frames to be sent by connWriter txBuf buffer.Buffer // buffer for marshaling frames before transmitting - txDone chan struct{} + txDone chan struct{} // closed when connWriter exits + txErr error // contains last error writing to c.net; DO NOT TOUCH outside of connWriter until txDone has been closed! } -type newSessionResp struct { - session *Session - err error +// used to abstract the underlying dialer for testing purposes +type dialer interface { + NetDialerDial(deadline time.Time, c *Conn, host, port string) error + TLSDialWithDialer(deadline time.Time, c *Conn, host, port string) error } // implements the dialer interface type defaultDialer struct{} -func (defaultDialer) NetDialerDial(c *conn, host, port string) (err error) { - dialer := &net.Dialer{Timeout: c.connectTimeout} +func (defaultDialer) NetDialerDial(deadline time.Time, c *Conn, host, port string) (err error) { + dialer := &net.Dialer{Deadline: deadline} c.net, err = dialer.Dial("tcp", net.JoinHostPort(host, port)) return } -func (defaultDialer) TLSDialWithDialer(c *conn, host, port string) (err error) { - dialer := &net.Dialer{Timeout: c.connectTimeout} +func (defaultDialer) TLSDialWithDialer(deadline time.Time, c *Conn, host, port string) (err error) { + dialer := &net.Dialer{Deadline: deadline} c.net, err = tls.DialWithDialer(dialer, "tcp", net.JoinHostPort(host, port), c.tlsConfig) return } -func dialConn(addr string, opts *ConnOptions) (*conn, error) { +func dialConn(deadline time.Time, addr string, opts *ConnOptions) (*Conn, error) { u, err := url.Parse(addr) if err != nil { return nil, err @@ -201,11 +224,11 @@ func dialConn(addr string, opts *ConnOptions) (*conn, error) { switch u.Scheme { case "amqp", "": - err = c.dialer.NetDialerDial(c, host, port) + err = c.dialer.NetDialerDial(deadline, c, host, port) case "amqps", "amqp+ssl": c.initTLSConfig() c.tlsNegotiation = false - err = c.dialer.TLSDialWithDialer(c, host, port) + err = c.dialer.TLSDialWithDialer(deadline, c, host, port) default: err = fmt.Errorf("unsupported scheme %q", u.Scheme) } @@ -216,26 +239,21 @@ func dialConn(addr string, opts *ConnOptions) (*conn, error) { return c, nil } -func newConn(netConn net.Conn, opts *ConnOptions) (*conn, error) { - c := &conn{ - dialer: defaultDialer{}, - net: netConn, - maxFrameSize: defaultMaxFrameSize, - PeerMaxFrameSize: defaultMaxFrameSize, - channelMax: defaultMaxSessions - 1, // -1 because channel-max starts at zero - idleTimeout: defaultIdleTimeout, - containerID: randString(40), - Done: make(chan struct{}), - connErr: make(chan error, 2), // buffered to ensure connReader/Writer won't leak - closeMux: make(chan struct{}), - rxProto: make(chan protoHeader), - rxFrame: make(chan frames.Frame), - rxDone: make(chan struct{}), - connReaderRun: make(chan func(), 1), // buffered to allow queueing function before interrupt - NewSession: make(chan newSessionResp), - DelSession: make(chan *Session), - txFrame: make(chan frames.Frame), - txDone: make(chan struct{}), +func newConn(netConn net.Conn, opts *ConnOptions) (*Conn, error) { + c := &Conn{ + dialer: defaultDialer{}, + net: netConn, + maxFrameSize: defaultMaxFrameSize, + peerMaxFrameSize: defaultMaxFrameSize, + channelMax: defaultMaxSessions - 1, // -1 because channel-max starts at zero + idleTimeout: defaultIdleTimeout, + containerID: shared.RandString(40), + done: make(chan struct{}), + rxtxExit: make(chan struct{}), + rxDone: make(chan struct{}), + txFrame: make(chan frames.Frame), + txDone: make(chan struct{}), + sessionsByChannel: map[uint16]*Session{}, } // apply options @@ -267,11 +285,8 @@ func newConn(netConn net.Conn, opts *ConnOptions) (*conn, error) { return nil, err } } - if opts.Timeout > 0 { - c.connectTimeout = opts.Timeout - } if opts.Properties != nil { - c.properties = make(map[encoding.Symbol]interface{}) + c.properties = make(map[encoding.Symbol]any) for key, val := range opts.Properties { c.properties[encoding.Symbol(key)] = val } @@ -282,11 +297,10 @@ func newConn(netConn net.Conn, opts *ConnOptions) (*conn, error) { if opts.dialer != nil { c.dialer = opts.dialer } - return c, nil } -func (c *conn) initTLSConfig() { +func (c *Conn) initTLSConfig() { // create a new config if not already set if c.tlsConfig == nil { c.tlsConfig = new(tls.Config) @@ -298,11 +312,11 @@ func (c *conn) initTLSConfig() { } } -// Start establishes the connection and begins multiplexing network IO. +// start establishes the connection and begins multiplexing network IO. // It is an error to call Start() on a connection that's been closed. -func (c *conn) Start() error { - // start reader - go c.connReader() +func (c *Conn) start(deadline time.Time) error { + // set connection establishment deadline + _ = c.net.SetDeadline(deadline) // run connection establishment state machine for state := c.negotiateProto; state != nil; { @@ -311,345 +325,297 @@ func (c *conn) Start() error { // check if err occurred if err != nil { close(c.txDone) // close here since connWriter hasn't been started yet + close(c.rxDone) _ = c.Close() return err } } - // start multiplexor and writer - go c.mux() + // remove connection establishment deadline + _ = c.net.SetDeadline(time.Time{}) + + // we can't create the channel bitmap until the connection has been established. + // this is because our peer can tell us the max channels they support. + c.channels = bitmap.New(uint32(c.channelMax)) + go c.connWriter() + go c.connReader() return nil } // Close closes the connection. -func (c *conn) Close() error { - c.closeMuxOnce.Do(func() { close(c.closeMux) }) - err := c.Err() - var connErr *ConnectionError - if errors.As(err, &connErr) && connErr.inner == nil { +func (c *Conn) Close() error { + c.close() + var connErr *ConnError + if errors.As(c.doneErr, &connErr) && connErr.RemoteErr == nil && connErr.inner == nil { // an empty ConnectionError means the connection was closed by the caller - // or as requested by the peer and no error was provided in the close frame. return nil } - return err + + // there was an error during shut-down or connReader/connWriter + // experienced a terminal error + return c.doneErr } -// close should only be called by conn.mux. -func (c *conn) close() { - close(c.Done) // notify goroutines and blocked functions to exit +// close is called once, either from Close() or when connReader/connWriter exits +func (c *Conn) close() { + c.closeOnce.Do(func() { + defer close(c.done) - // wait for writing to stop, allows it to send the final close frame - <-c.txDone + close(c.rxtxExit) - // reading from connErr in mux can race with closeMux, causing - // a pending conn read/write error to be lost. now that the - // mux has exited, drain any pending error. - select { - case err := <-c.connErr: - c.err = err - default: - // no pending read/write error - } + // wait for writing to stop, allows it to send the final close frame + <-c.txDone - err := c.net.Close() - switch { - // conn.err already set - // TODO: err info is lost, log it? - case c.err != nil: + closeErr := c.net.Close() - // conn.err not set and c.net.Close() returned a non-nil error - case err != nil: - c.err = err + // check rxDone after closing net, otherwise may block + // for up to c.idleTimeout + <-c.rxDone - // no errors - default: - } + if errors.Is(c.rxErr, net.ErrClosed) { + // this is the expected error when the connection is closed, swallow it + c.rxErr = nil + } - // check rxDone after closing net, otherwise may block - // for up to c.idleTimeout - <-c.rxDone + if c.txErr == nil && c.rxErr == nil && closeErr == nil { + // if there are no errors, it means user initiated close() and we shut down cleanly + c.doneErr = &ConnError{} + } else if amqpErr, ok := c.rxErr.(*Error); ok { + // we experienced a peer-initiated close that contained an Error. return it + c.doneErr = &ConnError{RemoteErr: amqpErr} + } else if c.txErr != nil { + c.doneErr = &ConnError{inner: c.txErr} + } else if c.rxErr != nil { + c.doneErr = &ConnError{inner: c.rxErr} + } else { + c.doneErr = &ConnError{inner: closeErr} + } + }) } -// Err returns the connection's error state after it's been closed. -// Calling this on an open connection will block until the connection is closed. -func (c *conn) Err() error { - c.errMu.Lock() - defer c.errMu.Unlock() - return &ConnectionError{inner: c.err} +// NewSession starts a new session on the connection. +// - ctx controls waiting for the peer to acknowledge the session +// - opts contains optional values, pass nil to accept the defaults +// +// If the context's deadline expires or is cancelled before the operation +// completes, the application can be left in an unknown state, potentially +// resulting in connection errors. +func (c *Conn) NewSession(ctx context.Context, opts *SessionOptions) (*Session, error) { + session, err := c.newSession(opts) + if err != nil { + return nil, err + } + + if err := session.begin(ctx); err != nil { + c.deleteSession(session) + return nil, err + } + + return session, nil } -// mux is started in it's own goroutine after initial connection establishment. -// It handles muxing of sessions, keepalives, and connection errors. -func (c *conn) mux() { - var ( - // allocated channels - channels = bitmap.New(uint32(c.channelMax)) +func (c *Conn) newSession(opts *SessionOptions) (*Session, error) { + c.sessionsByChannelMu.Lock() + defer c.sessionsByChannelMu.Unlock() - // create the next session to allocate - // note that channel always start at 0, and 0 is special and can't be deleted - nextChannel, _ = channels.Next() - nextSession = newSessionResp{session: newSession(c, uint16(nextChannel))} + // create the next session to allocate + // note that channel always start at 0 + channel, ok := c.channels.Next() + if !ok { + return nil, fmt.Errorf("reached connection channel max (%d)", c.channelMax) + } + session := newSession(c, uint16(channel), opts) + c.sessionsByChannel[session.channel] = session - // map channels to sessions - sessionsByChannel = make(map[uint16]*Session) - sessionsByRemoteChannel = make(map[uint16]*Session) - ) + return session, nil +} + +func (c *Conn) deleteSession(s *Session) { + c.sessionsByChannelMu.Lock() + defer c.sessionsByChannelMu.Unlock() + + delete(c.sessionsByChannel, s.channel) + c.channels.Remove(uint32(s.channel)) +} - // hold the errMu lock until error or done - c.errMu.Lock() - defer c.errMu.Unlock() - defer c.close() // defer order is important. c.errMu unlock indicates that connection is finally complete +// connReader reads from the net.Conn, decodes frames, and either handles +// them here as appropriate or sends them to the session.rx channel. +func (c *Conn) connReader() { + defer func() { + close(c.rxDone) + c.close() + }() + var sessionsByRemoteChannel = make(map[uint16]*Session) + var err error for { - // check if last loop returned an error - if c.err != nil { + if err != nil { + debug.Log(1, "RX (connReader): terminal error: %v", err) + c.rxErr = err return } - select { - // error from connReader - case c.err = <-c.connErr: - - // new frame from connReader - case fr := <-c.rxFrame: - var ( - session *Session - ok bool - ) - - switch body := fr.Body.(type) { - // Server initiated close. - case *frames.PerformClose: - if body.Error != nil { - c.err = body.Error - } - return + var fr frames.Frame + fr, err = c.readFrame() + if err != nil { + continue + } - // RemoteChannel should be used when frame is Begin - case *frames.PerformBegin: - if body.RemoteChannel == nil { - // since we only support remotely-initiated sessions, this is an error - // TODO: it would be ideal to not have this kill the connection - c.err = fmt.Errorf("%T: nil RemoteChannel", fr.Body) - break - } - session, ok = sessionsByChannel[*body.RemoteChannel] - if !ok { - c.err = fmt.Errorf("unexpected remote channel number %d, expected %d", *body.RemoteChannel, nextChannel) - break - } - - session.remoteChannel = fr.Channel - sessionsByRemoteChannel[fr.Channel] = session - - case *frames.PerformEnd: - session, ok = sessionsByRemoteChannel[fr.Channel] - if !ok { - c.err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel (PerformEnd)", fr.Body, fr.Channel) - break - } - // we MUST remove the remote channel from our map as soon as we receive - // the ack (i.e. before passing it on to the session mux) on the session - // ending since the numbers are recycled. - delete(sessionsByRemoteChannel, fr.Channel) - - default: - // pass on performative to the correct session - session, ok = sessionsByRemoteChannel[fr.Channel] - if !ok { - c.err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel", fr.Body, fr.Channel) - } - } + debug.Log(1, "RX (connReader): %s", fr) - if !ok { - continue - } + var ( + session *Session + ok bool + ) - select { - case session.rx <- fr: - case <-c.closeMux: + switch body := fr.Body.(type) { + // Server initiated close. + case *frames.PerformClose: + // connWriter will send the close performative ack on its way out. + // it's a SHOULD though, not a MUST. + if body.Error == nil { return } + err = body.Error + continue - // new session request - // - // Continually try to send the next session on the channel, - // then add it to the sessions map. This allows us to control ID - // allocation and prevents the need to have shared map. Since new - // sessions are far less frequent than frames being sent to sessions, - // this avoids the lock/unlock for session lookup. - case c.NewSession <- nextSession: - if nextSession.err != nil { + // RemoteChannel should be used when frame is Begin + case *frames.PerformBegin: + if body.RemoteChannel == nil { + // since we only support remotely-initiated sessions, this is an error + // TODO: it would be ideal to not have this kill the connection + err = fmt.Errorf("%T: nil RemoteChannel", fr.Body) continue } - - // save session into map - ch := nextSession.session.channel - sessionsByChannel[ch] = nextSession.session - - // get next available channel - next, ok := channels.Next() + c.sessionsByChannelMu.RLock() + session, ok = c.sessionsByChannel[*body.RemoteChannel] + c.sessionsByChannelMu.RUnlock() if !ok { - nextSession = newSessionResp{err: fmt.Errorf("reached connection channel max (%d)", c.channelMax)} + // this can happen if NewSession() exits due to the context expiring/cancelled + // before the begin ack is received. + err = fmt.Errorf("unexpected remote channel number %d", *body.RemoteChannel) continue } - // create the next session to send - nextSession = newSessionResp{session: newSession(c, uint16(next))} + session.remoteChannel = fr.Channel + sessionsByRemoteChannel[fr.Channel] = session - // session deletion - case s := <-c.DelSession: - delete(sessionsByChannel, s.channel) - channels.Remove(uint32(s.channel)) - - // connection is complete - case <-c.closeMux: - return + case *frames.PerformEnd: + session, ok = sessionsByRemoteChannel[fr.Channel] + if !ok { + err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel (PerformEnd)", fr.Body, fr.Channel) + continue + } + // we MUST remove the remote channel from our map as soon as we receive + // the ack (i.e. before passing it on to the session mux) on the session + // ending since the numbers are recycled. + delete(sessionsByRemoteChannel, fr.Channel) + + default: + // pass on performative to the correct session + session, ok = sessionsByRemoteChannel[fr.Channel] + if !ok { + err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel", fr.Body, fr.Channel) + continue + } } + + q := session.rxQ.Acquire() + q.Enqueue(fr.Body) + session.rxQ.Release(q) + debug.Log(2, "RX (connReader): mux frame to session: %s", fr) } } -// connReader reads from the net.Conn, decodes frames, and passes them -// up via the conn.rxFrame and conn.rxProto channels. -func (c *conn) connReader() { - defer close(c.rxDone) +// readFrame reads a complete frame from c.net. +// it assumes that any read deadline has already been applied. +// used externally by SASL only. +func (c *Conn) readFrame() (frames.Frame, error) { + switch { + // Cheaply reuse free buffer space when fully read. + case c.rxBuf.Len() == 0: + c.rxBuf.Reset() - buf := &buffer.Buffer{} + // Prevent excessive/unbounded growth by shifting data to beginning of buffer. + case int64(c.rxBuf.Size()) > int64(c.maxFrameSize): + c.rxBuf.Reclaim() + } var ( - negotiating = true // true during conn establishment, check for protoHeaders currentHeader frames.Header // keep track of the current header, for frames split across multiple TCP packets frameInProgress bool // true if in the middle of receiving data for currentHeader ) for { - switch { - // Cheaply reuse free buffer space when fully read. - case buf.Len() == 0: - buf.Reset() - - // Prevent excessive/unbounded growth by shifting data to beginning of buffer. - case int64(buf.Size()) > int64(c.maxFrameSize): - buf.Reclaim() - } - // need to read more if buf doesn't contain the complete frame // or there's not enough in buf to parse the header - if frameInProgress || buf.Len() < frames.HeaderSize { + if frameInProgress || c.rxBuf.Len() < frames.HeaderSize { + // we MUST reset the idle timeout before each read from net.Conn if c.idleTimeout > 0 { _ = c.net.SetReadDeadline(time.Now().Add(c.idleTimeout)) } - err := buf.ReadFromOnce(c.net) + err := c.rxBuf.ReadFromOnce(c.net) if err != nil { - log.Debug(1, "connReader error: %v", err) - select { - // check if error was due to close in progress - case <-c.Done: - return - - // if there is a pending connReaderRun function, execute it - case f := <-c.connReaderRun: - f() - continue - - // send error to mux and return - default: - c.connErr <- err - return - } + return frames.Frame{}, err } } // read more if buf doesn't contain enough to parse the header - if buf.Len() < frames.HeaderSize { - continue - } - - // during negotiation, check for proto frames - if negotiating && bytes.Equal(buf.Bytes()[:4], []byte{'A', 'M', 'Q', 'P'}) { - p, err := parseProtoHeader(buf) - if err != nil { - c.connErr <- err - return - } - - // negotiation is complete once an AMQP proto frame is received - if p.ProtoID == protoAMQP { - negotiating = false - } - - // send proto header - select { - case <-c.Done: - return - case c.rxProto <- p: - } - + if c.rxBuf.Len() < frames.HeaderSize { continue } // parse the header if a frame isn't in progress if !frameInProgress { var err error - currentHeader, err = frames.ParseHeader(buf) + currentHeader, err = frames.ParseHeader(&c.rxBuf) if err != nil { - c.connErr <- err - return + return frames.Frame{}, err } frameInProgress = true } // check size is reasonable if currentHeader.Size > math.MaxInt32 { // make max size configurable - c.connErr <- errors.New("payload too large") - return + return frames.Frame{}, errors.New("payload too large") } bodySize := int64(currentHeader.Size - frames.HeaderSize) - // the full frame has been received - if int64(buf.Len()) < bodySize { + // the full frame hasn't been received, keep reading + if int64(c.rxBuf.Len()) < bodySize { continue } frameInProgress = false // check if body is empty (keepalive) if bodySize == 0 { + debug.Log(3, "RX (connReader): received keep-alive frame") continue } // parse the frame - b, ok := buf.Next(bodySize) + b, ok := c.rxBuf.Next(bodySize) if !ok { - c.connErr <- fmt.Errorf("buffer EOF; requested bytes: %d, actual size: %d", bodySize, buf.Len()) - return + return frames.Frame{}, fmt.Errorf("buffer EOF; requested bytes: %d, actual size: %d", bodySize, c.rxBuf.Len()) } parsedBody, err := frames.ParseBody(buffer.New(b)) if err != nil { - c.connErr <- err - return + return frames.Frame{}, err } - // send to mux - select { - case <-c.Done: - return - case c.rxFrame <- frames.Frame{Channel: currentHeader.Channel, Body: parsedBody}: - } + return frames.Frame{Channel: currentHeader.Channel, Body: parsedBody}, nil } } -func (c *conn) connWriter() { - defer close(c.txDone) - - // disable write timeout - if c.connectTimeout != 0 { - c.connectTimeout = 0 - _ = c.net.SetWriteDeadline(time.Time{}) - } +func (c *Conn) connWriter() { + defer func() { + close(c.txDone) + c.close() + }() var ( // keepalives are sent at a rate of 1/2 idle timeout @@ -669,14 +635,15 @@ func (c *conn) connWriter() { var err error for { if err != nil { - log.Debug(1, "connWriter error: %v", err) - c.connErr <- err + debug.Log(1, "TX (connWriter): terminal error: %v", err) + c.txErr = err return } select { // frame write request case fr := <-c.txFrame: + debug.Log(1, "TX (connWriter): %s", fr) err = c.writeFrame(fr) if err == nil && fr.Done != nil { close(fr.Done) @@ -684,7 +651,7 @@ func (c *conn) connWriter() { // keepalive timer case <-keepalive: - log.Debug(3, "sending keep-alive frame") + debug.Log(3, "TX (connWriter): sending keep-alive frame") _, err = c.net.Write(keepaliveFrame) // It would be slightly more efficient in terms of network // resources to reset the timer each time a frame is sent. @@ -695,14 +662,17 @@ func (c *conn) connWriter() { // possibly drained, then reset.) // connection complete - case <-c.Done: - // send close - cls := &frames.PerformClose{} - log.Debug(1, "TX (connWriter): %s", cls) - _ = c.writeFrame(frames.Frame{ - Type: frameTypeAMQP, - Body: cls, - }) + case <-c.rxtxExit: + // send close performative. note that the spec says we + // SHOULD wait for the ack but we don't HAVE to, in order + // to be resilient to bad actors etc. so we just send + // the close performative and exit. + fr := frames.Frame{ + Type: frames.TypeAMQP, + Body: &frames.PerformClose{}, + } + debug.Log(1, "TX (connWriter): %s", fr) + c.txErr = c.writeFrame(fr) return } } @@ -710,35 +680,31 @@ func (c *conn) connWriter() { // writeFrame writes a frame to the network. // used externally by SASL only. -func (c *conn) writeFrame(fr frames.Frame) error { - if c.connectTimeout != 0 { - _ = c.net.SetWriteDeadline(time.Now().Add(c.connectTimeout)) - } - +func (c *Conn) writeFrame(fr frames.Frame) error { // writeFrame into txBuf c.txBuf.Reset() - err := writeFrame(&c.txBuf, fr) + err := frames.Write(&c.txBuf, fr) if err != nil { return err } // validate the frame isn't exceeding peer's max frame size requiredFrameSize := c.txBuf.Len() - if uint64(requiredFrameSize) > uint64(c.PeerMaxFrameSize) { - return fmt.Errorf("%T frame size %d larger than peer's max frame size %d", fr, requiredFrameSize, c.PeerMaxFrameSize) + if uint64(requiredFrameSize) > uint64(c.peerMaxFrameSize) { + return fmt.Errorf("%T frame size %d larger than peer's max frame size %d", fr, requiredFrameSize, c.peerMaxFrameSize) } // write to network - _, err = c.net.Write(c.txBuf.Bytes()) + n, err := c.net.Write(c.txBuf.Bytes()) + if l := c.txBuf.Len(); n > 0 && n < l && err != nil { + debug.Log(1, "TX (writeFrame): wrote %d bytes less than len %d: %v", n, l, err) + } return err } // writeProtoHeader writes an AMQP protocol header to the // network -func (c *conn) writeProtoHeader(pID protoID) error { - if c.connectTimeout != 0 { - _ = c.net.SetWriteDeadline(time.Now().Add(c.connectTimeout)) - } +func (c *Conn) writeProtoHeader(pID protoID) error { _, err := c.net.Write([]byte{'A', 'M', 'Q', 'P', byte(pID), 1, 0, 0}) return err } @@ -747,12 +713,13 @@ func (c *conn) writeProtoHeader(pID protoID) error { var keepaliveFrame = []byte{0x00, 0x00, 0x00, 0x08, 0x02, 0x00, 0x00, 0x00} // SendFrame is used by sessions and links to send frames across the network. -func (c *conn) SendFrame(fr frames.Frame) error { +func (c *Conn) sendFrame(fr frames.Frame) error { select { case c.txFrame <- fr: + debug.Log(2, "TX (Conn): mux frame to connWriter: %s", fr) return nil - case <-c.Done: - return c.Err() + case <-c.done: + return c.doneErr } } @@ -764,7 +731,7 @@ type stateFunc func() (stateFunc, error) // negotiateProto determines which proto to negotiate next. // used externally by SASL only. -func (c *conn) negotiateProto() (stateFunc, error) { +func (c *Conn) negotiateProto() (stateFunc, error) { // in the order each must be negotiated switch { case c.tlsNegotiation && !c.tlsComplete: @@ -787,7 +754,7 @@ const ( // exchangeProtoHeader performs the round trip exchange of protocol // headers, validation, and returns the protoID specific next state. -func (c *conn) exchangeProtoHeader(pID protoID) (stateFunc, error) { +func (c *Conn) exchangeProtoHeader(pID protoID) (stateFunc, error) { // write the proto header if err := c.writeProtoHeader(pID); err != nil { return nil, err @@ -817,62 +784,75 @@ func (c *conn) exchangeProtoHeader(pID protoID) (stateFunc, error) { } // readProtoHeader reads a protocol header packet from c.rxProto. -func (c *conn) readProtoHeader() (protoHeader, error) { - var deadline <-chan time.Time - if c.connectTimeout != 0 { - deadline = time.After(c.connectTimeout) +func (c *Conn) readProtoHeader() (protoHeader, error) { + const protoHeaderSize = 8 + + // only read from the network once our buffer has been exhausted. + // TODO: this preserves existing behavior as some tests rely on this + // implementation detail (it lets you replay a stream of bytes). we + // might want to consider removing this and fixing the tests as the + // protocol doesn't actually work this way. + if c.rxBuf.Len() == 0 { + for { + err := c.rxBuf.ReadFromOnce(c.net) + if err != nil { + return protoHeader{}, err + } + + // read more if buf doesn't contain enough to parse the header + if c.rxBuf.Len() >= protoHeaderSize { + break + } + } } - var p protoHeader - select { - case p = <-c.rxProto: - return p, nil - case err := <-c.connErr: - return p, err - case fr := <-c.rxFrame: - return p, fmt.Errorf("readProtoHeader: unexpected frame %#v", fr) - case <-deadline: - return p, errors.New("amqp: timeout waiting for response") + + buf, ok := c.rxBuf.Next(protoHeaderSize) + if !ok { + return protoHeader{}, errors.New("invalid protoHeader") } -} + // bounds check hint to compiler; see golang.org/issue/14808 + _ = buf[protoHeaderSize-1] -// startTLS wraps the conn with TLS and returns to Client.negotiateProto -func (c *conn) startTLS() (stateFunc, error) { - c.initTLSConfig() + if !bytes.Equal(buf[:4], []byte{'A', 'M', 'Q', 'P'}) { + return protoHeader{}, fmt.Errorf("unexpected protocol %q", buf[:4]) + } - // buffered so connReaderRun won't block - done := make(chan error, 1) + p := protoHeader{ + ProtoID: protoID(buf[4]), + Major: buf[5], + Minor: buf[6], + Revision: buf[7], + } - // this function will be executed by connReader - c.connReaderRun <- func() { - defer close(done) - _ = c.net.SetReadDeadline(time.Time{}) // clear timeout + if p.Major != 1 || p.Minor != 0 || p.Revision != 0 { + return protoHeader{}, fmt.Errorf("unexpected protocol version %d.%d.%d", p.Major, p.Minor, p.Revision) + } - // wrap existing net.Conn and perform TLS handshake - tlsConn := tls.Client(c.net, c.tlsConfig) - if c.connectTimeout != 0 { - _ = tlsConn.SetWriteDeadline(time.Now().Add(c.connectTimeout)) - } - done <- tlsConn.Handshake() - // TODO: return? + return p, nil +} - // swap net.Conn - c.net = tlsConn - c.tlsComplete = true - } +// startTLS wraps the conn with TLS and returns to Client.negotiateProto +func (c *Conn) startTLS() (stateFunc, error) { + c.initTLSConfig() - // set deadline to interrupt connReader - _ = c.net.SetReadDeadline(time.Time{}.Add(1)) + _ = c.net.SetReadDeadline(time.Time{}) // clear timeout - if err := <-done; err != nil { + // wrap existing net.Conn and perform TLS handshake + tlsConn := tls.Client(c.net, c.tlsConfig) + if err := tlsConn.Handshake(); err != nil { return nil, err } + // swap net.Conn + c.net = tlsConn + c.tlsComplete = true + // go to next protocol return c.negotiateProto, nil } // openAMQP round trips the AMQP open performative -func (c *conn) openAMQP() (stateFunc, error) { +func (c *Conn) openAMQP() (stateFunc, error) { // send open frame open := &frames.PerformOpen{ ContainerID: c.containerID, @@ -882,30 +862,31 @@ func (c *conn) openAMQP() (stateFunc, error) { IdleTimeout: c.idleTimeout / 2, // per spec, advertise half our idle timeout Properties: c.properties, } - log.Debug(1, "TX (openAMQP): %s", open) - err := c.writeFrame(frames.Frame{ - Type: frameTypeAMQP, + fr := frames.Frame{ + Type: frames.TypeAMQP, Body: open, Channel: 0, - }) + } + debug.Log(1, "TX (openAMQP): %s", fr) + err := c.writeFrame(fr) if err != nil { return nil, err } // get the response - fr, err := c.readFrame() + fr, err = c.readSingleFrame() if err != nil { return nil, err } + debug.Log(1, "RX (openAMQP): %s", fr) o, ok := fr.Body.(*frames.PerformOpen) if !ok { return nil, fmt.Errorf("openAMQP: unexpected frame type %T", fr.Body) } - log.Debug(1, "RX (openAMQP): %s", o) // update peer settings if o.MaxFrameSize > 0 { - c.PeerMaxFrameSize = o.MaxFrameSize + c.peerMaxFrameSize = o.MaxFrameSize } if o.IdleTimeout > 0 { // TODO: reject very small idle timeouts @@ -921,17 +902,17 @@ func (c *conn) openAMQP() (stateFunc, error) { // negotiateSASL returns the SASL handler for the first matched // mechanism specified by the server -func (c *conn) negotiateSASL() (stateFunc, error) { +func (c *Conn) negotiateSASL() (stateFunc, error) { // read mechanisms frame - fr, err := c.readFrame() + fr, err := c.readSingleFrame() if err != nil { return nil, err } + debug.Log(1, "RX (negotiateSASL): %s", fr) sm, ok := fr.Body.(*frames.SASLMechanisms) if !ok { return nil, fmt.Errorf("negotiateSASL: unexpected frame type %T", fr.Body) } - log.Debug(1, "RX (negotiateSASL): %s", sm) // return first match in c.saslHandlers based on order received for _, mech := range sm.Mechanisms { @@ -950,17 +931,17 @@ func (c *conn) negotiateSASL() (stateFunc, error) { // SASL handlers return this stateFunc when the mechanism specific negotiation // has completed. // used externally by SASL only. -func (c *conn) saslOutcome() (stateFunc, error) { +func (c *Conn) saslOutcome() (stateFunc, error) { // read outcome frame - fr, err := c.readFrame() + fr, err := c.readSingleFrame() if err != nil { return nil, err } + debug.Log(1, "RX (saslOutcome): %s", fr) so, ok := fr.Body.(*frames.SASLOutcome) if !ok { return nil, fmt.Errorf("saslOutcome: unexpected frame type %T", fr.Body) } - log.Debug(1, "RX (saslOutcome): %s", so) // check if auth succeeded if so.Code != encoding.CodeSASLOK { @@ -972,27 +953,16 @@ func (c *conn) saslOutcome() (stateFunc, error) { return c.negotiateProto, nil } -// readFrame is used during connection establishment to read a single frame. +// readSingleFrame is used during connection establishment to read a single frame. // -// After setup, conn.mux handles incoming frames. -// used externally by SASL only. -func (c *conn) readFrame() (frames.Frame, error) { - var deadline <-chan time.Time - if c.connectTimeout != 0 { - deadline = time.After(c.connectTimeout) +// After setup, conn.connReader handles incoming frames. +func (c *Conn) readSingleFrame() (frames.Frame, error) { + fr, err := c.readFrame() + if err != nil { + return frames.Frame{}, err } - var fr frames.Frame - select { - case fr = <-c.rxFrame: - return fr, nil - case err := <-c.connErr: - return fr, err - case p := <-c.rxProto: - return fr, fmt.Errorf("unexpected protocol header %#v", p) - case <-deadline: - return fr, errors.New("amqp: timeout waiting for response") - } + return fr, nil } type protoHeader struct { @@ -1001,61 +971,3 @@ type protoHeader struct { Minor uint8 Revision uint8 } - -// parseProtoHeader reads the proto header from r and returns the results -// -// An error is returned if the protocol is not "AMQP" or if the version is not 1.0.0. -func parseProtoHeader(r *buffer.Buffer) (protoHeader, error) { - const protoHeaderSize = 8 - buf, ok := r.Next(protoHeaderSize) - if !ok { - return protoHeader{}, errors.New("invalid protoHeader") - } - _ = buf[7] - - if !bytes.Equal(buf[:4], []byte{'A', 'M', 'Q', 'P'}) { - return protoHeader{}, fmt.Errorf("unexpected protocol %q", buf[:4]) - } - - p := protoHeader{ - ProtoID: protoID(buf[4]), - Major: buf[5], - Minor: buf[6], - Revision: buf[7], - } - - if p.Major != 1 || p.Minor != 0 || p.Revision != 0 { - return p, fmt.Errorf("unexpected protocol version %d.%d.%d", p.Major, p.Minor, p.Revision) - } - return p, nil -} - -// writesFrame encodes fr into buf. -// split out from conn.WriteFrame for testing purposes. -func writeFrame(buf *buffer.Buffer, fr frames.Frame) error { - // write header - buf.Append([]byte{ - 0, 0, 0, 0, // size, overwrite later - 2, // doff, see frameHeader.DataOffset comment - fr.Type, // frame type - }) - buf.AppendUint16(fr.Channel) // channel - - // write AMQP frame body - err := encoding.Marshal(buf, fr.Body) - if err != nil { - return err - } - - // validate size - if uint(buf.Len()) > math.MaxUint32 { - return errors.New("frame too large") - } - - // retrieve raw bytes - bufBytes := buf.Bytes() - - // write correct size - binary.BigEndian.PutUint32(bufBytes, uint32(len(bufBytes))) - return nil -} diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/const.go b/sdk/messaging/azeventhubs/internal/go-amqp/const.go index d192eaeb0eda..fee0b5041525 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/const.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/const.go @@ -1,5 +1,6 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package amqp import "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" @@ -7,13 +8,13 @@ import "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go- // Sender Settlement Modes const ( // Sender will send all deliveries initially unsettled to the receiver. - ModeUnsettled = encoding.ModeUnsettled + SenderSettleModeUnsettled SenderSettleMode = encoding.SenderSettleModeUnsettled // Sender will send all deliveries settled to the receiver. - ModeSettled = encoding.ModeSettled + SenderSettleModeSettled SenderSettleMode = encoding.SenderSettleModeSettled // Sender MAY send a mixture of settled and unsettled deliveries to the receiver. - ModeMixed = encoding.ModeMixed + SenderSettleModeMixed SenderSettleMode = encoding.SenderSettleModeMixed ) // SenderSettleMode specifies how the sender will settle messages. @@ -21,20 +22,23 @@ type SenderSettleMode = encoding.SenderSettleMode func senderSettleModeValue(m *SenderSettleMode) SenderSettleMode { if m == nil { - return ModeMixed + return SenderSettleModeMixed } return *m } // Receiver Settlement Modes const ( - // Receiver will spontaneously settle all incoming transfers. - ModeFirst = encoding.ModeFirst - - // Receiver will only settle after sending the disposition to the - // sender and receiving a disposition indicating settlement of - // the delivery from the sender. - ModeSecond = encoding.ModeSecond + // Receiver is the first to consider the message as settled. + // Once the corresponding disposition frame is sent, the message + // is considered to be settled. + ReceiverSettleModeFirst ReceiverSettleMode = encoding.ReceiverSettleModeFirst + + // Receiver is the second to consider the message as settled. + // Once the corresponding disposition frame is sent, the settlement + // is considered in-flight and the message will not be considered as + // settled until the sender replies acknowledging the settlement. + ReceiverSettleModeSecond ReceiverSettleMode = encoding.ReceiverSettleModeSecond ) // ReceiverSettleMode specifies how the receiver will settle messages. @@ -42,7 +46,7 @@ type ReceiverSettleMode = encoding.ReceiverSettleMode func receiverSettleModeValue(m *ReceiverSettleMode) ReceiverSettleMode { if m == nil { - return ModeFirst + return ReceiverSettleModeFirst } return *m } @@ -50,16 +54,16 @@ func receiverSettleModeValue(m *ReceiverSettleMode) ReceiverSettleMode { // Durability Policies const ( // No terminus state is retained durably. - DurabilityNone = encoding.DurabilityNone + DurabilityNone Durability = encoding.DurabilityNone // Only the existence and configuration of the terminus is // retained durably. - DurabilityConfiguration = encoding.DurabilityConfiguration + DurabilityConfiguration Durability = encoding.DurabilityConfiguration // In addition to the existence and configuration of the // terminus, the unsettled state for durable messages is // retained durably. - DurabilityUnsettledState = encoding.DurabilityUnsettledState + DurabilityUnsettledState Durability = encoding.DurabilityUnsettledState ) // Durability specifies the durability of a link. @@ -68,18 +72,18 @@ type Durability = encoding.Durability // Expiry Policies const ( // The expiry timer starts when terminus is detached. - ExpiryLinkDetach = encoding.ExpiryLinkDetach + ExpiryPolicyLinkDetach ExpiryPolicy = encoding.ExpiryLinkDetach // The expiry timer starts when the most recently // associated session is ended. - ExpirySessionEnd = encoding.ExpirySessionEnd + ExpiryPolicySessionEnd ExpiryPolicy = encoding.ExpirySessionEnd // The expiry timer starts when most recently associated // connection is closed. - ExpiryConnectionClose = encoding.ExpiryConnectionClose + ExpiryPolicyConnectionClose ExpiryPolicy = encoding.ExpiryConnectionClose // The terminus never expires. - ExpiryNever = encoding.ExpiryNever + ExpiryPolicyNever ExpiryPolicy = encoding.ExpiryNever ) // ExpiryPolicy specifies when the expiry timer of a terminus diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/manualCreditor.go b/sdk/messaging/azeventhubs/internal/go-amqp/creditor.go similarity index 75% rename from sdk/messaging/azeventhubs/internal/go-amqp/manualCreditor.go rename to sdk/messaging/azeventhubs/internal/go-amqp/creditor.go index 6b50d5c1abae..184702bca7d2 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/manualCreditor.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/creditor.go @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation + package amqp import ( @@ -7,7 +8,7 @@ import ( "sync" ) -type manualCreditor struct { +type creditor struct { mu sync.Mutex // future values for the next flow frame. @@ -19,11 +20,13 @@ type manualCreditor struct { drained chan struct{} } -var errLinkDraining = errors.New("link is currently draining, no credits can be added") -var errAlreadyDraining = errors.New("drain already in process") +var ( + errLinkDraining = errors.New("link is currently draining, no credits can be added") + errAlreadyDraining = errors.New("drain already in process") +) // EndDrain ends the current drain, unblocking any active Drain calls. -func (mc *manualCreditor) EndDrain() { +func (mc *creditor) EndDrain() { mc.mu.Lock() defer mc.mu.Unlock() @@ -40,7 +43,7 @@ func (mc *manualCreditor) EndDrain() { // (drain: true, credits: 0) if a flow is needed (drain) // (drain: false, credits > 0) if a flow is needed (issue credit) // (drain: false, credits == 0) if no flow needed. -func (mc *manualCreditor) FlowBits(currentCredits uint32) (bool, uint32) { +func (mc *creditor) FlowBits(currentCredits uint32) (bool, uint32) { mc.mu.Lock() defer mc.mu.Unlock() @@ -67,7 +70,9 @@ func (mc *manualCreditor) FlowBits(currentCredits uint32) (bool, uint32) { } // Drain initiates a drain and blocks until EndDrain is called. -func (mc *manualCreditor) Drain(ctx context.Context, l *link) error { +// If the context's deadline expires or is cancelled before the operation +// completes, the drain might not have happened. +func (mc *creditor) Drain(ctx context.Context, r *Receiver) error { mc.mu.Lock() if mc.drained != nil { @@ -82,12 +87,18 @@ func (mc *manualCreditor) Drain(ctx context.Context, l *link) error { mc.mu.Unlock() + // cause mux() to check our flow conditions. + select { + case r.receiverReady <- struct{}{}: + default: + } + // send drain, wait for responding flow frame select { case <-drained: return nil - case <-l.Detached: - return &DetachError{RemoteError: l.detachError} + case <-r.l.done: + return r.l.doneErr case <-ctx.Done(): return ctx.Err() } @@ -95,7 +106,7 @@ func (mc *manualCreditor) Drain(ctx context.Context, l *link) error { // IssueCredit queues up additional credits to be requested at the next // call of FlowBits() -func (mc *manualCreditor) IssueCredit(credits uint32) error { +func (mc *creditor) IssueCredit(credits uint32) error { mc.mu.Lock() defer mc.mu.Unlock() diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/doc.go b/sdk/messaging/azeventhubs/internal/go-amqp/doc.go deleted file mode 100644 index 91945b7cd405..000000000000 --- a/sdk/messaging/azeventhubs/internal/go-amqp/doc.go +++ /dev/null @@ -1,13 +0,0 @@ -/* -Package amqp provides an AMQP 1.0 client implementation. - -AMQP 1.0 is not compatible with AMQP 0-9-1 or 0-10, which are -the most common AMQP protocols in use today. - -The example below shows how to use this package to connect -to a Microsoft Azure Service Bus queue. -*/ - -// Copyright (C) 2017 Kale Blankenship -// Portions Copyright (c) Microsoft Corporation -package amqp diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/errors.go b/sdk/messaging/azeventhubs/internal/go-amqp/errors.go index 298b83a95e91..515a7c36bca3 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/errors.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/errors.go @@ -1,85 +1,107 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package amqp import ( - "errors" - "fmt" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" ) +// ErrCond is an AMQP defined error condition. +// See http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-transport-v1.0-os.html#type-amqp-error for info on their meaning. +type ErrCond = encoding.ErrCond + // Error Conditions const ( // AMQP Errors - ErrorInternalError ErrorCondition = "amqp:internal-error" - ErrorNotFound ErrorCondition = "amqp:not-found" - ErrorUnauthorizedAccess ErrorCondition = "amqp:unauthorized-access" - ErrorDecodeError ErrorCondition = "amqp:decode-error" - ErrorResourceLimitExceeded ErrorCondition = "amqp:resource-limit-exceeded" - ErrorNotAllowed ErrorCondition = "amqp:not-allowed" - ErrorInvalidField ErrorCondition = "amqp:invalid-field" - ErrorNotImplemented ErrorCondition = "amqp:not-implemented" - ErrorResourceLocked ErrorCondition = "amqp:resource-locked" - ErrorPreconditionFailed ErrorCondition = "amqp:precondition-failed" - ErrorResourceDeleted ErrorCondition = "amqp:resource-deleted" - ErrorIllegalState ErrorCondition = "amqp:illegal-state" - ErrorFrameSizeTooSmall ErrorCondition = "amqp:frame-size-too-small" + ErrCondDecodeError ErrCond = "amqp:decode-error" + ErrCondFrameSizeTooSmall ErrCond = "amqp:frame-size-too-small" + ErrCondIllegalState ErrCond = "amqp:illegal-state" + ErrCondInternalError ErrCond = "amqp:internal-error" + ErrCondInvalidField ErrCond = "amqp:invalid-field" + ErrCondNotAllowed ErrCond = "amqp:not-allowed" + ErrCondNotFound ErrCond = "amqp:not-found" + ErrCondNotImplemented ErrCond = "amqp:not-implemented" + ErrCondPreconditionFailed ErrCond = "amqp:precondition-failed" + ErrCondResourceDeleted ErrCond = "amqp:resource-deleted" + ErrCondResourceLimitExceeded ErrCond = "amqp:resource-limit-exceeded" + ErrCondResourceLocked ErrCond = "amqp:resource-locked" + ErrCondUnauthorizedAccess ErrCond = "amqp:unauthorized-access" // Connection Errors - ErrorConnectionForced ErrorCondition = "amqp:connection:forced" - ErrorFramingError ErrorCondition = "amqp:connection:framing-error" - ErrorConnectionRedirect ErrorCondition = "amqp:connection:redirect" + ErrCondConnectionForced ErrCond = "amqp:connection:forced" + ErrCondConnectionRedirect ErrCond = "amqp:connection:redirect" + ErrCondFramingError ErrCond = "amqp:connection:framing-error" // Session Errors - ErrorWindowViolation ErrorCondition = "amqp:session:window-violation" - ErrorErrantLink ErrorCondition = "amqp:session:errant-link" - ErrorHandleInUse ErrorCondition = "amqp:session:handle-in-use" - ErrorUnattachedHandle ErrorCondition = "amqp:session:unattached-handle" + ErrCondErrantLink ErrCond = "amqp:session:errant-link" + ErrCondHandleInUse ErrCond = "amqp:session:handle-in-use" + ErrCondUnattachedHandle ErrCond = "amqp:session:unattached-handle" + ErrCondWindowViolation ErrCond = "amqp:session:window-violation" // Link Errors - ErrorDetachForced ErrorCondition = "amqp:link:detach-forced" - ErrorTransferLimitExceeded ErrorCondition = "amqp:link:transfer-limit-exceeded" - ErrorMessageSizeExceeded ErrorCondition = "amqp:link:message-size-exceeded" - ErrorLinkRedirect ErrorCondition = "amqp:link:redirect" - ErrorStolen ErrorCondition = "amqp:link:stolen" + ErrCondDetachForced ErrCond = "amqp:link:detach-forced" + ErrCondLinkRedirect ErrCond = "amqp:link:redirect" + ErrCondMessageSizeExceeded ErrCond = "amqp:link:message-size-exceeded" + ErrCondStolen ErrCond = "amqp:link:stolen" + ErrCondTransferLimitExceeded ErrCond = "amqp:link:transfer-limit-exceeded" ) +// Error is an AMQP error. type Error = encoding.Error -type ErrorCondition = encoding.ErrorCondition +// LinkError is returned by methods on Sender/Receiver when the link has closed. +type LinkError struct { + // RemoteErr contains any error information provided by the peer if the peer detached the link. + RemoteErr *Error -// DetachError is returned by a link (Receiver/Sender) when a detach frame is received. -// -// RemoteError will be nil if the link was detached gracefully. -type DetachError struct { - RemoteError *Error + inner error } -func (e *DetachError) Error() string { - return fmt.Sprintf("link detached, reason: %+v", e.RemoteError) +// Error implements the error interface for LinkError. +func (e *LinkError) Error() string { + if e.RemoteErr == nil && e.inner == nil { + return "amqp: link closed" + } else if e.RemoteErr != nil { + return e.RemoteErr.Error() + } + return e.inner.Error() } -// Errors -var ( - // ErrSessionClosed is propagated to Sender/Receivers - // when Session.Close() is called. - ErrSessionClosed = errors.New("amqp: session closed") - - // ErrLinkClosed is returned by send and receive operations when - // Sender.Close() or Receiver.Close() are called. - ErrLinkClosed = errors.New("amqp: link closed") -) +// ConnError is returned by methods on Conn and propagated to Session and Senders/Receivers +// when the connection has been closed. +type ConnError struct { + // RemoteErr contains any error information provided by the peer if the peer closed the AMQP connection. + RemoteErr *Error -// ConnectionError is propagated to Session and Senders/Receivers -// when the connection has been closed or is no longer functional. -type ConnectionError struct { inner error } -func (c *ConnectionError) Error() string { - if c.inner == nil { +// Error implements the error interface for ConnectionError. +func (e *ConnError) Error() string { + if e.RemoteErr == nil && e.inner == nil { return "amqp: connection closed" + } else if e.RemoteErr != nil { + return e.RemoteErr.Error() + } + return e.inner.Error() +} + +// SessionError is returned by methods on Session and propagated to Senders/Receivers +// when the session has been closed. +type SessionError struct { + // RemoteErr contains any error information provided by the peer if the peer closed the session. + RemoteErr *Error + + inner error +} + +// Error implements the error interface for SessionError. +func (e *SessionError) Error() string { + if e.RemoteErr == nil && e.inner == nil { + return "amqp: session closed" + } else if e.RemoteErr != nil { + return e.RemoteErr.Error() } - return c.inner.Error() + return e.inner.Error() } diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/bitmap/bitmap.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/bitmap/bitmap.go index 09198e23c343..d4d682e9199e 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/internal/bitmap/bitmap.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/internal/bitmap/bitmap.go @@ -1,5 +1,6 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package bitmap import ( diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/buffer/buffer.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/buffer/buffer.go index 98f606391307..b82e5fab76a6 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/internal/buffer/buffer.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/internal/buffer/buffer.go @@ -1,5 +1,6 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package buffer import ( diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug/debug.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug/debug.go new file mode 100644 index 000000000000..3e6821e1f723 --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug/debug.go @@ -0,0 +1,20 @@ +// Copyright (C) 2017 Kale Blankenship +// Portions Copyright (c) Microsoft Corporation + +//go:build !debug +// +build !debug + +package debug + +// dummy functions used when debugging is not enabled + +// Log writes the formatted string to stderr. +// Level indicates the verbosity of the messages to log. +// The greater the value, the more verbose messages will be logged. +func Log(_ int, _ string, _ ...any) {} + +// Assert panics if the specified condition is false. +func Assert(bool) {} + +// Assert panics with the provided message if the specified condition is false. +func Assertf(bool, string, ...any) {} diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug/debug_debug.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug/debug_debug.go new file mode 100644 index 000000000000..96d53768a5c9 --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug/debug_debug.go @@ -0,0 +1,51 @@ +// Copyright (C) 2017 Kale Blankenship +// Portions Copyright (c) Microsoft Corporation + +//go:build debug +// +build debug + +package debug + +import ( + "fmt" + "log" + "os" + "strconv" +) + +var ( + debugLevel = 1 + logger = log.New(os.Stderr, "", log.Lmicroseconds) +) + +func init() { + level, err := strconv.Atoi(os.Getenv("DEBUG_LEVEL")) + if err != nil { + return + } + + debugLevel = level +} + +// Log writes the formatted string to stderr. +// Level indicates the verbosity of the messages to log. +// The greater the value, the more verbose messages will be logged. +func Log(level int, format string, v ...any) { + if level <= debugLevel { + logger.Printf(format, v...) + } +} + +// Assert panics if the specified condition is false. +func Assert(condition bool) { + if !condition { + panic("assertion failed!") + } +} + +// Assert panics with the provided message if the specified condition is false. +func Assertf(condition bool, msg string, v ...any) { + if !condition { + panic(fmt.Sprintf(msg, v...)) + } +} diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/decode.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/decode.go index 8fefb2b78874..1de2be5f70a9 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/decode.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/decode.go @@ -1,5 +1,6 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package encoding import ( @@ -30,10 +31,10 @@ type unmarshaler interface { // If i is a pointer to a pointer (**Type), it will be dereferenced and a new instance // of (*Type) is allocated via reflection. // -// Common map types (map[string]string, map[Symbol]interface{}, and -// map[interface{}]interface{}), will be decoded via conversion to the mapStringAny, +// Common map types (map[string]string, map[Symbol]any, and +// map[any]any), will be decoded via conversion to the mapStringAny, // mapSymbolAny, and mapAnyAny types. -func Unmarshal(r *buffer.Buffer, i interface{}) error { +func Unmarshal(r *buffer.Buffer, i any) error { if tryReadNull(r) { return nil } @@ -171,13 +172,13 @@ func Unmarshal(r *buffer.Buffer, i interface{}) error { return (*arrayTimestamp)(t).Unmarshal(r) case *[]UUID: return (*arrayUUID)(t).Unmarshal(r) - case *[]interface{}: + case *[]any: return (*list)(t).Unmarshal(r) - case *map[interface{}]interface{}: + case *map[any]any: return (*mapAnyAny)(t).Unmarshal(r) - case *map[string]interface{}: + case *map[string]any: return (*mapStringAny)(t).Unmarshal(r) - case *map[Symbol]interface{}: + case *map[Symbol]any: return (*mapSymbolAny)(t).Unmarshal(r) case *DeliveryState: type_, _, err := PeekMessageType(r.Bytes()) @@ -201,7 +202,7 @@ func Unmarshal(r *buffer.Buffer, i interface{}) error { } return Unmarshal(r, *t) - case *interface{}: + case *any: v, err := ReadAny(r) if err != nil { return err @@ -288,7 +289,7 @@ func UnmarshalComposite(r *buffer.Buffer, type_ AMQPType, fields ...UnmarshalFie // An optional nullHandler can be set. If the composite field being unmarshaled // is null and handleNull is not nil, nullHandler will be called. type UnmarshalField struct { - Field interface{} + Field any HandleNull NullHandler } @@ -479,7 +480,7 @@ func readBinary(r *buffer.Buffer) ([]byte, error) { return append([]byte(nil), buf...), nil } -func ReadAny(r *buffer.Buffer) (interface{}, error) { +func ReadAny(r *buffer.Buffer) (any, error) { if tryReadNull(r) { return nil, nil } @@ -580,8 +581,8 @@ func ReadAny(r *buffer.Buffer) (interface{}, error) { } } -func readAnyMap(r *buffer.Buffer) (interface{}, error) { - var m map[interface{}]interface{} +func readAnyMap(r *buffer.Buffer) (any, error) { + var m map[any]any err := (*mapAnyAny)(&m).Unmarshal(r) if err != nil { return nil, err @@ -604,7 +605,7 @@ Loop: } if stringKeys { - mm := make(map[string]interface{}, len(m)) + mm := make(map[string]any, len(m)) for key, value := range m { switch key := key.(type) { case string: @@ -619,13 +620,13 @@ Loop: return m, nil } -func readAnyList(r *buffer.Buffer) (interface{}, error) { - var a []interface{} +func readAnyList(r *buffer.Buffer) (any, error) { + var a []any err := (*list)(&a).Unmarshal(r) return a, err } -func readAnyArray(r *buffer.Buffer) (interface{}, error) { +func readAnyArray(r *buffer.Buffer) (any, error) { // get the array type buf := r.Bytes() if len(buf) < 1 { @@ -715,7 +716,7 @@ func readAnyArray(r *buffer.Buffer) (interface{}, error) { } } -func readComposite(r *buffer.Buffer) (interface{}, error) { +func readComposite(r *buffer.Buffer) (any, error) { buf := r.Bytes() if len(buf) < 2 { diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/encode.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/encode.go index 26319094c2d2..1103c84f2b26 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/encode.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/encode.go @@ -1,5 +1,6 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package encoding import ( @@ -17,7 +18,7 @@ type marshaler interface { Marshal(*buffer.Buffer) error } -func Marshal(wr *buffer.Buffer, i interface{}) error { +func Marshal(wr *buffer.Buffer, i any) error { switch t := i.(type) { case nil: wr.AppendByte(byte(TypeCodeNull)) @@ -105,17 +106,17 @@ func Marshal(wr *buffer.Buffer, i interface{}) error { return WriteBinary(wr, t) case *[]byte: return WriteBinary(wr, *t) - case map[interface{}]interface{}: + case map[any]any: return writeMap(wr, t) - case *map[interface{}]interface{}: + case *map[any]any: return writeMap(wr, *t) - case map[string]interface{}: + case map[string]any: return writeMap(wr, t) - case *map[string]interface{}: + case *map[string]any: return writeMap(wr, *t) - case map[Symbol]interface{}: + case map[Symbol]any: return writeMap(wr, t) - case *map[Symbol]interface{}: + case *map[Symbol]any: return writeMap(wr, *t) case Unsettled: return writeMap(wr, t) @@ -185,9 +186,9 @@ func Marshal(wr *buffer.Buffer, i interface{}) error { return arrayUUID(t).Marshal(wr) case *[]UUID: return arrayUUID(*t).Marshal(wr) - case []interface{}: + case []any: return list(t).Marshal(wr) - case *[]interface{}: + case *[]any: return list(*t).Marshal(wr) case marshaler: return t.Marshal(wr) @@ -277,8 +278,8 @@ func writeTimestamp(wr *buffer.Buffer, t time.Time) { // marshalField is a field to be marshaled type MarshalField struct { - Value interface{} // value to be marshaled, use pointers to avoid interface conversion overhead - Omit bool // indicates that this field should be omitted (set to null) + Value any // value to be marshaled, use pointers to avoid interface conversion overhead + Omit bool // indicates that this field should be omitted (set to null) } // marshalComposite is a helper for us in a composite's marshal() function. @@ -406,7 +407,7 @@ func WriteBinary(wr *buffer.Buffer, bin []byte) error { } } -func writeMap(wr *buffer.Buffer, m interface{}) error { +func writeMap(wr *buffer.Buffer, m any) error { startIdx := wr.Len() wr.Append([]byte{ byte(TypeCodeMap32), // type @@ -416,7 +417,7 @@ func writeMap(wr *buffer.Buffer, m interface{}) error { var pairs int switch m := m.(type) { - case map[interface{}]interface{}: + case map[any]any: pairs = len(m) * 2 for key, val := range m { err := Marshal(wr, key) @@ -428,7 +429,7 @@ func writeMap(wr *buffer.Buffer, m interface{}) error { return err } } - case map[string]interface{}: + case map[string]any: pairs = len(m) * 2 for key, val := range m { err := writeString(wr, key) @@ -440,7 +441,7 @@ func writeMap(wr *buffer.Buffer, m interface{}) error { return err } } - case map[Symbol]interface{}: + case map[Symbol]any: pairs = len(m) * 2 for key, val := range m { err := key.Marshal(wr) diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/types.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/types.go index 64278d5ea39e..5196d49b4d4c 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/types.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding/types.go @@ -1,5 +1,6 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package encoding import ( @@ -219,13 +220,13 @@ func (e *ExpiryPolicy) String() string { // Sender Settlement Modes const ( // Sender will send all deliveries initially unsettled to the receiver. - ModeUnsettled SenderSettleMode = 0 + SenderSettleModeUnsettled SenderSettleMode = 0 // Sender will send all deliveries settled to the receiver. - ModeSettled SenderSettleMode = 1 + SenderSettleModeSettled SenderSettleMode = 1 // Sender MAY send a mixture of settled and unsettled deliveries to the receiver. - ModeMixed SenderSettleMode = 2 + SenderSettleModeMixed SenderSettleMode = 2 ) // SenderSettleMode specifies how the sender will settle messages. @@ -241,13 +242,13 @@ func (m *SenderSettleMode) String() string { } switch *m { - case ModeUnsettled: + case SenderSettleModeUnsettled: return "unsettled" - case ModeSettled: + case SenderSettleModeSettled: return "settled" - case ModeMixed: + case SenderSettleModeMixed: return "mixed" default: @@ -268,12 +269,12 @@ func (m *SenderSettleMode) Unmarshal(r *buffer.Buffer) error { // Receiver Settlement Modes const ( // Receiver will spontaneously settle all incoming transfers. - ModeFirst ReceiverSettleMode = 0 + ReceiverSettleModeFirst ReceiverSettleMode = 0 // Receiver will only settle after sending the disposition to the // sender and receiving a disposition indicating settlement of // the delivery from the sender. - ModeSecond ReceiverSettleMode = 1 + ReceiverSettleModeSecond ReceiverSettleMode = 1 ) // ReceiverSettleMode specifies how the receiver will settle messages. @@ -289,10 +290,10 @@ func (m *ReceiverSettleMode) String() string { } switch *m { - case ModeFirst: + case ReceiverSettleModeFirst: return "first" - case ModeSecond: + case ReceiverSettleModeSecond: return "second" default: @@ -465,7 +466,7 @@ func tryReadNull(r *buffer.Buffer) bool { // Annotations keys must be of type string, int, or int64. // // String keys are encoded as AMQP Symbols. -type Annotations map[interface{}]interface{} +type Annotations map[any]any func (a Annotations) Marshal(wr *buffer.Buffer) error { return writeMap(wr, a) @@ -493,16 +494,16 @@ func (a *Annotations) Unmarshal(r *buffer.Buffer) error { return nil } -// ErrorCondition is one of the error conditions defined in the AMQP spec. -type ErrorCondition string +// ErrCond is one of the error conditions defined in the AMQP spec. +type ErrCond string -func (ec ErrorCondition) Marshal(wr *buffer.Buffer) error { +func (ec ErrCond) Marshal(wr *buffer.Buffer) error { return (Symbol)(ec).Marshal(wr) } -func (ec *ErrorCondition) Unmarshal(r *buffer.Buffer) error { +func (ec *ErrCond) Unmarshal(r *buffer.Buffer) error { s, err := ReadString(r) - *ec = ErrorCondition(s) + *ec = ErrCond(s) return err } @@ -518,7 +519,7 @@ func (ec *ErrorCondition) Unmarshal(r *buffer.Buffer) error { // Error is an AMQP error. type Error struct { // A symbolic value indicating the error condition. - Condition ErrorCondition + Condition ErrCond // descriptive text about the error condition // @@ -527,7 +528,7 @@ type Error struct { Description string // map carrying information about the error condition - Info map[string]interface{} + Info map[string]any } func (e *Error) Marshal(wr *buffer.Buffer) error { @@ -777,10 +778,10 @@ func (m *Milliseconds) Unmarshal(r *buffer.Buffer) error { // mapAnyAny is used to decode AMQP maps who's keys are undefined or // inconsistently typed. -type mapAnyAny map[interface{}]interface{} +type mapAnyAny map[any]any func (m mapAnyAny) Marshal(wr *buffer.Buffer) error { - return writeMap(wr, map[interface{}]interface{}(m)) + return writeMap(wr, map[any]any(m)) } func (m *mapAnyAny) Unmarshal(r *buffer.Buffer) error { @@ -816,10 +817,10 @@ func (m *mapAnyAny) Unmarshal(r *buffer.Buffer) error { } // mapStringAny is used to decode AMQP maps that have string keys -type mapStringAny map[string]interface{} +type mapStringAny map[string]any func (m mapStringAny) Marshal(wr *buffer.Buffer) error { - return writeMap(wr, map[string]interface{}(m)) + return writeMap(wr, map[string]any(m)) } func (m *mapStringAny) Unmarshal(r *buffer.Buffer) error { @@ -846,10 +847,10 @@ func (m *mapStringAny) Unmarshal(r *buffer.Buffer) error { } // mapStringAny is used to decode AMQP maps that have Symbol keys -type mapSymbolAny map[Symbol]interface{} +type mapSymbolAny map[Symbol]any func (m mapSymbolAny) Marshal(wr *buffer.Buffer) error { - return writeMap(wr, map[Symbol]interface{}(m)) + return writeMap(wr, map[Symbol]any(m)) } func (m *mapSymbolAny) Unmarshal(r *buffer.Buffer) error { @@ -936,8 +937,8 @@ func (p *LifetimePolicy) Unmarshal(r *buffer.Buffer) error { } type DescribedType struct { - Descriptor interface{} - Value interface{} + Descriptor any + Value any } func (t DescribedType) Marshal(wr *buffer.Buffer) error { @@ -2066,7 +2067,7 @@ func (a *arrayUUID) Unmarshal(r *buffer.Buffer) error { // LIST -type list []interface{} +type list []any func (l list) Marshal(wr *buffer.Buffer) error { length := len(l) @@ -2111,7 +2112,7 @@ func (l *list) Unmarshal(r *buffer.Buffer) error { ll := *l if int64(cap(ll)) < length { - ll = make([]interface{}, length) + ll = make([]any, length) } else { ll = ll[:length] } diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames/frames.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames/frames.go index ae21d97f7f4c..370209e7bfc1 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames/frames.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames/frames.go @@ -1,5 +1,6 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package frames import ( @@ -12,6 +13,22 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" ) +// Type contains the values for a frame's type. +type Type uint8 + +const ( + TypeAMQP Type = 0x0 + TypeSASL Type = 0x1 +) + +// String implements the fmt.Stringer interface for type Type. +func (t Type) String() string { + if t == 0 { + return "AMQP" + } + return "SASL" +} + /* @@ -107,7 +124,7 @@ type Source struct { // distribution-modes. That is, the value MUST be of the same type as // would be valid in a field defined with the following attributes: // type="symbol" multiple="true" requires="distribution-mode" - DynamicNodeProperties map[encoding.Symbol]interface{} // TODO: implement custom type with validation + DynamicNodeProperties map[encoding.Symbol]any // TODO: implement custom type with validation // the distribution mode of the link // @@ -129,7 +146,7 @@ type Source struct { // Indicates the outcome to be used for transfers that have not reached a terminal // state at the receiver when the transfer is settled, including when the source // is destroyed. The value MUST be a valid outcome (e.g., released or rejected). - DefaultOutcome interface{} + DefaultOutcome any // descriptors for the outcomes that can be chosen on this link // @@ -182,7 +199,7 @@ func (s *Source) Unmarshal(r *buffer.Buffer) error { func (s Source) String() string { return fmt.Sprintf("source{Address: %s, Durable: %d, ExpiryPolicy: %s, Timeout: %d, "+ - "Dynamic: %t, DynamicNodeProperties: %v, DistributionMode: %s, Filter: %v, DefaultOutcome: %v"+ + "Dynamic: %t, DynamicNodeProperties: %v, DistributionMode: %s, Filter: %v, DefaultOutcome: %v "+ "Outcomes: %v, Capabilities: %v}", s.Address, s.Durable, @@ -289,7 +306,7 @@ type Target struct { // distribution-modes. That is, the value MUST be of the same type as // would be valid in a field defined with the following attributes: // type="symbol" multiple="true" requires="distribution-mode" - DynamicNodeProperties map[encoding.Symbol]interface{} // TODO: implement custom type with validation + DynamicNodeProperties map[encoding.Symbol]any // TODO: implement custom type with validation // the extension capabilities the sender supports/desires // @@ -336,7 +353,7 @@ func (t Target) String() string { // frame is the decoded representation of a frame type Frame struct { - Type uint8 // AMQP/SASL + Type Type // AMQP/SASL Channel uint16 // channel this frame is for Body FrameBody // body of the frame @@ -344,6 +361,11 @@ type Frame struct { Done chan encoding.DeliveryState } +// String implements the fmt.Stringer interface for type Frame. +func (f Frame) String() string { + return fmt.Sprintf("Frame{Type: %s, Channel: %d, Body: %s}", f.Type, f.Channel, f.Body) +} + // frameBody adds some type safety to frame encoding type FrameBody interface { frameBody() @@ -375,7 +397,7 @@ type PerformOpen struct { IncomingLocales encoding.MultiSymbol OfferedCapabilities encoding.MultiSymbol DesiredCapabilities encoding.MultiSymbol - Properties map[encoding.Symbol]interface{} + Properties map[encoding.Symbol]any } func (o *PerformOpen) frameBody() {} @@ -479,7 +501,7 @@ type PerformBegin struct { // session properties // http://www.amqp.org/specification/1.0/session-properties - Properties map[encoding.Symbol]interface{} + Properties map[encoding.Symbol]any } func (b *PerformBegin) frameBody() {} @@ -681,7 +703,7 @@ type PerformAttach struct { // link properties // http://www.amqp.org/specification/1.0/link-properties - Properties map[encoding.Symbol]interface{} + Properties map[encoding.Symbol]any } func (a *PerformAttach) frameBody() {} @@ -859,7 +881,7 @@ type PerformFlow struct { // link state properties // http://www.amqp.org/specification/1.0/link-state-properties - Properties map[encoding.Symbol]interface{} + Properties map[encoding.Symbol]any } func (f *PerformFlow) frameBody() {} @@ -1087,7 +1109,7 @@ func (t *PerformTransfer) frameBody() {} func (t PerformTransfer) String() string { deliveryTag := "" if t.DeliveryTag != nil { - deliveryTag = fmt.Sprintf("%q", t.DeliveryTag) + deliveryTag = fmt.Sprintf("%X", t.DeliveryTag) } return fmt.Sprintf("Transfer{Handle: %d, DeliveryID: %s, DeliveryTag: %s, MessageFormat: %s, "+ @@ -1207,7 +1229,7 @@ type PerformDisposition struct { func (d *PerformDisposition) frameBody() {} func (d PerformDisposition) String() string { - return fmt.Sprintf("Disposition{Role: %s, First: %d, Last: %s, Settled: %t, State: %s, Batchable: %t}", + return fmt.Sprintf("Disposition{Role: %s, First: %d, Last: %s, Settled: %t, State: %v, Batchable: %t}", d.Role, d.First, formatUint32Ptr(d.Last), @@ -1307,6 +1329,10 @@ type PerformEnd struct { func (e *PerformEnd) frameBody() {} +func (d PerformEnd) String() string { + return fmt.Sprintf("End{Error: %v}", d.Error) +} + func (e *PerformEnd) Marshal(wr *buffer.Buffer) error { return encoding.MarshalComposite(wr, encoding.TypeCodeEnd, []encoding.MarshalField{ {Value: e.Error, Omit: e.Error == nil}, diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames/parsing.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames/parsing.go index eef0fbf12f66..0e03a52a23e3 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames/parsing.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames/parsing.go @@ -1,11 +1,13 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package frames import ( "encoding/binary" "errors" "fmt" + "math" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/buffer" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" @@ -128,3 +130,33 @@ func ParseBody(r *buffer.Buffer) (FrameBody, error) { return nil, fmt.Errorf("unknown performative type %02x", pType) } } + +// Write encodes fr into buf. +// split out from conn.WriteFrame for testing purposes. +func Write(buf *buffer.Buffer, fr Frame) error { + // write header + buf.Append([]byte{ + 0, 0, 0, 0, // size, overwrite later + 2, // doff, see frameHeader.DataOffset comment + uint8(fr.Type), // frame type + }) + buf.AppendUint16(fr.Channel) // channel + + // write AMQP frame body + err := encoding.Marshal(buf, fr.Body) + if err != nil { + return err + } + + // validate size + if uint(buf.Len()) > math.MaxUint32 { + return errors.New("frame too large") + } + + // retrieve raw bytes + bufBytes := buf.Bytes() + + // write correct size + binary.BigEndian.PutUint32(bufBytes, uint32(len(bufBytes))) + return nil +} diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/log/log.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/log/log.go deleted file mode 100644 index 9d28348a5ace..000000000000 --- a/sdk/messaging/azeventhubs/internal/go-amqp/internal/log/log.go +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (C) 2017 Kale Blankenship -// Portions Copyright (c) Microsoft Corporation - -//go:build !debug -// +build !debug - -package log - -// dummy functions used when debugging is not enabled - -func Debug(_ int, _ string, _ ...interface{}) {} diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/log/log_debug.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/log/log_debug.go deleted file mode 100644 index 73e79f9fe075..000000000000 --- a/sdk/messaging/azeventhubs/internal/go-amqp/internal/log/log_debug.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (C) 2017 Kale Blankenship -// Portions Copyright (c) Microsoft Corporation - -//go:build debug -// +build debug - -package log - -import "log" -import "os" -import "strconv" - -var ( - debugLevel = 1 - logger = log.New(os.Stderr, "", log.Lmicroseconds) -) - -func init() { - level, err := strconv.Atoi(os.Getenv("DEBUG_LEVEL")) - if err != nil { - return - } - - debugLevel = level -} - -func Debug(level int, format string, v ...interface{}) { - if level <= debugLevel { - logger.Printf(format, v...) - } -} diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/queue/queue.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/queue/queue.go new file mode 100644 index 000000000000..45d6f5af9daf --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/go-amqp/internal/queue/queue.go @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation + +package queue + +import ( + "container/ring" +) + +// Holder provides synchronized access to a *Queue[T]. +type Holder[T any] struct { + // these channels work in tandem to provide exclusive access to the underlying *Queue[T]. + // each channel is created with a buffer size of one. + // empty behaves like a mutex when there's one or more messages in the queue. + // populated is like a semaphore when the queue is empty. + // the *Queue[T] is only ever in one channel. which channel depends on if it contains any items. + // the initial state is for empty to contain an empty queue. + empty chan *Queue[T] + populated chan *Queue[T] +} + +// NewHolder creates a new Holder[T] that contains the provided *Queue[T]. +func NewHolder[T any](q *Queue[T]) *Holder[T] { + h := &Holder[T]{ + empty: make(chan *Queue[T], 1), + populated: make(chan *Queue[T], 1), + } + h.Release(q) + return h +} + +// Acquire attempts to acquire the *Queue[T]. If the *Queue[T] has already been acquired the call blocks. +// When the *Queue[T] is no longer required, you MUST call Release() to relinquish acquisition. +func (h *Holder[T]) Acquire() *Queue[T] { + // the queue will be in only one of the channels, it doesn't matter which one + var q *Queue[T] + select { + case q = <-h.empty: + // empty queue + case q = <-h.populated: + // populated queue + } + return q +} + +// Wait returns a channel that's signaled when the *Queue[T] contains at least one item. +// When the *Queue[T] is no longer required, you MUST call Release() to relinquish acquisition. +func (h *Holder[T]) Wait() <-chan *Queue[T] { + return h.populated +} + +// Release returns the *Queue[T] back to the Holder[T]. +// Once the *Queue[T] has been released, it is no longer safe to call its methods. +func (h *Holder[T]) Release(q *Queue[T]) { + if q.Len() == 0 { + h.empty <- q + } else { + h.populated <- q + } +} + +// Len returns the length of the *Queue[T]. +func (h *Holder[T]) Len() int { + msgLen := 0 + select { + case q := <-h.empty: + h.empty <- q + case q := <-h.populated: + msgLen = q.Len() + h.populated <- q + } + return msgLen +} + +// Queue[T] is a segmented FIFO queue of Ts. +type Queue[T any] struct { + head *ring.Ring + tail *ring.Ring + size int +} + +// New creates a new instance of Queue[T]. +// - size is the size of each Queue segment +func New[T any](size int) *Queue[T] { + r := &ring.Ring{ + Value: &segment[T]{ + items: make([]*T, size), + }, + } + return &Queue[T]{ + head: r, + tail: r, + } +} + +// Enqueue adds the specified item to the end of the queue. +// If the current segment is full, a new segment is created. +func (q *Queue[T]) Enqueue(item T) { + for { + r := q.tail + seg := r.Value.(*segment[T]) + + if seg.tail < len(seg.items) { + seg.items[seg.tail] = &item + seg.tail++ + q.size++ + return + } + + // segment is full, can we advance? + if next := r.Next(); next != q.head { + q.tail = next + continue + } + + // no, add a new ring + r.Link(&ring.Ring{ + Value: &segment[T]{ + items: make([]*T, len(seg.items)), + }, + }) + + q.tail = r.Next() + } +} + +// Dequeue removes and returns the item from the front of the queue. +func (q *Queue[T]) Dequeue() *T { + r := q.head + seg := r.Value.(*segment[T]) + + if seg.tail == 0 { + // queue is empty + return nil + } + + // remove first item + item := seg.items[seg.head] + seg.items[seg.head] = nil + seg.head++ + q.size-- + + if seg.head == seg.tail { + // segment is now empty, reset indices + seg.head, seg.tail = 0, 0 + + // if we're not at the last ring, advance head to the next one + if q.head != q.tail { + q.head = r.Next() + } + } + + return item +} + +// Len returns the total count of enqueued items. +func (q *Queue[T]) Len() int { + return q.size +} + +type segment[T any] struct { + items []*T + head int + tail int +} diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/internal/shared/shared.go b/sdk/messaging/azeventhubs/internal/go-amqp/internal/shared/shared.go new file mode 100644 index 000000000000..867c1e932bf5 --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/go-amqp/internal/shared/shared.go @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation + +package shared + +import ( + "encoding/base64" + "math/rand" + "sync" + "time" +) + +// lockedRand provides a rand source that is safe for concurrent use. +type lockedRand struct { + mu sync.Mutex + src *rand.Rand +} + +func (r *lockedRand) Read(p []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.src.Read(p) +} + +// package scoped rand source to avoid any issues with seeding +// of the global source. +var pkgRand = &lockedRand{ + src: rand.New(rand.NewSource(time.Now().UnixNano())), +} + +// RandString returns a base64 encoded string of n bytes. +func RandString(n int) string { + b := make([]byte, n) + // from math/rand, cannot fail + _, _ = pkgRand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/link.go b/sdk/messaging/azeventhubs/internal/go-amqp/link.go index ce56dbb89c80..bb1824e79514 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/link.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/link.go @@ -1,53 +1,56 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package amqp import ( - "bytes" "context" "errors" "fmt" "sync" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/buffer" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/queue" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/shared" ) -// link is a unidirectional route. +// linkKey uniquely identifies a link on a connection by name and direction. +// +// A link can be identified uniquely by the ordered tuple +// +// (source-container-id, target-container-id, name) // -// May be used for sending or receiving. +// On a single connection the container ID pairs can be abbreviated +// to a boolean flag indicating the direction of the link. +type linkKey struct { + name string + role encoding.Role // Local role: sender/receiver +} + +// link contains the common state and methods for sending and receiving links type link struct { - Key linkKey // Name and direction - Handle uint32 // our handle - RemoteHandle uint32 // remote's handle - dynamicAddr bool // request a dynamic link address from the server - RX chan frames.FrameBody // sessions sends frames for this link on this channel - Transfers chan frames.PerformTransfer // sender uses to send transfer frames - closeOnce sync.Once // closeOnce protects close from being closed multiple times + key linkKey // Name and direction + handle uint32 // our handle + remoteHandle uint32 // remote's handle + dynamicAddr bool // request a dynamic link address from the server - // close signals the mux to shutdown. This indicates that `Close()` was called on this link. - // NOTE: observers outside of link.go *must only* use the Detached channel to check if the link is unavailable. - // including the close channel will lead to a race condition. - close chan struct{} + // frames destined for this link are added to this queue by Session.muxFrameToLink + rxQ *queue.Holder[frames.FrameBody] - // detached is closed by mux/muxDetach when the link is fully detached. - // This will be initiated if the service sends back an error or requests the link detach. - Detached chan struct{} + // used for gracefully closing link + close chan struct{} // signals a link's mux to shut down; DO NOT use this to check if a link has terminated, use done instead + forceClose chan struct{} // used for forcibly terminate a link if Close() times out/is cancelled + closeOnce *sync.Once // closeOnce protects close from being closed multiple times - detachErrorMu sync.Mutex // protects detachError - detachError *Error // error to send to remote on detach, set by closeWithError - Session *Session // parent session - receiver *Receiver // allows link options to modify Receiver - Source *frames.Source // used for Receiver links - Target *frames.Target // used for Sender links - properties map[encoding.Symbol]interface{} // additional properties sent upon link attach - // Indicates whether we should allow detaches on disposition errors or not. - // Some AMQP servers (like Event Hubs) benefit from keeping the link open on disposition errors - // (for instance, if you're doing many parallel sends over the same link and you get back a - // throttling error, which is not fatal) - detachOnDispositionError bool + done chan struct{} // closed when the link has terminated (mux exited); DO NOT wait on this from within a link's mux() as it will never trigger! + doneErr error // contains the error state returned from Close(); DO NOT TOUCH outside of link.go until done has been closed! + + session *Session // parent session + source *frames.Source // used for Receiver links + target *frames.Target // used for Sender links + properties map[encoding.Symbol]any // additional properties sent upon link attach // "The delivery-count is initialized by the sender when a link endpoint is created, // and is incremented whenever a message is sent. Only the sender MAY independently @@ -55,259 +58,97 @@ type link struct { // value from the sender and any subsequent messages received on the link. Note that, // despite its name, the delivery-count is not a count but a sequence number // initialized at an arbitrary point by the sender." - deliveryCount uint32 - linkCredit uint32 // maximum number of messages allowed between flow updates - SenderSettleMode *SenderSettleMode - ReceiverSettleMode *ReceiverSettleMode - MaxMessageSize uint64 - detachReceived bool - err error // err returned on Close() + deliveryCount uint32 - // message receiving - ReceiverReady chan struct{} // receiver sends on this when mux is paused to indicate it can handle more messages - Messages chan Message // used to send completed messages to receiver - unsettledMessages map[string]struct{} // used to keep track of messages being handled downstream - unsettledMessagesLock sync.RWMutex // lock to protect concurrent access to unsettledMessages - buf buffer.Buffer // buffered bytes for current message - more bool // if true, buf contains a partial message - msg Message // current message being decoded -} + // The current maximum number of messages that can be handled at the receiver endpoint of the link. Only the receiver endpoint + // can independently set this value. The sender endpoint sets this to the last known value seen from the receiver. + linkCredit uint32 -// newSendingLink creates a new sending link and attaches it to the session -func newSendingLink(target string, s *Session, opts *SenderOptions) (*link, error) { - l := &link{ - Key: linkKey{randString(40), encoding.RoleSender}, - Session: s, - close: make(chan struct{}), - Detached: make(chan struct{}), - ReceiverReady: make(chan struct{}, 1), - detachOnDispositionError: true, - Target: &frames.Target{Address: target}, - Source: new(frames.Source), - } + senderSettleMode *SenderSettleMode + receiverSettleMode *ReceiverSettleMode + maxMessageSize uint64 - if opts == nil { - return l, nil - } - - for _, v := range opts.Capabilities { - l.Source.Capabilities = append(l.Source.Capabilities, encoding.Symbol(v)) - } - if opts.Durability > DurabilityUnsettledState { - return nil, fmt.Errorf("invalid Durability %d", opts.Durability) - } - l.Source.Durable = opts.Durability - if opts.DynamicAddress { - l.Target.Address = "" - l.dynamicAddr = opts.DynamicAddress - } - if opts.ExpiryPolicy != "" { - if err := encoding.ValidateExpiryPolicy(opts.ExpiryPolicy); err != nil { - return nil, err - } - l.Source.ExpiryPolicy = opts.ExpiryPolicy - } - l.Source.Timeout = opts.ExpiryTimeout - l.detachOnDispositionError = !opts.IgnoreDispositionErrors - if opts.Name != "" { - l.Key.name = opts.Name - } - if opts.Properties != nil { - l.properties = make(map[encoding.Symbol]interface{}) - for k, v := range opts.Properties { - if k == "" { - return nil, errors.New("link property key must not be empty") - } - l.properties[encoding.Symbol(k)] = v - } - } - if opts.RequestedReceiverSettleMode != nil { - if rsm := *opts.RequestedReceiverSettleMode; rsm > ModeSecond { - return nil, fmt.Errorf("invalid RequestedReceiverSettleMode %d", rsm) - } - l.ReceiverSettleMode = opts.RequestedReceiverSettleMode - } - if opts.SettlementMode != nil { - if ssm := *opts.SettlementMode; ssm > ModeMixed { - return nil, fmt.Errorf("invalid SettlementMode %d", ssm) - } - l.SenderSettleMode = opts.SettlementMode - } - l.Source.Address = opts.SourceAddress - return l, nil -} - -func newReceivingLink(source string, s *Session, r *Receiver, opts *ReceiverOptions) (*link, error) { - l := &link{ - Key: linkKey{randString(40), encoding.RoleReceiver}, - Session: s, - receiver: r, - close: make(chan struct{}), - Detached: make(chan struct{}), - ReceiverReady: make(chan struct{}, 1), - Source: &frames.Source{Address: source}, - Target: new(frames.Target), - } - - if opts == nil { - return l, nil - } - - l.receiver.batching = opts.Batching - if opts.BatchMaxAge > 0 { - l.receiver.batchMaxAge = opts.BatchMaxAge - } - for _, v := range opts.Capabilities { - l.Target.Capabilities = append(l.Target.Capabilities, encoding.Symbol(v)) - } - if opts.Credit > 0 { - l.receiver.maxCredit = opts.Credit - } - if opts.Durability > DurabilityUnsettledState { - return nil, fmt.Errorf("invalid Durability %d", opts.Durability) - } - l.Target.Durable = opts.Durability - if opts.DynamicAddress { - l.Source.Address = "" - l.dynamicAddr = opts.DynamicAddress - } - if opts.ExpiryPolicy != "" { - if err := encoding.ValidateExpiryPolicy(opts.ExpiryPolicy); err != nil { - return nil, err - } - l.Target.ExpiryPolicy = opts.ExpiryPolicy - } - l.Target.Timeout = opts.ExpiryTimeout - if opts.Filters != nil { - l.Source.Filter = make(encoding.Filter) - for _, f := range opts.Filters { - f(l.Source.Filter) - } - } - if opts.ManualCredits { - l.receiver.manualCreditor = &manualCreditor{} - } - if opts.MaxMessageSize > 0 { - l.MaxMessageSize = opts.MaxMessageSize - } - if opts.Name != "" { - l.Key.name = opts.Name - } - if opts.Properties != nil { - l.properties = make(map[encoding.Symbol]interface{}) - for k, v := range opts.Properties { - if k == "" { - return nil, errors.New("link property key must not be empty") - } - l.properties[encoding.Symbol(k)] = v - } - } - if opts.RequestedSenderSettleMode != nil { - if rsm := *opts.RequestedSenderSettleMode; rsm > ModeMixed { - return nil, fmt.Errorf("invalid RequestedSenderSettleMode %d", rsm) - } - l.SenderSettleMode = opts.RequestedSenderSettleMode - } - if opts.SettlementMode != nil { - if rsm := *opts.SettlementMode; rsm > ModeSecond { - return nil, fmt.Errorf("invalid SettlementMode %d", rsm) - } - l.ReceiverSettleMode = opts.SettlementMode - } - l.Target.Address = opts.TargetAddress - return l, nil + closeInProgress bool // indicates that the detach performative has been sent } -// attach sends the Attach performative to establish the link with its parent session. -// this is automatically called by the new*Link constructors. -func (l *link) attach(ctx context.Context, s *Session) error { - // sending unsettled messages when the receiver is in mode-second is currently - // broken and causes a hang after sending, so just disallow it for now. - if l.receiver == nil && senderSettleModeValue(l.SenderSettleMode) != ModeSettled && receiverSettleModeValue(l.ReceiverSettleMode) == ModeSecond { - return errors.New("sender does not support exactly-once guarantee") +func newLink(s *Session, r encoding.Role) link { + l := link{ + key: linkKey{shared.RandString(40), r}, + session: s, + close: make(chan struct{}), + forceClose: make(chan struct{}), + closeOnce: &sync.Once{}, + done: make(chan struct{}), } - isReceiver := l.receiver != nil - - // buffer rx to linkCredit so that conn.mux won't block - // attempting to send to a slow reader - if isReceiver { - if l.receiver.manualCreditor != nil { - l.RX = make(chan frames.FrameBody, l.receiver.maxCredit) - } else { - l.RX = make(chan frames.FrameBody, l.linkCredit) - } + // set the segment size relative to respective window + var segmentSize int + if r == encoding.RoleReceiver { + segmentSize = int(s.incomingWindow) } else { - l.RX = make(chan frames.FrameBody, 1) + segmentSize = int(s.outgoingWindow) } - // request handle from Session.mux - select { - case <-ctx.Done(): - return ctx.Err() - case <-s.done: - return s.err - case s.allocateHandle <- l: - } + l.rxQ = queue.NewHolder(queue.New[frames.FrameBody](segmentSize)) + return l +} - // wait for handle allocation +// waitForFrame waits for an incoming frame to be queued. +// it returns the next frame from the queue, or an error. +// the error is either from the context or session.doneErr. +// not meant for consumption outside of link.go. +func (l *link) waitForFrame(ctx context.Context) (frames.FrameBody, error) { select { case <-ctx.Done(): - // TODO: this _might_ leak l's handle - return ctx.Err() - case <-s.done: - return s.err - case <-l.RX: + return nil, ctx.Err() + case <-l.session.done: + // session has terminated, no need to deallocate in this case + return nil, l.session.doneErr + case q := <-l.rxQ.Wait(): + // frame received + fr := q.Dequeue() + l.rxQ.Release(q) + return *fr, nil } +} - // check for link request error - if l.err != nil { - return l.err +// attach sends the Attach performative to establish the link with its parent session. +// this is automatically called by the new*Link constructors. +func (l *link) attach(ctx context.Context, beforeAttach func(*frames.PerformAttach), afterAttach func(*frames.PerformAttach)) error { + if err := l.session.allocateHandle(l); err != nil { + return err } attach := &frames.PerformAttach{ - Name: l.Key.name, - Handle: l.Handle, - ReceiverSettleMode: l.ReceiverSettleMode, - SenderSettleMode: l.SenderSettleMode, - MaxMessageSize: l.MaxMessageSize, - Source: l.Source, - Target: l.Target, + Name: l.key.name, + Handle: l.handle, + ReceiverSettleMode: l.receiverSettleMode, + SenderSettleMode: l.senderSettleMode, + MaxMessageSize: l.maxMessageSize, + Source: l.source, + Target: l.target, Properties: l.properties, } - if isReceiver { - attach.Role = encoding.RoleReceiver - if attach.Source == nil { - attach.Source = new(frames.Source) - } - attach.Source.Dynamic = l.dynamicAddr - } else { - attach.Role = encoding.RoleSender - if attach.Target == nil { - attach.Target = new(frames.Target) - } - attach.Target.Dynamic = l.dynamicAddr - } + // link-specific configuration of the attach frame + beforeAttach(attach) - // send Attach frame - log.Debug(1, "TX (attachLink): %s", attach) - _ = s.txFrame(attach, nil) + _ = l.session.txFrame(attach, nil) // wait for response - var fr frames.FrameBody - select { - case <-ctx.Done(): - // TODO: this leaks l's handle - return ctx.Err() - case <-s.done: - return s.err - case fr = <-l.RX: + fr, err := l.waitForFrame(ctx) + if err != nil { + l.session.deallocateHandle(l) + return err } - log.Debug(3, "RX (attachLink): %s", fr) + resp, ok := fr.(*frames.PerformAttach) if !ok { - return fmt.Errorf("unexpected attach response: %#v", fr) + debug.Log(1, "RX (link): unexpected attach response frame %T", fr) + if err := l.session.conn.Close(); err != nil { + return err + } + return &ConnError{inner: fmt.Errorf("unexpected attach response: %#v", fr)} } // If the remote encounters an error during the attach it returns an Attach @@ -321,27 +162,26 @@ func (l *link) attach(ctx context.Context, s *Session) error { // http://docs.oasis-open.org/amqp/core/v1.0/csprd01/amqp-core-transport-v1.0-csprd01.html#doc-idp386144 if resp.Source == nil && resp.Target == nil { // wait for detach - select { - case <-ctx.Done(): - // TODO: this leaks l's handle - return ctx.Err() - case <-s.done: - return s.err - case fr = <-l.RX: + fr, err := l.waitForFrame(ctx) + if err != nil { + l.session.deallocateHandle(l) + return err } detach, ok := fr.(*frames.PerformDetach) if !ok { - return fmt.Errorf("unexpected frame while waiting for detach: %#v", fr) + if err := l.session.conn.Close(); err != nil { + return err + } + return &ConnError{inner: fmt.Errorf("unexpected frame while waiting for detach: %#v", fr)} } // send return detach fr = &frames.PerformDetach{ - Handle: l.Handle, + Handle: l.handle, Closed: true, } - log.Debug(1, "TX (attachLink): %s", fr) - _ = s.txFrame(fr, nil) + _ = l.session.txFrame(fr, nil) if detach.Error == nil { return fmt.Errorf("received detach with no error specified") @@ -349,687 +189,133 @@ func (l *link) attach(ctx context.Context, s *Session) error { return detach.Error } - if l.MaxMessageSize == 0 || resp.MaxMessageSize < l.MaxMessageSize { - l.MaxMessageSize = resp.MaxMessageSize + if l.maxMessageSize == 0 || resp.MaxMessageSize < l.maxMessageSize { + l.maxMessageSize = resp.MaxMessageSize } - if isReceiver { - if l.Source == nil { - l.Source = new(frames.Source) - } - // if dynamic address requested, copy assigned name to address - if l.dynamicAddr && resp.Source != nil { - l.Source.Address = resp.Source.Address - } - // deliveryCount is a sequence number, must initialize to sender's initial sequence number - l.deliveryCount = resp.InitialDeliveryCount - // buffer receiver so that link.mux doesn't block - l.Messages = make(chan Message, l.receiver.maxCredit) - l.unsettledMessages = map[string]struct{}{} - // copy the received filter values - if resp.Source != nil { - l.Source.Filter = resp.Source.Filter - } - } else { - if l.Target == nil { - l.Target = new(frames.Target) - } - // if dynamic address requested, copy assigned name to address - if l.dynamicAddr && resp.Target != nil { - l.Target.Address = resp.Target.Address - } - l.Transfers = make(chan frames.PerformTransfer) - } + // link-specific configuration post attach + afterAttach(resp) if err := l.setSettleModes(resp); err != nil { - l.muxDetach() + // close the link as there's a mismatch on requested/supported settlement modes + dr := &frames.PerformDetach{ + Handle: l.handle, + Closed: true, + } + _ = l.session.txFrame(dr, nil) return err } - go l.mux() - return nil } -func (l *link) addUnsettled(msg *Message) { - l.unsettledMessagesLock.Lock() - l.unsettledMessages[string(msg.DeliveryTag)] = struct{}{} - l.unsettledMessagesLock.Unlock() -} - -// DeleteUnsettled removes the message from the map of unsettled messages. -func (l *link) DeleteUnsettled(msg *Message) { - l.unsettledMessagesLock.Lock() - delete(l.unsettledMessages, string(msg.DeliveryTag)) - l.unsettledMessagesLock.Unlock() -} - -func (l *link) countUnsettled() int { - l.unsettledMessagesLock.RLock() - count := len(l.unsettledMessages) - l.unsettledMessagesLock.RUnlock() - return count -} - // setSettleModes sets the settlement modes based on the resp frames.PerformAttach. // // If a settlement mode has been explicitly set locally and it was not honored by the // server an error is returned. func (l *link) setSettleModes(resp *frames.PerformAttach) error { var ( - localRecvSettle = receiverSettleModeValue(l.ReceiverSettleMode) + localRecvSettle = receiverSettleModeValue(l.receiverSettleMode) respRecvSettle = receiverSettleModeValue(resp.ReceiverSettleMode) ) - if l.ReceiverSettleMode != nil && localRecvSettle != respRecvSettle { - return fmt.Errorf("amqp: receiver settlement mode %q requested, received %q from server", l.ReceiverSettleMode, &respRecvSettle) + if l.receiverSettleMode != nil && localRecvSettle != respRecvSettle { + return fmt.Errorf("amqp: receiver settlement mode %q requested, received %q from server", l.receiverSettleMode, &respRecvSettle) } - l.ReceiverSettleMode = &respRecvSettle + l.receiverSettleMode = &respRecvSettle var ( - localSendSettle = senderSettleModeValue(l.SenderSettleMode) + localSendSettle = senderSettleModeValue(l.senderSettleMode) respSendSettle = senderSettleModeValue(resp.SenderSettleMode) ) - if l.SenderSettleMode != nil && localSendSettle != respSendSettle { - return fmt.Errorf("amqp: sender settlement mode %q requested, received %q from server", l.SenderSettleMode, &respSendSettle) - } - l.SenderSettleMode = &respSendSettle - - return nil -} - -// doFlow handles the logical 'flow' event for a link. -// For receivers it will send (if needed) an AMQP flow frame, via `muxFlow`. If a fatal error -// occurs it will be set in `l.err` and 'ok' will be false. -// For senders it will indicate if we should try to send any outgoing transfers (the logical -// equivalent of a flow for a sender) by returning true for 'enableOutgoingTransfers'. -func (l *link) doFlow() (ok bool, enableOutgoingTransfers bool) { - var ( - isReceiver = l.receiver != nil - isSender = !isReceiver - ) - - switch { - // enable outgoing transfers case if sender and credits are available - case isSender && l.linkCredit > 0: - log.Debug(1, "Link Mux isSender: credit: %d, deliveryCount: %d, messages: %d, unsettled: %d", l.linkCredit, l.deliveryCount, len(l.Messages), l.countUnsettled()) - return true, true - - case isReceiver && l.receiver.manualCreditor != nil: - drain, credits := l.receiver.manualCreditor.FlowBits(l.linkCredit) - - if drain || credits > 0 { - log.Debug(1, "FLOW Link Mux (manual): source: %s, inflight: %d, credit: %d, creditsToAdd: %d, drain: %v, deliveryCount: %d, messages: %d, unsettled: %d, maxCredit : %d, settleMode: %s", - l.Source.Address, l.receiver.inFlight.len(), l.linkCredit, credits, drain, l.deliveryCount, len(l.Messages), l.countUnsettled(), l.receiver.maxCredit, l.ReceiverSettleMode.String()) - - // send a flow frame. - l.err = l.muxFlow(credits, drain) - } - - // if receiver && half maxCredits have been processed, send more credits - case isReceiver && l.linkCredit+uint32(l.countUnsettled()) <= l.receiver.maxCredit/2: - log.Debug(1, "FLOW Link Mux half: source: %s, inflight: %d, credit: %d, deliveryCount: %d, messages: %d, unsettled: %d, maxCredit : %d, settleMode: %s", l.Source.Address, l.receiver.inFlight.len(), l.linkCredit, l.deliveryCount, len(l.Messages), l.countUnsettled(), l.receiver.maxCredit, l.ReceiverSettleMode.String()) - - linkCredit := l.receiver.maxCredit - uint32(l.countUnsettled()) - l.err = l.muxFlow(linkCredit, false) - - if l.err != nil { - return false, false - } - - case isReceiver && l.linkCredit == 0: - log.Debug(1, "PAUSE Link Mux pause: inflight: %d, credit: %d, deliveryCount: %d, messages: %d, unsettled: %d, maxCredit : %d, settleMode: %s", l.receiver.inFlight.len(), l.linkCredit, l.deliveryCount, len(l.Messages), l.countUnsettled(), l.receiver.maxCredit, l.ReceiverSettleMode.String()) - } - - return true, false -} - -func (l *link) mux() { - defer l.muxDetach() - -Loop: - for { - var outgoingTransfers chan frames.PerformTransfer - - ok, enableOutgoingTransfers := l.doFlow() - - if !ok { - return - } - - if enableOutgoingTransfers { - outgoingTransfers = l.Transfers - } - - select { - // received frame - case fr := <-l.RX: - l.err = l.muxHandleFrame(fr) - if l.err != nil { - return - } - - // send data - case tr := <-outgoingTransfers: - log.Debug(3, "TX(link): %s", tr) - - // Ensure the session mux is not blocked - for { - select { - case l.Session.txTransfer <- &tr: - // decrement link-credit after entire message transferred - if !tr.More { - l.deliveryCount++ - l.linkCredit-- - // we are the sender and we keep track of the peer's link credit - log.Debug(3, "TX(link): key:%s, decremented linkCredit: %d", l.Key.name, l.linkCredit) - } - continue Loop - case fr := <-l.RX: - l.err = l.muxHandleFrame(fr) - if l.err != nil { - return - } - case <-l.close: - l.err = ErrLinkClosed - return - case <-l.Session.done: - l.err = l.Session.err - return - } - } - - case <-l.ReceiverReady: - continue - case <-l.close: - l.err = ErrLinkClosed - return - case <-l.Session.done: - l.err = l.Session.err - return - } - } -} - -// muxFlow sends tr to the session mux. -// l.linkCredit will also be updated to `linkCredit` -func (l *link) muxFlow(linkCredit uint32, drain bool) error { - var ( - deliveryCount = l.deliveryCount - ) - - log.Debug(3, "link.muxFlow(): len(l.Messages):%d - linkCredit: %d - deliveryCount: %d, inFlight: %d", len(l.Messages), linkCredit, deliveryCount, l.receiver.inFlight.len()) - - fr := &frames.PerformFlow{ - Handle: &l.Handle, - DeliveryCount: &deliveryCount, - LinkCredit: &linkCredit, // max number of messages, - Drain: drain, - } - log.Debug(3, "TX (muxFlow): %s", fr) - - // Update credit. This must happen before entering loop below - // because incoming messages handled while waiting to transmit - // flow increment deliveryCount. This causes the credit to become - // out of sync with the server. - - if !drain { - // if we're draining we don't want to touch our internal credit - we're not changing it so any issued credits - // are still valid until drain completes, at which point they will be naturally zeroed. - l.linkCredit = linkCredit - } - - // Ensure the session mux is not blocked - for { - select { - case l.Session.tx <- fr: - return nil - case fr := <-l.RX: - err := l.muxHandleFrame(fr) - if err != nil { - return err - } - case <-l.close: - return ErrLinkClosed - case <-l.Session.done: - return l.Session.err - } - } -} - -func (l *link) muxReceive(fr frames.PerformTransfer) error { - if !l.more { - // this is the first transfer of a message, - // record the delivery ID, message format, - // and delivery Tag - if fr.DeliveryID != nil { - l.msg.deliveryID = *fr.DeliveryID - } - if fr.MessageFormat != nil { - l.msg.Format = *fr.MessageFormat - } - l.msg.DeliveryTag = fr.DeliveryTag - - // these fields are required on first transfer of a message - if fr.DeliveryID == nil { - return l.closeWithError(&Error{ - Condition: ErrorNotAllowed, - Description: "received message without a delivery-id", - }) - } - if fr.MessageFormat == nil { - return l.closeWithError(&Error{ - Condition: ErrorNotAllowed, - Description: "received message without a message-format", - }) - } - if fr.DeliveryTag == nil { - return l.closeWithError(&Error{ - Condition: ErrorNotAllowed, - Description: "received message without a delivery-tag", - }) - } - } else { - // this is a continuation of a multipart message - // some fields may be omitted on continuation transfers, - // but if they are included they must be consistent - // with the first. - - if fr.DeliveryID != nil && *fr.DeliveryID != l.msg.deliveryID { - msg := fmt.Sprintf( - "received continuation transfer with inconsistent delivery-id: %d != %d", - *fr.DeliveryID, l.msg.deliveryID, - ) - return l.closeWithError(&Error{ - Condition: ErrorNotAllowed, - Description: msg, - }) - } - if fr.MessageFormat != nil && *fr.MessageFormat != l.msg.Format { - msg := fmt.Sprintf( - "received continuation transfer with inconsistent message-format: %d != %d", - *fr.MessageFormat, l.msg.Format, - ) - return l.closeWithError(&Error{ - Condition: ErrorNotAllowed, - Description: msg, - }) - } - if fr.DeliveryTag != nil && !bytes.Equal(fr.DeliveryTag, l.msg.DeliveryTag) { - msg := fmt.Sprintf( - "received continuation transfer with inconsistent delivery-tag: %q != %q", - fr.DeliveryTag, l.msg.DeliveryTag, - ) - return l.closeWithError(&Error{ - Condition: ErrorNotAllowed, - Description: msg, - }) - } - } - - // discard message if it's been aborted - if fr.Aborted { - l.buf.Reset() - l.msg = Message{} - l.more = false - return nil - } - - // ensure maxMessageSize will not be exceeded - if l.MaxMessageSize != 0 && uint64(l.buf.Len())+uint64(len(fr.Payload)) > l.MaxMessageSize { - return l.closeWithError(&Error{ - Condition: ErrorMessageSizeExceeded, - Description: fmt.Sprintf("received message larger than max size of %d", l.MaxMessageSize), - }) - } - - // add the payload the the buffer - l.buf.Append(fr.Payload) - - // mark as settled if at least one frame is settled - l.msg.settled = l.msg.settled || fr.Settled - - // save in-progress status - l.more = fr.More - - if fr.More { - return nil + if l.senderSettleMode != nil && localSendSettle != respSendSettle { + return fmt.Errorf("amqp: sender settlement mode %q requested, received %q from server", l.senderSettleMode, &respSendSettle) } + l.senderSettleMode = &respSendSettle - // last frame in message - err := l.msg.Unmarshal(&l.buf) - if err != nil { - return err - } - log.Debug(1, "deliveryID %d before push to receiver - deliveryCount : %d - linkCredit: %d, len(messages): %d, len(inflight): %d", l.msg.deliveryID, l.deliveryCount, l.linkCredit, len(l.Messages), l.receiver.inFlight.len()) - // send to receiver - if receiverSettleModeValue(l.ReceiverSettleMode) == ModeSecond { - l.addUnsettled(&l.msg) - } - select { - case l.Messages <- l.msg: - // message received - case <-l.Detached: - // link has been detached - return l.err - } - - log.Debug(1, "deliveryID %d after push to receiver - deliveryCount : %d - linkCredit: %d, len(messages): %d, len(inflight): %d", l.msg.deliveryID, l.deliveryCount, l.linkCredit, len(l.Messages), l.receiver.inFlight.len()) - - // reset progress - l.buf.Reset() - l.msg = Message{} - - // decrement link-credit after entire message received - l.deliveryCount++ - l.linkCredit-- - log.Debug(1, "deliveryID %d before exit - deliveryCount : %d - linkCredit: %d, len(messages): %d", l.msg.deliveryID, l.deliveryCount, l.linkCredit, len(l.Messages)) return nil } -// DrainCredit will cause a flow frame with 'drain' set to true when -// the next flow frame is sent in 'mux()'. -// Applicable only when manual credit management has been enabled. -func (l *link) DrainCredit(ctx context.Context) error { - if l.receiver == nil || l.receiver.manualCreditor == nil { - return errors.New("drain can only be used with receiver links using manual credit management") - } - - // cause mux() to check our flow conditions. - select { - case l.ReceiverReady <- struct{}{}: - default: - } - - return l.receiver.manualCreditor.Drain(ctx, l) -} - -// IssueCredit requests additional credits be issued for this link. -// Applicable only when manual credit management has been enabled. -func (l *link) IssueCredit(credit uint32) error { - if l.receiver == nil || l.receiver.manualCreditor == nil { - return errors.New("issueCredit can only be used with receiver links using manual credit management") - } - - if err := l.receiver.manualCreditor.IssueCredit(credit); err != nil { - return err - } - - // cause mux() to check our flow conditions. - select { - case l.ReceiverReady <- struct{}{}: - default: - } - - return nil -} - -func (l *link) detachOnRejectDisp() bool { - // only detach on rejection when no RSM was requested or in ModeFirst. - // if the receiver is in ModeSecond, it will send an explicit rejection disposition - // that we'll have to ack. so in that case, we don't treat it as a link error. - if l.detachOnDispositionError && (l.receiver == nil && (l.ReceiverSettleMode == nil || *l.ReceiverSettleMode == ModeFirst)) { - return true - } - return false -} - // muxHandleFrame processes fr based on type. func (l *link) muxHandleFrame(fr frames.FrameBody) error { - var ( - isSender = l.receiver == nil - ) - switch fr := fr.(type) { - // message frame - case *frames.PerformTransfer: - log.Debug(3, "RX (muxHandleFrame): %s", fr) - if isSender { - // Senders should never receive transfer frames, but handle it just in case. - return l.closeWithError(&Error{ - Condition: ErrorNotAllowed, - Description: "sender cannot process transfer frame", - }) - } - - return l.muxReceive(*fr) - - // flow control frame - case *frames.PerformFlow: - log.Debug(3, "RX (muxHandleFrame): %s", fr) - if isSender { - linkCredit := *fr.LinkCredit - l.deliveryCount - if fr.DeliveryCount != nil { - // DeliveryCount can be nil if the receiver hasn't processed - // the attach. That shouldn't be the case here, but it's - // what ActiveMQ does. - linkCredit += *fr.DeliveryCount - } - l.linkCredit = linkCredit - } - - if !fr.Echo { - // if the 'drain' flag has been set in the frame sent to the _receiver_ then - // we signal whomever is waiting (the service has seen and acknowledged our drain) - if fr.Drain && l.receiver != nil && l.receiver.manualCreditor != nil { - l.linkCredit = 0 // we have no active credits at this point. - l.receiver.manualCreditor.EndDrain() - } - return nil - } - - var ( - // copy because sent by pointer below; prevent race - linkCredit = l.linkCredit - deliveryCount = l.deliveryCount - ) - - // send flow - // TODO: missing Available and session info - resp := &frames.PerformFlow{ - Handle: &l.Handle, - DeliveryCount: &deliveryCount, - LinkCredit: &linkCredit, // max number of messages - } - log.Debug(1, "TX (muxHandleFrame): %s", resp) - _ = l.Session.txFrame(resp, nil) - - // remote side is closing links case *frames.PerformDetach: - log.Debug(1, "RX (muxHandleFrame): %s", fr) - // don't currently support link detach and reattach if !fr.Closed { - return fmt.Errorf("non-closing detach not supported: %+v", fr) + l.closeWithError(ErrCondNotImplemented, fmt.Sprintf("non-closing detach not supported: %+v", fr)) + return nil } - // set detach received and close link - l.detachReceived = true - - return &DetachError{fr.Error} - - case *frames.PerformDisposition: - log.Debug(3, "RX (muxHandleFrame): %s", fr) + // there are two possibilities: + // - this is the ack to a client-side Close() + // - the peer is closing the link so we must ack - // Unblock receivers waiting for message disposition - if l.receiver != nil { - // bubble disposition error up to the receiver - var dispositionError error - if state, ok := fr.State.(*encoding.StateRejected); ok { - // state.Error isn't required to be filled out. For instance if you dead letter a message - // you will get a rejected response that doesn't contain an error. - if state.Error != nil { - dispositionError = state.Error - } + if l.closeInProgress { + // if the client-side close was initiated due to an error (l.closeWithError) + // then l.doneErr will already be set. in this case, return that error instead + // of an empty LinkError which indicates a clean client-side close. + if l.doneErr != nil { + return l.doneErr } - // removal from the in-flight map will also remove the message from the unsettled map - l.receiver.inFlight.remove(fr.First, fr.Last, dispositionError) - } - - // If sending async and a message is rejected, cause a link error. - // - // This isn't ideal, but there isn't a clear better way to handle it. - if fr, ok := fr.State.(*encoding.StateRejected); ok && l.detachOnRejectDisp() { - return &DetachError{fr.Error} - } - - if fr.Settled { - return nil + return &LinkError{} } - resp := &frames.PerformDisposition{ - Role: encoding.RoleSender, - First: fr.First, - Last: fr.Last, - Settled: true, + dr := &frames.PerformDetach{ + Handle: l.handle, + Closed: true, } - log.Debug(1, "TX (muxHandleFrame): %s", resp) - _ = l.Session.txFrame(resp, nil) + _ = l.session.txFrame(dr, nil) + return &LinkError{RemoteErr: fr.Error} default: - // TODO: evaluate - log.Debug(1, "muxHandleFrame: unexpected frame: %s\n", fr) - } - - return nil -} - -// close closes and requests deletion of the link. -// -// No operations on link are valid after close. -// -// If ctx expires while waiting for servers response, ctx.Err() will be returned. -// The session will continue to wait for the response until the Session or Client -// is closed. -func (l *link) Close(ctx context.Context) error { - l.closeOnce.Do(func() { close(l.close) }) - select { - case <-l.Detached: - // mux exited - case <-ctx.Done(): - return ctx.Err() - } - if l.err == ErrLinkClosed { + debug.Log(1, "RX (link): unexpected frame: %s", fr) + l.closeWithError(ErrCondInternalError, fmt.Sprintf("link received unexpected frame %T", fr)) return nil } - return l.err } -// returns the error passed in -func (l *link) closeWithError(de *Error) error { +// Close closes the Sender and AMQP link. +func (l *link) closeLink(ctx context.Context) error { + var ctxErr error l.closeOnce.Do(func() { - l.detachErrorMu.Lock() - l.detachError = de - l.detachErrorMu.Unlock() close(l.close) - }) - return de -} - -func (l *link) muxDetach() { - defer func() { - // final cleanup and signaling - - // deallocate handle - Loop: - for { - select { - case <-l.RX: - // at this point we shouldn't be receiving any more frames for - // this link. however, if we do, we need to keep the session mux - // unblocked else we deadlock. so just read and discard them. - case l.Session.deallocateHandle <- l: - break Loop - case <-l.Session.done: - if l.err == nil { - l.err = l.Session.err - } - break Loop - } - } - - // unblock any in flight message dispositions - if l.receiver != nil { - l.receiver.inFlight.clear(l.err) - } - - // unblock any pending drain requests - if l.receiver != nil && l.receiver.manualCreditor != nil { - l.receiver.manualCreditor.EndDrain() + select { + case <-l.done: + // mux exited + case <-ctx.Done(): + close(l.forceClose) + ctxErr = ctx.Err() } + }) - // signal that the link mux has exited - close(l.Detached) - }() - - // "A peer closes a link by sending the detach frame with the - // handle for the specified link, and the closed flag set to - // true. The partner will destroy the corresponding link - // endpoint, and reply with its own detach frame with the - // closed flag set to true. - // - // Note that one peer MAY send a closing detach while its - // partner is sending a non-closing detach. In this case, - // the partner MUST signal that it has closed the link by - // reattaching and then sending a closing detach." - - l.detachErrorMu.Lock() - detachError := l.detachError - l.detachErrorMu.Unlock() - - fr := &frames.PerformDetach{ - Handle: l.Handle, - Closed: true, - Error: detachError, + if ctxErr != nil { + return ctxErr } -Loop: - for { - select { - case l.Session.tx <- fr: - // after sending the detach frame, break the read loop - break Loop - case fr := <-l.RX: - // read from link to avoid blocking session.mux - switch fr := fr.(type) { - case *frames.PerformDetach: - if fr.Closed { - l.detachReceived = true - } - case *frames.PerformTransfer: - _ = l.muxReceive(*fr) - } - case <-l.Session.done: - if l.err == nil { - l.err = l.Session.err - } - return - } + var linkErr *LinkError + if errors.As(l.doneErr, &linkErr) && linkErr.inner == nil { + // an empty LinkError means the link was closed by the caller + return nil } + return l.doneErr +} - // don't wait for remote to detach when already - // received or closing due to error - if l.detachReceived || detachError != nil { +// closeWithError initiates closing the link with the specified AMQP error. +// the mux must continue to run until the ack'ing detach is received. +// l.doneErr is populated with a &LinkError{} containing an inner error constructed from the specified values +// - cnd is the AMQP error condition +// - desc is the error description +func (l *link) closeWithError(cnd ErrCond, desc string) { + amqpErr := &Error{Condition: cnd, Description: desc} + if l.closeInProgress { + debug.Log(3, "TX (link) close error already pending, discarding %v", amqpErr) return } - for { - select { - // read from link until detach with Close == true is received - case fr := <-l.RX: - switch fr := fr.(type) { - case *frames.PerformDetach: - if fr.Closed { - return - } - case *frames.PerformTransfer: - _ = l.muxReceive(*fr) - } - - // connection has ended - case <-l.Session.done: - if l.err == nil { - l.err = l.Session.err - } - return - } + dr := &frames.PerformDetach{ + Handle: l.handle, + Closed: true, + Error: amqpErr, } + l.closeInProgress = true + l.doneErr = &LinkError{inner: fmt.Errorf("%s: %s", cnd, desc)} + _ = l.session.txFrame(dr, nil) } diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/link_options.go b/sdk/messaging/azeventhubs/internal/go-amqp/link_options.go index 465dffd1c705..c4ba797007db 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/link_options.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/link_options.go @@ -1,15 +1,14 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package amqp import ( - "time" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" ) type SenderOptions struct { - // Capabilities is the list of extension capabilities the sender supports/desires. + // Capabilities is the list of extension capabilities the sender supports. Capabilities []string // Durability indicates what state of the sender will be retained durably. @@ -35,11 +34,6 @@ type SenderOptions struct { // Default: 0. ExpiryTimeout uint32 - // IgnoreDispositionErrors controls automatic detach on disposition errors. - // - // Default: false. - IgnoreDispositionErrors bool - // Name sets the name of the link. // // Link names must be unique per-connection and direction. @@ -48,7 +42,7 @@ type SenderOptions struct { Name string // Properties sets an entry in the link properties map sent to the server. - Properties map[string]interface{} + Properties map[string]any // RequestedReceiverSettleMode sets the requested receiver settlement mode. // @@ -65,34 +59,46 @@ type SenderOptions struct { // SourceAddress specifies the source address for this sender. SourceAddress string -} -type ReceiverOptions struct { - // LinkBatching toggles batching of message disposition. - // - // When enabled, accepting a message does not send the disposition - // to the server until the batch is equal to link credit or the - // batch max age expires. + // TargetCapabilities is the list of extension capabilities the sender desires. + TargetCapabilities []string + + // TargetDurability indicates what state of the peer will be retained durably. // - // Default: false. - Batching bool + // Default: DurabilityNone. + TargetDurability Durability - // BatchMaxAge sets the maximum time between the start - // of a disposition batch and sending the batch to the server. + // TargetExpiryPolicy determines when the expiry timer of the peer starts counting + // down from the timeout value. If the link is subsequently re-attached before + // the timeout is reached, the count down is aborted. // - // Has no effect when Batching is false. + // Default: ExpirySessionEnd. + TargetExpiryPolicy ExpiryPolicy + + // TargetExpiryTimeout is the duration in seconds that the peer will be retained. // - // Default: 5 seconds. - BatchMaxAge time.Duration + // Default: 0. + TargetExpiryTimeout uint32 +} - // Capabilities is the list of extension capabilities the receiver supports/desires. +type ReceiverOptions struct { + // Capabilities is the list of extension capabilities the receiver supports. Capabilities []string // Credit specifies the maximum number of unacknowledged messages - // the sender can transmit. + // the sender can transmit. Once this limit is reached, no more messages + // will arrive until messages are acknowledged and settled. + // + // As messages are settled, any available credit will automatically be issued. + // + // Setting this to -1 requires manual management of link credit. + // Credits can be added with IssueCredit(), and links can also be + // drained with DrainCredit(). + // This should only be enabled when complete control of the link's + // flow control is required. // // Default: 1. - Credit uint32 + Credit int32 // Durability indicates what state of the receiver will be retained durably. // @@ -121,11 +127,6 @@ type ReceiverOptions struct { // If the peer cannot fulfill the filters the link will be detached. Filters []LinkFilter - // ManualCredits enables manual credit management for this link. - // Credits can be added with IssueCredit(), and links can also be - // drained with DrainCredit(). - ManualCredits bool - // MaxMessageSize sets the maximum message size that can // be received on the link. // @@ -142,7 +143,7 @@ type ReceiverOptions struct { Name string // Properties sets an entry in the link properties map sent to the server. - Properties map[string]interface{} + Properties map[string]any // RequestedSenderSettleMode sets the requested sender settlement mode. // @@ -159,6 +160,26 @@ type ReceiverOptions struct { // TargetAddress specifies the target address for this receiver. TargetAddress string + + // SourceCapabilities is the list of extension capabilities the receiver desires. + SourceCapabilities []string + + // SourceDurability indicates what state of the peer will be retained durably. + // + // Default: DurabilityNone. + SourceDurability Durability + + // SourceExpiryPolicy determines when the expiry timer of the peer starts counting + // down from the timeout value. If the link is subsequently re-attached before + // the timeout is reached, the count down is aborted. + // + // Default: ExpirySessionEnd. + SourceExpiryPolicy ExpiryPolicy + + // SourceExpiryTimeout is the duration in seconds that the peer will be retained. + // + // Default: 0. + SourceExpiryTimeout uint32 } // LinkFilter is an advanced API for setting non-standard source filters. @@ -191,10 +212,11 @@ type ReceiverOptions struct { // http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-types-v1.0-os.html#section-descriptor-values type LinkFilter func(encoding.Filter) -// LinkFilterSource creates or updates the named filter for this LinkFilter. -func LinkFilterSource(name string, code uint64, value interface{}) LinkFilter { +// NewLinkFilter creates a new LinkFilter with the specified values. +// Any preexisting link filter with the same name will be updated with the new code and value. +func NewLinkFilter(name string, code uint64, value any) LinkFilter { return func(f encoding.Filter) { - var descriptor interface{} + var descriptor any if code != 0 { descriptor = code } else { @@ -207,9 +229,10 @@ func LinkFilterSource(name string, code uint64, value interface{}) LinkFilter { } } -// LinkFilterSelector creates or updates the selector filter (apache.org:selector-filter:string) for this LinkFilter. -func LinkFilterSelector(filter string) LinkFilter { - return LinkFilterSource(selectorFilter, selectorFilterCode, filter) +// NewSelectorFilter creates a new selector filter (apache.org:selector-filter:string) with the specified filter value. +// Any preexisting selector filter will be updated with the new filter value. +func NewSelectorFilter(filter string) LinkFilter { + return NewLinkFilter(selectorFilter, selectorFilterCode, filter) } const ( diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/message.go b/sdk/messaging/azeventhubs/internal/go-amqp/message.go index eed1b45cbbb5..a19892fbdfc5 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/message.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/message.go @@ -1,5 +1,6 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package amqp import ( @@ -34,7 +35,7 @@ type Message struct { // The delivery-annotations section is used for delivery-specific non-standard // properties at the head of the message. Delivery annotations convey information // from the sending peer to the receiving peer. - DeliveryAnnotations encoding.Annotations + DeliveryAnnotations Annotations // If the recipient does not understand the annotation it cannot be acted upon // and its effects (such as any implied propagation) cannot be acted upon. // Annotations might be specific to one implementation, or common to multiple @@ -50,7 +51,7 @@ type Message struct { // The message-annotations section is used for properties of the message which // are aimed at the infrastructure. - Annotations encoding.Annotations + Annotations Annotations // The message-annotations section is used for properties of the message which // are aimed at the infrastructure and SHOULD be propagated across every // delivery step. Message annotations convey information about the message. @@ -78,7 +79,7 @@ type Message struct { // The application-properties section is a part of the bare message used for // structured application data. Intermediaries can use the data within this // structure for the purposes of filtering or routing. - ApplicationProperties map[string]interface{} + ApplicationProperties map[string]any // The keys of this map are restricted to be of type string (which excludes // the possibility of a null key) and the values are restricted to be of // simple types only, that is, excluding map, list, and array types. @@ -89,26 +90,21 @@ type Message struct { // Value payload. // An amqp-value section contains a single AMQP value. - Value interface{} + Value any // Sequence will contain AMQP sequence sections from the body of the message. // An amqp-sequence section contains an AMQP sequence. - Sequence [][]interface{} + Sequence [][]any // The footer section is used for details about the message or delivery which // can only be calculated or evaluated once the whole bare message has been // constructed or seen (for example message hashes, HMACs, signatures and // encryption details). - Footer encoding.Annotations - - // Mark the message as settled when LinkSenderSettle is ModeMixed. - // - // This field is ignored when LinkSenderSettle is not ModeMixed. - SendSettled bool + Footer Annotations - link *link // the receiving link - deliveryID uint32 // used when sending disposition - settled bool // whether transfer was settled by sender + rcvr *Receiver // the receiving link + deliveryID uint32 // used when sending disposition + settled bool // whether transfer was settled by sender } // NewMessage returns a *Message with data as the payload. @@ -133,8 +129,8 @@ func (m *Message) GetData() []byte { // LinkName returns the receiving link name or the empty string. func (m *Message) LinkName() string { - if m.link != nil { - return m.link.Key.name + if m.rcvr != nil { + return m.rcvr.l.key.name } return "" } @@ -146,10 +142,6 @@ func (m *Message) MarshalBinary() ([]byte, error) { return buf.Detach(), err } -func (m *Message) shouldSendDisposition() bool { - return !m.settled -} - func (m *Message) Marshal(wr *buffer.Buffer) error { if m.Header != nil { err := m.Header.Marshal(wr) @@ -244,7 +236,7 @@ func (m *Message) Unmarshal(r *buffer.Buffer) error { } var ( - section interface{} + section any // section header is read from r before // unmarshaling section is set to true discardHeader = true @@ -283,7 +275,7 @@ func (m *Message) Unmarshal(r *buffer.Buffer) error { case encoding.TypeCodeAMQPSequence: r.Skip(int(headerLength)) - var data []interface{} + var data []any err = encoding.Unmarshal(r, &data) if err != nil { return err @@ -367,7 +359,7 @@ type ( // - amqp.UUID: // - []byte: // - string: - MessageID = interface{} + MessageID = any // AMQPSymbol corresponds to the 'symbol' type in the AMQP spec. // @@ -526,4 +518,5 @@ func (p *MessageProperties) Unmarshal(r *buffer.Buffer) error { // String keys are encoded as AMQP Symbols. type Annotations = encoding.Annotations +// UUID is a 128 bit identifier as defined in RFC 4122. type UUID = encoding.UUID diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/receiver.go b/sdk/messaging/azeventhubs/internal/go-amqp/receiver.go index bf9515579b3d..b6256e3d9a9a 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/receiver.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/receiver.go @@ -1,43 +1,68 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package amqp import ( + "bytes" "context" + "errors" + "fmt" "sync" - "time" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/buffer" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/queue" ) -type messageDisposition struct { - id uint32 - state encoding.DeliveryState -} +// Default link options +const ( + defaultLinkCredit = 1 +) // Receiver receives messages on a single AMQP link. type Receiver struct { - link *link // underlying link - batching bool // enable batching of message dispositions - batchMaxAge time.Duration // maximum time between the start n batch and sending the batch to the server - dispositions chan messageDisposition // message dispositions are sent on this channel when batching is enabled - maxCredit uint32 // maximum allowed inflight messages - inFlight inFlight // used to track message disposition when rcv-settle-mode == second - manualCreditor *manualCreditor // allows for credits to be managed manually (via calls to IssueCredit/DrainCredit) + l link + + // message receiving + receiverReady chan struct{} // receiver sends on this when mux is paused to indicate it can handle more messages + messagesQ *queue.Holder[Message] // used to send completed messages to receiver + + unsettledMessages map[string]struct{} // used to keep track of messages being handled downstream + unsettledMessagesLock sync.RWMutex // lock to protect concurrent access to unsettledMessages + msgBuf buffer.Buffer // buffered bytes for current message + more bool // if true, buf contains a partial message + msg Message // current message being decoded + + settlementCount uint32 // the count of settled messages + settlementCountMu sync.Mutex // must be held when accessing settlementCount + + autoSendFlow bool // automatically send flow frames as credit becomes available + inFlight inFlight // used to track message disposition when rcv-settle-mode == second + creditor creditor // manages credits via calls to IssueCredit/DrainCredit } -// IssueCredit adds credits to be requested in the next flow -// request. +// IssueCredit adds credits to be requested in the next flow request. +// Attempting to issue more credit than the receiver's max credit as +// specified in ReceiverOptions.MaxCredit will result in an error. func (r *Receiver) IssueCredit(credit uint32) error { - return r.link.IssueCredit(credit) -} + if r.autoSendFlow { + return errors.New("issueCredit can only be used with receiver links using manual credit management") + } + + if err := r.creditor.IssueCredit(credit); err != nil { + return err + } -// DrainCredit sets the drain flag on the next flow frame and -// waits for the drain to be acknowledged. -func (r *Receiver) DrainCredit(ctx context.Context) error { - return r.link.DrainCredit(ctx) + // cause mux() to check our flow conditions. + select { + case r.receiverReady <- struct{}{}: + default: + } + + return nil } // Prefetched returns the next message that is stored in the Receiver's @@ -45,85 +70,116 @@ func (r *Receiver) DrainCredit(ctx context.Context) error { // and returns immediately if the prefetch cache is empty. To receive from the // prefetch and wait for messages from the remote Sender use `Receive`. // -// When using ModeSecond, you *must* take an action on the message by calling +// Once a message is received, and if the sender is configured in any mode other +// than SenderSettleModeSettled, you *must* take an action on the message by calling // one of the following: AcceptMessage, RejectMessage, ReleaseMessage, ModifyMessage. -// When using ModeFirst, the message is spontaneously Accepted at reception. func (r *Receiver) Prefetched() *Message { select { - case r.link.ReceiverReady <- struct{}{}: + case r.receiverReady <- struct{}{}: default: } // non-blocking receive to ensure buffered messages are // delivered regardless of whether the link has been closed. - select { - case msg := <-r.link.Messages: - log.Debug(3, "Receive() non blocking %d", msg.deliveryID) - msg.link = r.link - return &msg - default: - // done draining messages + q := r.messagesQ.Acquire() + msg := q.Dequeue() + r.messagesQ.Release(q) + + if msg == nil { return nil } + + debug.Log(3, "RX (Receiver): prefetched delivery ID %d", msg.deliveryID) + msg.rcvr = r + + if msg.settled { + r.onSettlement(1) + } + + return msg +} + +// ReceiveOptions contains any optional values for the Receiver.Receive method. +type ReceiveOptions struct { + // for future expansion } // Receive returns the next message from the sender. -// // Blocks until a message is received, ctx completes, or an error occurs. -// When using ModeSecond, you *must* take an action on the message by calling +// +// Once a message is received, and if the sender is configured in any mode other +// than SenderSettleModeSettled, you *must* take an action on the message by calling // one of the following: AcceptMessage, RejectMessage, ReleaseMessage, ModifyMessage. -// When using ModeFirst, the message is spontaneously Accepted at reception. -func (r *Receiver) Receive(ctx context.Context) (*Message, error) { +func (r *Receiver) Receive(ctx context.Context, opts *ReceiveOptions) (*Message, error) { if msg := r.Prefetched(); msg != nil { return msg, nil } // wait for the next message select { - case msg := <-r.link.Messages: - log.Debug(3, "Receive() blocking %d", msg.deliveryID) - msg.link = r.link - return &msg, nil - case <-r.link.Detached: - return nil, r.link.err + case q := <-r.messagesQ.Wait(): + msg := q.Dequeue() + debug.Assert(msg != nil) + debug.Log(3, "RX (Receiver): received delivery ID %d", msg.deliveryID) + msg.rcvr = r + r.messagesQ.Release(q) + if msg.settled { + r.onSettlement(1) + } + return msg, nil + case <-r.l.done: + // if the link receives messages and is then closed between the above call to r.Prefetched() + // and this select statement, the order of selecting r.messages and r.l.done is undefined. + // however, once r.l.done is closed the link cannot receive any more messages. so be sure to + // drain any that might have trickled in within this window. + if msg := r.Prefetched(); msg != nil { + return msg, nil + } + return nil, r.l.doneErr case <-ctx.Done(): return nil, ctx.Err() } } -// Accept notifies the server that the message has been -// accepted and does not require redelivery. +// Accept notifies the server that the message has been accepted and does not require redelivery. +// - ctx controls waiting for the peer to acknowledge the disposition +// - msg is the message to accept +// +// If the context's deadline expires or is cancelled before the operation +// completes, the message's disposition is in an unknown state. func (r *Receiver) AcceptMessage(ctx context.Context, msg *Message) error { - if !msg.shouldSendDisposition() { - return nil - } return r.messageDisposition(ctx, msg, &encoding.StateAccepted{}) } // Reject notifies the server that the message is invalid. +// - ctx controls waiting for the peer to acknowledge the disposition +// - msg is the message to reject +// - e is an optional rejection error // -// Rejection error is optional. +// If the context's deadline expires or is cancelled before the operation +// completes, the message's disposition is in an unknown state. func (r *Receiver) RejectMessage(ctx context.Context, msg *Message, e *Error) error { - if !msg.shouldSendDisposition() { - return nil - } return r.messageDisposition(ctx, msg, &encoding.StateRejected{Error: e}) } -// Release releases the message back to the server. The message -// may be redelivered to this or another consumer. +// Release releases the message back to the server. The message may be redelivered to this or another consumer. +// - ctx controls waiting for the peer to acknowledge the disposition +// - msg is the message to release +// +// If the context's deadline expires or is cancelled before the operation +// completes, the message's disposition is in an unknown state. func (r *Receiver) ReleaseMessage(ctx context.Context, msg *Message) error { - if !msg.shouldSendDisposition() { - return nil - } return r.messageDisposition(ctx, msg, &encoding.StateReleased{}) } // Modify notifies the server that the message was not acted upon and should be modifed. +// - ctx controls waiting for the peer to acknowledge the disposition +// - msg is the message to modify +// - options contains the optional settings to modify +// +// If the context's deadline expires or is cancelled before the operation +// completes, the message's disposition is in an unknown state. func (r *Receiver) ModifyMessage(ctx context.Context, msg *Message, options *ModifyMessageOptions) error { - if !msg.shouldSendDisposition() { - return nil - } if options == nil { options = &ModifyMessageOptions{} } @@ -153,23 +209,23 @@ type ModifyMessageOptions struct { // Address returns the link's address. func (r *Receiver) Address() string { - if r.link.Source == nil { + if r.l.source == nil { return "" } - return r.link.Source.Address + return r.l.source.Address } // LinkName returns associated link name or an empty string if link is not defined. func (r *Receiver) LinkName() string { - return r.link.Key.name + return r.l.key.name } // LinkSourceFilterValue retrieves the specified link source filter value or nil if it doesn't exist. -func (r *Receiver) LinkSourceFilterValue(name string) interface{} { - if r.link.Source == nil { +func (r *Receiver) LinkSourceFilterValue(name string) any { + if r.l.source == nil { return nil } - filter, ok := r.link.Source.Filter[encoding.Symbol(name)] + filter, ok := r.l.source.Filter[encoding.Symbol(name)] if !ok { return nil } @@ -177,96 +233,13 @@ func (r *Receiver) LinkSourceFilterValue(name string) interface{} { } // Close closes the Receiver and AMQP link. +// - ctx controls waiting for the peer to acknowledge the close // -// If ctx expires while waiting for servers response, ctx.Err() will be returned. -// The session will continue to wait for the response until the Session or Client -// is closed. +// If the context's deadline expires or is cancelled before the operation +// completes, the application can be left in an unknown state, potentially +// resulting in connection errors. func (r *Receiver) Close(ctx context.Context) error { - return r.link.Close(ctx) -} - -func (r *Receiver) dispositionBatcher() { - // batch operations: - // Keep track of the first and last delivery ID, incrementing as - // Accept() is called. After last-first == batchSize, send disposition. - // If Reject()/Release() is called, send one disposition for previously - // accepted, and one for the rejected/released message. If messages are - // accepted out of order, send any existing batch and the current message. - var ( - batchSize = r.maxCredit - batchStarted bool - first uint32 - last uint32 - ) - - // create an unstarted timer - batchTimer := time.NewTimer(1 * time.Minute) - batchTimer.Stop() - defer batchTimer.Stop() - - for { - select { - case msgDis := <-r.dispositions: - - // not accepted or batch out of order - _, isAccept := msgDis.state.(*encoding.StateAccepted) - if !isAccept || (batchStarted && last+1 != msgDis.id) { - // send the current batch, if any - if batchStarted { - lastCopy := last - err := r.sendDisposition(first, &lastCopy, &encoding.StateAccepted{}) - if err != nil { - r.inFlight.remove(first, &lastCopy, err) - } - batchStarted = false - } - - // send the current message - err := r.sendDisposition(msgDis.id, nil, msgDis.state) - if err != nil { - r.inFlight.remove(msgDis.id, nil, err) - } - continue - } - - if batchStarted { - // increment last - last++ - } else { - // start new batch - batchStarted = true - first = msgDis.id - last = msgDis.id - batchTimer.Reset(r.batchMaxAge) - } - - // send batch if current size == batchSize - if last-first+1 >= batchSize { - lastCopy := last - err := r.sendDisposition(first, &lastCopy, &encoding.StateAccepted{}) - if err != nil { - r.inFlight.remove(first, &lastCopy, err) - } - batchStarted = false - if !batchTimer.Stop() { - <-batchTimer.C // batch timer must be drained if stop returns false - } - } - - // maxBatchAge elapsed, send batch - case <-batchTimer.C: - lastCopy := last - err := r.sendDisposition(first, &lastCopy, &encoding.StateAccepted{}) - if err != nil { - r.inFlight.remove(first, &lastCopy, err) - } - batchStarted = false - batchTimer.Stop() - - case <-r.link.Detached: - return - } - } + return r.l.closeLink(ctx) } // sendDisposition sends a disposition frame to the peer @@ -275,50 +248,536 @@ func (r *Receiver) sendDisposition(first uint32, last *uint32, state encoding.De Role: encoding.RoleReceiver, First: first, Last: last, - Settled: r.link.ReceiverSettleMode == nil || *r.link.ReceiverSettleMode == ModeFirst, + Settled: r.l.receiverSettleMode == nil || *r.l.receiverSettleMode == ReceiverSettleModeFirst, State: state, } select { - case <-r.link.Detached: - return r.link.err + case <-r.l.done: + return r.l.doneErr default: - log.Debug(1, "TX (sendDisposition): %s", fr) - return r.link.Session.txFrame(fr, nil) + // TODO: this is racy + return r.l.session.txFrame(fr, nil) } } func (r *Receiver) messageDisposition(ctx context.Context, msg *Message, state encoding.DeliveryState) error { + if msg.settled { + return nil + } + var wait chan error - if r.link.ReceiverSettleMode != nil && *r.link.ReceiverSettleMode == ModeSecond { - log.Debug(3, "RX (messageDisposition): add %d to inflight", msg.deliveryID) + if r.l.receiverSettleMode != nil && *r.l.receiverSettleMode == ReceiverSettleModeSecond { + debug.Log(3, "TX (Receiver): delivery ID %d is in flight", msg.deliveryID) wait = r.inFlight.add(msg.deliveryID) } - if r.batching { - r.dispositions <- messageDisposition{id: msg.deliveryID, state: state} - } else { - err := r.sendDisposition(msg.deliveryID, nil, state) - if err != nil { - return err - } + if err := r.sendDisposition(msg.deliveryID, nil, state); err != nil { + return err } if wait == nil { + // mode first, there will be no settlement ack + r.deleteUnsettled(msg) + r.onSettlement(1) return nil } select { case err := <-wait: + debug.Log(3, "RX (Receiver): delivery ID %d has been settled", msg.deliveryID) // we've received confirmation of disposition - r.link.DeleteUnsettled(msg) + r.deleteUnsettled(msg) + r.onSettlement(1) msg.settled = true return err case <-ctx.Done(): + // didn't receive the ack in the time allotted, leave message as unsettled return ctx.Err() } } +// onSettlement is to be called after message settlement. +// - count is the number of messages that were settled +func (r *Receiver) onSettlement(count uint32) { + if !r.autoSendFlow { + return + } + + r.settlementCountMu.Lock() + r.settlementCount += count + r.settlementCountMu.Unlock() + + select { + case r.receiverReady <- struct{}{}: + // woke up + default: + // wake pending + } +} + +func (r *Receiver) addUnsettled(msg *Message) { + r.unsettledMessagesLock.Lock() + r.unsettledMessages[string(msg.DeliveryTag)] = struct{}{} + r.unsettledMessagesLock.Unlock() +} + +func (r *Receiver) deleteUnsettled(msg *Message) { + r.unsettledMessagesLock.Lock() + delete(r.unsettledMessages, string(msg.DeliveryTag)) + r.unsettledMessagesLock.Unlock() +} + +func (r *Receiver) countUnsettled() int { + r.unsettledMessagesLock.RLock() + count := len(r.unsettledMessages) + r.unsettledMessagesLock.RUnlock() + return count +} + +func newReceiver(source string, session *Session, opts *ReceiverOptions) (*Receiver, error) { + l := newLink(session, encoding.RoleReceiver) + l.source = &frames.Source{Address: source} + l.target = new(frames.Target) + l.linkCredit = defaultLinkCredit + r := &Receiver{ + l: l, + autoSendFlow: true, + receiverReady: make(chan struct{}, 1), + } + + r.messagesQ = queue.NewHolder(queue.New[Message](int(session.incomingWindow))) + + if opts == nil { + return r, nil + } + + for _, v := range opts.Capabilities { + r.l.target.Capabilities = append(r.l.target.Capabilities, encoding.Symbol(v)) + } + if opts.Credit > 0 { + r.l.linkCredit = uint32(opts.Credit) + } else if opts.Credit < 0 { + r.l.linkCredit = 0 + r.autoSendFlow = false + } + if opts.Durability > DurabilityUnsettledState { + return nil, fmt.Errorf("invalid Durability %d", opts.Durability) + } + r.l.target.Durable = opts.Durability + if opts.DynamicAddress { + r.l.source.Address = "" + r.l.dynamicAddr = opts.DynamicAddress + } + if opts.ExpiryPolicy != "" { + if err := encoding.ValidateExpiryPolicy(opts.ExpiryPolicy); err != nil { + return nil, err + } + r.l.target.ExpiryPolicy = opts.ExpiryPolicy + } + r.l.target.Timeout = opts.ExpiryTimeout + if opts.Filters != nil { + r.l.source.Filter = make(encoding.Filter) + for _, f := range opts.Filters { + f(r.l.source.Filter) + } + } + if opts.MaxMessageSize > 0 { + r.l.maxMessageSize = opts.MaxMessageSize + } + if opts.Name != "" { + r.l.key.name = opts.Name + } + if opts.Properties != nil { + r.l.properties = make(map[encoding.Symbol]any) + for k, v := range opts.Properties { + if k == "" { + return nil, errors.New("link property key must not be empty") + } + r.l.properties[encoding.Symbol(k)] = v + } + } + if opts.RequestedSenderSettleMode != nil { + if rsm := *opts.RequestedSenderSettleMode; rsm > SenderSettleModeMixed { + return nil, fmt.Errorf("invalid RequestedSenderSettleMode %d", rsm) + } + r.l.senderSettleMode = opts.RequestedSenderSettleMode + } + if opts.SettlementMode != nil { + if rsm := *opts.SettlementMode; rsm > ReceiverSettleModeSecond { + return nil, fmt.Errorf("invalid SettlementMode %d", rsm) + } + r.l.receiverSettleMode = opts.SettlementMode + } + r.l.target.Address = opts.TargetAddress + for _, v := range opts.SourceCapabilities { + r.l.source.Capabilities = append(r.l.source.Capabilities, encoding.Symbol(v)) + } + if opts.SourceDurability != DurabilityNone { + r.l.source.Durable = opts.SourceDurability + } + if opts.SourceExpiryPolicy != ExpiryPolicySessionEnd { + r.l.source.ExpiryPolicy = opts.SourceExpiryPolicy + } + if opts.SourceExpiryTimeout != 0 { + r.l.source.Timeout = opts.SourceExpiryTimeout + } + return r, nil +} + +// attach sends the Attach performative to establish the link with its parent session. +// this is automatically called by the new*Link constructors. +func (r *Receiver) attach(ctx context.Context) error { + if err := r.l.attach(ctx, func(pa *frames.PerformAttach) { + pa.Role = encoding.RoleReceiver + if pa.Source == nil { + pa.Source = new(frames.Source) + } + pa.Source.Dynamic = r.l.dynamicAddr + }, func(pa *frames.PerformAttach) { + if r.l.source == nil { + r.l.source = new(frames.Source) + } + // if dynamic address requested, copy assigned name to address + if r.l.dynamicAddr && pa.Source != nil { + r.l.source.Address = pa.Source.Address + } + // deliveryCount is a sequence number, must initialize to sender's initial sequence number + r.l.deliveryCount = pa.InitialDeliveryCount + r.unsettledMessages = map[string]struct{}{} + // copy the received filter values + if pa.Source != nil { + r.l.source.Filter = pa.Source.Filter + } + }); err != nil { + return err + } + + go r.mux() + + return nil +} + +func (r *Receiver) mux() { + defer func() { + // unblock any in flight message dispositions + r.inFlight.clear(r.l.doneErr) + + if !r.autoSendFlow { + // unblock any pending drain requests + r.creditor.EndDrain() + } + + r.l.session.deallocateHandle(&r.l) + close(r.l.done) + }() + + if r.autoSendFlow { + r.l.doneErr = r.muxFlow(r.l.linkCredit, false) + } + + for { + msgLen := r.messagesQ.Len() + + r.settlementCountMu.Lock() + // counter that accumulates the settled delivery count. + // once the threshold has been reached, the counter is + // reset and a flow frame is sent. + previousSettlementCount := r.settlementCount + if previousSettlementCount >= r.l.linkCredit { + r.settlementCount = 0 + } + r.settlementCountMu.Unlock() + + // once we have pending credit equal to or greater than our available credit, reclaim it. + // we do this instead of settlementCount > 0 to prevent flow frames from being too chatty. + // NOTE: we compare the settlementCount against the current link credit instead of some + // fixed threshold to ensure credit is reclaimed in cases where the number of unsettled + // messages remains high for whatever reason. + if r.autoSendFlow && previousSettlementCount > 0 && previousSettlementCount >= r.l.linkCredit { + debug.Log(1, "RX (Receiver) (auto): source: %q, inflight: %d, linkCredit: %d, deliveryCount: %d, messages: %d, unsettled: %d, settlementCount: %d, settleMode: %s", r.l.source.Address, r.inFlight.len(), r.l.linkCredit, r.l.deliveryCount, msgLen, r.countUnsettled(), previousSettlementCount, r.l.receiverSettleMode.String()) + r.l.doneErr = r.creditor.IssueCredit(previousSettlementCount) + } else if r.l.linkCredit == 0 { + debug.Log(1, "RX (Receiver) (pause): source: %q, inflight: %d, linkCredit: %d, deliveryCount: %d, messages: %d, unsettled: %d, settlementCount: %d, settleMode: %s", r.l.source.Address, r.inFlight.len(), r.l.linkCredit, r.l.deliveryCount, msgLen, r.countUnsettled(), previousSettlementCount, r.l.receiverSettleMode.String()) + } + + if r.l.doneErr != nil { + return + } + + drain, credits := r.creditor.FlowBits(r.l.linkCredit) + if drain || credits > 0 { + debug.Log(1, "RX (Receiver) (flow): source: %q, inflight: %d, curLinkCredit: %d, newLinkCredit: %d, drain: %v, deliveryCount: %d, messages: %d, unsettled: %d, settlementCount: %d, settleMode: %s", + r.l.source.Address, r.inFlight.len(), r.l.linkCredit, credits, drain, r.l.deliveryCount, msgLen, r.countUnsettled(), previousSettlementCount, r.l.receiverSettleMode.String()) + + // send a flow frame. + r.l.doneErr = r.muxFlow(credits, drain) + } + + if r.l.doneErr != nil { + return + } + + closed := r.l.close + if r.l.closeInProgress { + // swap out channel so it no longer triggers + closed = nil + } + + select { + case <-r.l.forceClose: + // the call to r.Close() timed out waiting for the ack + r.l.doneErr = &LinkError{inner: errors.New("the receiver was forcibly closed")} + return + + case q := <-r.l.rxQ.Wait(): + // populated queue + fr := *q.Dequeue() + r.l.rxQ.Release(q) + + // if muxHandleFrame returns an error it means the mux must terminate. + // note that in the case of a client-side close due to an error, nil + // is returned in order to keep the mux running to ack the detach frame. + if err := r.muxHandleFrame(fr); err != nil { + r.l.doneErr = err + return + } + + case <-r.receiverReady: + continue + + case <-closed: + if r.l.closeInProgress { + // a client-side close due to protocol error is in progress + continue + } + // receiver is being closed by the client + r.l.closeInProgress = true + fr := &frames.PerformDetach{ + Handle: r.l.handle, + Closed: true, + } + _ = r.l.session.txFrame(fr, nil) + + case <-r.l.session.done: + r.l.doneErr = r.l.session.doneErr + return + } + } +} + +// muxFlow sends tr to the session mux. +// l.linkCredit will also be updated to `linkCredit` +func (r *Receiver) muxFlow(linkCredit uint32, drain bool) error { + var ( + deliveryCount = r.l.deliveryCount + ) + + fr := &frames.PerformFlow{ + Handle: &r.l.handle, + DeliveryCount: &deliveryCount, + LinkCredit: &linkCredit, // max number of messages, + Drain: drain, + } + + // Update credit. This must happen before entering loop below + // because incoming messages handled while waiting to transmit + // flow increment deliveryCount. This causes the credit to become + // out of sync with the server. + + if !drain { + // if we're draining we don't want to touch our internal credit - we're not changing it so any issued credits + // are still valid until drain completes, at which point they will be naturally zeroed. + r.l.linkCredit = linkCredit + } + + select { + case r.l.session.tx <- fr: + debug.Log(2, "TX (Receiver): %s", fr) + return nil + case <-r.l.close: + return nil + case <-r.l.session.done: + return r.l.session.doneErr + } +} + +// muxHandleFrame processes fr based on type. +func (r *Receiver) muxHandleFrame(fr frames.FrameBody) error { + debug.Log(2, "RX (Receiver): %s", fr) + switch fr := fr.(type) { + // message frame + case *frames.PerformTransfer: + r.muxReceive(*fr) + + // flow control frame + case *frames.PerformFlow: + if !fr.Echo { + // if the 'drain' flag has been set in the frame sent to the _receiver_ then + // we signal whomever is waiting (the service has seen and acknowledged our drain) + if fr.Drain && !r.autoSendFlow { + r.l.linkCredit = 0 // we have no active credits at this point. + r.creditor.EndDrain() + } + return nil + } + + var ( + // copy because sent by pointer below; prevent race + linkCredit = r.l.linkCredit + deliveryCount = r.l.deliveryCount + ) + + // send flow + resp := &frames.PerformFlow{ + Handle: &r.l.handle, + DeliveryCount: &deliveryCount, + LinkCredit: &linkCredit, // max number of messages + } + + select { + case r.l.session.tx <- resp: + debug.Log(2, "TX (Sender): %s", resp) + case <-r.l.close: + return nil + case <-r.l.session.done: + return r.l.session.doneErr + } + + case *frames.PerformDisposition: + // Unblock receivers waiting for message disposition + // bubble disposition error up to the receiver + var dispositionError error + if state, ok := fr.State.(*encoding.StateRejected); ok { + // state.Error isn't required to be filled out. For instance if you dead letter a message + // you will get a rejected response that doesn't contain an error. + if state.Error != nil { + dispositionError = state.Error + } + } + // removal from the in-flight map will also remove the message from the unsettled map + r.inFlight.remove(fr.First, fr.Last, dispositionError) + + default: + return r.l.muxHandleFrame(fr) + } + + return nil +} + +func (r *Receiver) muxReceive(fr frames.PerformTransfer) { + if !r.more { + // this is the first transfer of a message, + // record the delivery ID, message format, + // and delivery Tag + if fr.DeliveryID != nil { + r.msg.deliveryID = *fr.DeliveryID + } + if fr.MessageFormat != nil { + r.msg.Format = *fr.MessageFormat + } + r.msg.DeliveryTag = fr.DeliveryTag + + // these fields are required on first transfer of a message + if fr.DeliveryID == nil { + r.l.closeWithError(ErrCondNotAllowed, "received message without a delivery-id") + return + } + if fr.MessageFormat == nil { + r.l.closeWithError(ErrCondNotAllowed, "received message without a message-format") + return + } + if fr.DeliveryTag == nil { + r.l.closeWithError(ErrCondNotAllowed, "received message without a delivery-tag") + return + } + } else { + // this is a continuation of a multipart message + // some fields may be omitted on continuation transfers, + // but if they are included they must be consistent + // with the first. + + if fr.DeliveryID != nil && *fr.DeliveryID != r.msg.deliveryID { + msg := fmt.Sprintf( + "received continuation transfer with inconsistent delivery-id: %d != %d", + *fr.DeliveryID, r.msg.deliveryID, + ) + r.l.closeWithError(ErrCondNotAllowed, msg) + return + } + if fr.MessageFormat != nil && *fr.MessageFormat != r.msg.Format { + msg := fmt.Sprintf( + "received continuation transfer with inconsistent message-format: %d != %d", + *fr.MessageFormat, r.msg.Format, + ) + r.l.closeWithError(ErrCondNotAllowed, msg) + return + } + if fr.DeliveryTag != nil && !bytes.Equal(fr.DeliveryTag, r.msg.DeliveryTag) { + msg := fmt.Sprintf( + "received continuation transfer with inconsistent delivery-tag: %q != %q", + fr.DeliveryTag, r.msg.DeliveryTag, + ) + r.l.closeWithError(ErrCondNotAllowed, msg) + return + } + } + + // discard message if it's been aborted + if fr.Aborted { + r.msgBuf.Reset() + r.msg = Message{} + r.more = false + return + } + + // ensure maxMessageSize will not be exceeded + if r.l.maxMessageSize != 0 && uint64(r.msgBuf.Len())+uint64(len(fr.Payload)) > r.l.maxMessageSize { + r.l.closeWithError(ErrCondMessageSizeExceeded, fmt.Sprintf("received message larger than max size of %d", r.l.maxMessageSize)) + return + } + + // add the payload the the buffer + r.msgBuf.Append(fr.Payload) + + // mark as settled if at least one frame is settled + r.msg.settled = r.msg.settled || fr.Settled + + // save in-progress status + r.more = fr.More + + if fr.More { + return + } + + // last frame in message + err := r.msg.Unmarshal(&r.msgBuf) + if err != nil { + r.l.closeWithError(ErrCondInternalError, err.Error()) + return + } + + // send to receiver + if !r.msg.settled { + r.addUnsettled(&r.msg) + debug.Log(3, "RX (Receiver): add unsettled delivery ID %d", r.msg.deliveryID) + } + + q := r.messagesQ.Acquire() + q.Enqueue(r.msg) + msgLen := q.Len() + r.messagesQ.Release(q) + + // reset progress + r.msgBuf.Reset() + r.msg = Message{} + + // decrement link-credit after entire message received + r.l.deliveryCount++ + r.l.linkCredit-- + debug.Log(3, "RX (Receiver) link %s - deliveryCount: %d, linkCredit: %d, len(messages): %d", r.l.key.name, r.l.deliveryCount, r.l.linkCredit, msgLen) +} + // inFlight tracks in-flight message dispositions allowing receivers // to block waiting for the server to respond when an appropriate // settlement mode is configured. diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/sasl.go b/sdk/messaging/azeventhubs/internal/go-amqp/sasl.go index 4896af839e14..5dbaaae0a69e 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/sasl.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/sasl.go @@ -1,13 +1,14 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package amqp import ( "fmt" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/log" ) // SASL Mechanisms @@ -18,13 +19,8 @@ const ( saslMechanismXOAUTH2 encoding.Symbol = "XOAUTH2" ) -const ( - frameTypeAMQP = 0x0 - frameTypeSASL = 0x1 -) - // SASLType represents a SASL configuration to use during authentication. -type SASLType func(c *conn) error +type SASLType func(c *Conn) error // ConnSASLPlain enables SASL PLAIN authentication for the connection. // @@ -32,7 +28,7 @@ type SASLType func(c *conn) error // on TLS/SSL enabled connection. func SASLTypePlain(username, password string) SASLType { // TODO: how widely used is hostname? should it be supported - return func(c *conn) error { + return func(c *Conn) error { // make handlers map if no other mechanism has if c.saslHandlers == nil { c.saslHandlers = make(map[encoding.Symbol]stateFunc) @@ -46,11 +42,12 @@ func SASLTypePlain(username, password string) SASLType { InitialResponse: []byte("\x00" + username + "\x00" + password), Hostname: "", } - log.Debug(1, "TX (ConnSASLPlain): %s", init) - err := c.writeFrame(frames.Frame{ - Type: frameTypeSASL, + fr := frames.Frame{ + Type: frames.TypeSASL, Body: init, - }) + } + debug.Log(1, "TX (ConnSASLPlain): %s", fr) + err := c.writeFrame(fr) if err != nil { return nil, err } @@ -64,7 +61,7 @@ func SASLTypePlain(username, password string) SASLType { // ConnSASLAnonymous enables SASL ANONYMOUS authentication for the connection. func SASLTypeAnonymous() SASLType { - return func(c *conn) error { + return func(c *Conn) error { // make handlers map if no other mechanism has if c.saslHandlers == nil { c.saslHandlers = make(map[encoding.Symbol]stateFunc) @@ -76,11 +73,12 @@ func SASLTypeAnonymous() SASLType { Mechanism: saslMechanismANONYMOUS, InitialResponse: []byte("anonymous"), } - log.Debug(1, "TX (ConnSASLAnonymous): %s", init) - err := c.writeFrame(frames.Frame{ - Type: frameTypeSASL, + fr := frames.Frame{ + Type: frames.TypeSASL, Body: init, - }) + } + debug.Log(1, "TX (ConnSASLAnonymous): %s", fr) + err := c.writeFrame(fr) if err != nil { return nil, err } @@ -96,7 +94,7 @@ func SASLTypeAnonymous() SASLType { // The value for resp is dependent on the type of authentication (empty string is common for TLS). // See https://datatracker.ietf.org/doc/html/rfc4422#appendix-A for additional info. func SASLTypeExternal(resp string) SASLType { - return func(c *conn) error { + return func(c *Conn) error { // make handlers map if no other mechanism has if c.saslHandlers == nil { c.saslHandlers = make(map[encoding.Symbol]stateFunc) @@ -108,11 +106,12 @@ func SASLTypeExternal(resp string) SASLType { Mechanism: saslMechanismEXTERNAL, InitialResponse: []byte(resp), } - log.Debug(1, "TX (ConnSASLExternal): %s", init) - err := c.writeFrame(frames.Frame{ - Type: frameTypeSASL, + fr := frames.Frame{ + Type: frames.TypeSASL, Body: init, - }) + } + debug.Log(1, "TX (ConnSASLExternal): %s", fr) + err := c.writeFrame(fr) if err != nil { return nil, err } @@ -135,7 +134,7 @@ func SASLTypeExternal(resp string) SASLType { // SASL XOAUTH2 transmits the bearer in plain text and should only be used // on TLS/SSL enabled connection. func SASLTypeXOAUTH2(username, bearer string, saslMaxFrameSizeOverride uint32) SASLType { - return func(c *conn) error { + return func(c *Conn) error { // make handlers map if no other mechanism has if c.saslHandlers == nil { c.saslHandlers = make(map[encoding.Symbol]stateFunc) @@ -158,25 +157,25 @@ func SASLTypeXOAUTH2(username, bearer string, saslMaxFrameSizeOverride uint32) S } type saslXOAUTH2Handler struct { - conn *conn + conn *Conn maxFrameSizeOverride uint32 response []byte errorResponse []byte // https://developers.google.com/gmail/imap/xoauth2-protocol#error_response } func (s saslXOAUTH2Handler) init() (stateFunc, error) { - originalPeerMaxFrameSize := s.conn.PeerMaxFrameSize - if s.maxFrameSizeOverride > s.conn.PeerMaxFrameSize { - s.conn.PeerMaxFrameSize = s.maxFrameSizeOverride + originalPeerMaxFrameSize := s.conn.peerMaxFrameSize + if s.maxFrameSizeOverride > s.conn.peerMaxFrameSize { + s.conn.peerMaxFrameSize = s.maxFrameSizeOverride } err := s.conn.writeFrame(frames.Frame{ - Type: frameTypeSASL, + Type: frames.TypeSASL, Body: &frames.SASLInit{ Mechanism: saslMechanismXOAUTH2, InitialResponse: s.response, }, }) - s.conn.PeerMaxFrameSize = originalPeerMaxFrameSize + s.conn.peerMaxFrameSize = originalPeerMaxFrameSize if err != nil { return nil, err } @@ -208,7 +207,7 @@ func (s saslXOAUTH2Handler) step() (stateFunc, error) { // The SASL protocol requires clients to send an empty response to this challenge. err := s.conn.writeFrame(frames.Frame{ - Type: frameTypeSASL, + Type: frames.TypeSASL, Body: &frames.SASLResponse{ Response: []byte{}, }, diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/sender.go b/sdk/messaging/azeventhubs/internal/go-amqp/sender.go index a1d44c4a20aa..aaf7cea09f02 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/sender.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/sender.go @@ -1,22 +1,25 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package amqp import ( "context" "encoding/binary" + "errors" "fmt" "sync" - "sync/atomic" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/buffer" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames" ) // Sender sends messages on a single AMQP link. type Sender struct { - link *link + l link + transfers chan frames.PerformTransfer // sender uses to send transfer frames mu sync.Mutex // protects buf and nextDeliveryTag buf buffer.Buffer @@ -25,33 +28,49 @@ type Sender struct { // LinkName() is the name of the link used for this Sender. func (s *Sender) LinkName() string { - return s.link.Key.name + return s.l.key.name } // MaxMessageSize is the maximum size of a single message. func (s *Sender) MaxMessageSize() uint64 { - return s.link.MaxMessageSize + return s.l.maxMessageSize +} + +// SendOptions contains any optional values for the Sender.Send method. +type SendOptions struct { + // Indicates the message is to be sent as settled when settlement mode is SenderSettleModeMixed. + // If the settlement mode is SenderSettleModeUnsettled and Settled is true, an error is returned. + Settled bool } // Send sends a Message. // -// Blocks until the message is sent, ctx completes, or an error occurs. +// Blocks until the message is sent or an error occurs. If the peer is +// configured for receiver settlement mode second, the call also blocks +// until the peer confirms message settlement. +// +// - ctx controls waiting for the message to be sent and possibly confirmed +// - msg is the message to send +// - opts contains optional values, pass nil to accept the defaults +// +// If the context's deadline expires or is cancelled before the operation +// completes, the message is in an unknown state of transmission. // // Send is safe for concurrent use. Since only a single message can be // sent on a link at a time, this is most useful when settlement confirmation -// has been requested (receiver settle mode is "Second"). In this case, +// has been requested (receiver settle mode is second). In this case, // additional messages can be sent while the current goroutine is waiting // for the confirmation. -func (s *Sender) Send(ctx context.Context, msg *Message) error { +func (s *Sender) Send(ctx context.Context, msg *Message, opts *SendOptions) error { // check if the link is dead. while it's safe to call s.send // in this case, this will avoid some allocations etc. select { - case <-s.link.Detached: - return s.link.err + case <-s.l.done: + return s.l.doneErr default: // link is still active } - done, err := s.send(ctx, msg) + done, err := s.send(ctx, msg, opts) if err != nil { return err } @@ -60,23 +79,27 @@ func (s *Sender) Send(ctx context.Context, msg *Message) error { select { case state := <-done: if state, ok := state.(*encoding.StateRejected); ok { - if s.link.detachOnRejectDisp() { - return &DetachError{state.Error} + if state.Error != nil { + return state.Error } - return state.Error + return errors.New("the peer rejected the message without specifying an error") } return nil - case <-s.link.Detached: - return s.link.err + case <-s.l.done: + return s.l.doneErr case <-ctx.Done(): + // TODO: if the message is not settled and we never received a disposition, how can we consider the message as sent? return ctx.Err() } } // send is separated from Send so that the mutex unlock can be deferred without // locking the transfer confirmation that happens in Send. -func (s *Sender) send(ctx context.Context, msg *Message) (chan encoding.DeliveryState, error) { - const maxDeliveryTagLength = 32 +func (s *Sender) send(ctx context.Context, msg *Message, opts *SendOptions) (chan encoding.DeliveryState, error) { + const ( + maxDeliveryTagLength = 32 + maxTransferFrameHeader = 66 // determined by calcMaxTransferFrameHeader + ) if len(msg.DeliveryTag) > maxDeliveryTagLength { return nil, fmt.Errorf("delivery tag is over the allowed %v bytes, len: %v", maxDeliveryTagLength, len(msg.DeliveryTag)) } @@ -90,15 +113,21 @@ func (s *Sender) send(ctx context.Context, msg *Message) (chan encoding.Delivery return nil, err } - if s.link.MaxMessageSize != 0 && uint64(s.buf.Len()) > s.link.MaxMessageSize { - return nil, fmt.Errorf("encoded message size exceeds max of %d", s.link.MaxMessageSize) + if s.l.maxMessageSize != 0 && uint64(s.buf.Len()) > s.l.maxMessageSize { + return nil, fmt.Errorf("encoded message size exceeds max of %d", s.l.maxMessageSize) + } + + senderSettled := senderSettleModeValue(s.l.senderSettleMode) == SenderSettleModeSettled + if opts != nil { + if opts.Settled && senderSettleModeValue(s.l.senderSettleMode) == SenderSettleModeUnsettled { + return nil, errors.New("can't send message as settled when sender settlement mode is unsettled") + } else if opts.Settled { + senderSettled = true + } } var ( - maxPayloadSize = int64(s.link.Session.conn.PeerMaxFrameSize) - maxTransferFrameHeader - sndSettleMode = s.link.SenderSettleMode - senderSettled = sndSettleMode != nil && (*sndSettleMode == ModeSettled || (*sndSettleMode == ModeMixed && msg.SendSettled)) - deliveryID = atomic.AddUint32(&s.link.Session.nextDeliveryID, 1) + maxPayloadSize = int64(s.l.session.conn.peerMaxFrameSize) - maxTransferFrameHeader ) deliveryTag := msg.DeliveryTag @@ -110,8 +139,8 @@ func (s *Sender) send(ctx context.Context, msg *Message) (chan encoding.Delivery } fr := frames.PerformTransfer{ - Handle: s.link.Handle, - DeliveryID: &deliveryID, + Handle: s.l.handle, + DeliveryID: &needsDeliveryID, DeliveryTag: deliveryTag, MessageFormat: &msg.Format, More: s.buf.Len() > 0, @@ -135,9 +164,10 @@ func (s *Sender) send(ctx context.Context, msg *Message) (chan encoding.Delivery } select { - case s.link.Transfers <- fr: - case <-s.link.Detached: - return nil, s.link.err + case s.transfers <- fr: + // frame was sent to our mux + case <-s.l.done: + return nil, s.l.doneErr case <-ctx.Done(): return nil, ctx.Err() } @@ -153,13 +183,275 @@ func (s *Sender) send(ctx context.Context, msg *Message) (chan encoding.Delivery // Address returns the link's address. func (s *Sender) Address() string { - if s.link.Target == nil { + if s.l.target == nil { return "" } - return s.link.Target.Address + return s.l.target.Address } // Close closes the Sender and AMQP link. +// - ctx controls waiting for the peer to acknowledge the close +// +// If the context's deadline expires or is cancelled before the operation +// completes, the application can be left in an unknown state, potentially +// resulting in connection errors. func (s *Sender) Close(ctx context.Context) error { - return s.link.Close(ctx) + return s.l.closeLink(ctx) +} + +// newSendingLink creates a new sending link and attaches it to the session +func newSender(target string, session *Session, opts *SenderOptions) (*Sender, error) { + l := newLink(session, encoding.RoleSender) + l.target = &frames.Target{Address: target} + l.source = new(frames.Source) + s := &Sender{ + l: l, + } + + if opts == nil { + return s, nil + } + + for _, v := range opts.Capabilities { + s.l.source.Capabilities = append(s.l.source.Capabilities, encoding.Symbol(v)) + } + if opts.Durability > DurabilityUnsettledState { + return nil, fmt.Errorf("invalid Durability %d", opts.Durability) + } + s.l.source.Durable = opts.Durability + if opts.DynamicAddress { + s.l.target.Address = "" + s.l.dynamicAddr = opts.DynamicAddress + } + if opts.ExpiryPolicy != "" { + if err := encoding.ValidateExpiryPolicy(opts.ExpiryPolicy); err != nil { + return nil, err + } + s.l.source.ExpiryPolicy = opts.ExpiryPolicy + } + s.l.source.Timeout = opts.ExpiryTimeout + if opts.Name != "" { + s.l.key.name = opts.Name + } + if opts.Properties != nil { + s.l.properties = make(map[encoding.Symbol]any) + for k, v := range opts.Properties { + if k == "" { + return nil, errors.New("link property key must not be empty") + } + s.l.properties[encoding.Symbol(k)] = v + } + } + if opts.RequestedReceiverSettleMode != nil { + if rsm := *opts.RequestedReceiverSettleMode; rsm > ReceiverSettleModeSecond { + return nil, fmt.Errorf("invalid RequestedReceiverSettleMode %d", rsm) + } + s.l.receiverSettleMode = opts.RequestedReceiverSettleMode + } + if opts.SettlementMode != nil { + if ssm := *opts.SettlementMode; ssm > SenderSettleModeMixed { + return nil, fmt.Errorf("invalid SettlementMode %d", ssm) + } + s.l.senderSettleMode = opts.SettlementMode + } + s.l.source.Address = opts.SourceAddress + for _, v := range opts.TargetCapabilities { + s.l.target.Capabilities = append(s.l.target.Capabilities, encoding.Symbol(v)) + } + if opts.TargetDurability != DurabilityNone { + s.l.target.Durable = opts.TargetDurability + } + if opts.TargetExpiryPolicy != ExpiryPolicySessionEnd { + s.l.target.ExpiryPolicy = opts.TargetExpiryPolicy + } + if opts.TargetExpiryTimeout != 0 { + s.l.target.Timeout = opts.TargetExpiryTimeout + } + return s, nil +} + +func (s *Sender) attach(ctx context.Context) error { + if err := s.l.attach(ctx, func(pa *frames.PerformAttach) { + pa.Role = encoding.RoleSender + if pa.Target == nil { + pa.Target = new(frames.Target) + } + pa.Target.Dynamic = s.l.dynamicAddr + }, func(pa *frames.PerformAttach) { + if s.l.target == nil { + s.l.target = new(frames.Target) + } + + // if dynamic address requested, copy assigned name to address + if s.l.dynamicAddr && pa.Target != nil { + s.l.target.Address = pa.Target.Address + } + }); err != nil { + return err + } + + s.transfers = make(chan frames.PerformTransfer) + + go s.mux() + + return nil +} + +func (s *Sender) mux() { + defer func() { + s.l.session.deallocateHandle(&s.l) + close(s.l.done) + }() + +Loop: + for { + var outgoingTransfers chan frames.PerformTransfer + if s.l.linkCredit > 0 { + debug.Log(1, "TX (Sender) (enable): target: %q, link credit: %d, deliveryCount: %d", s.l.target.Address, s.l.linkCredit, s.l.deliveryCount) + outgoingTransfers = s.transfers + } else { + debug.Log(1, "TX (Sender) (pause): target: %q, link credit: %d, deliveryCount: %d", s.l.target.Address, s.l.linkCredit, s.l.deliveryCount) + } + + closed := s.l.close + if s.l.closeInProgress { + // swap out channel so it no longer triggers + closed = nil + } + + select { + case <-s.l.forceClose: + // the call to s.Close() timed out waiting for the ack + s.l.doneErr = &LinkError{inner: errors.New("the sender was forcibly closed")} + return + + // received frame + case q := <-s.l.rxQ.Wait(): + // populated queue + fr := *q.Dequeue() + s.l.rxQ.Release(q) + + // if muxHandleFrame returns an error it means the mux must terminate. + // note that in the case of a client-side close due to an error, nil + // is returned in order to keep the mux running to ack the detach frame. + if err := s.muxHandleFrame(fr); err != nil { + s.l.doneErr = err + return + } + + // send data + case tr := <-outgoingTransfers: + select { + case s.l.session.txTransfer <- &tr: + debug.Log(2, "TX (Sender): mux transfer to Session: %d, %s", s.l.session.channel, tr) + // decrement link-credit after entire message transferred + if !tr.More { + s.l.deliveryCount++ + s.l.linkCredit-- + // we are the sender and we keep track of the peer's link credit + debug.Log(3, "TX (Sender): link: %s, link credit: %d", s.l.key.name, s.l.linkCredit) + } + continue Loop + case <-s.l.close: + continue Loop + case <-s.l.session.done: + continue Loop + } + + case <-closed: + if s.l.closeInProgress { + // a client-side close due to protocol error is in progress + continue + } + // sender is being closed by the client + s.l.closeInProgress = true + fr := &frames.PerformDetach{ + Handle: s.l.handle, + Closed: true, + } + _ = s.l.session.txFrame(fr, nil) + + case <-s.l.session.done: + // TODO: per spec, if the session has terminated, we're not allowed to send frames + s.l.doneErr = s.l.session.doneErr + return + } + } +} + +// muxHandleFrame processes fr based on type. +// depending on the peer's RSM, it might return a disposition frame for sending +func (s *Sender) muxHandleFrame(fr frames.FrameBody) error { + debug.Log(2, "RX (Sender): %s", fr) + switch fr := fr.(type) { + // flow control frame + case *frames.PerformFlow: + // the sender's link-credit variable MUST be set according to this formula when flow information is given by the receiver: + // link-credit(snd) := delivery-count(rcv) + link-credit(rcv) - delivery-count(snd) + linkCredit := *fr.LinkCredit - s.l.deliveryCount + if fr.DeliveryCount != nil { + // DeliveryCount can be nil if the receiver hasn't processed + // the attach. That shouldn't be the case here, but it's + // what ActiveMQ does. + linkCredit += *fr.DeliveryCount + } + + s.l.linkCredit = linkCredit + + if !fr.Echo { + return nil + } + + var ( + // copy because sent by pointer below; prevent race + deliveryCount = s.l.deliveryCount + ) + + // send flow + resp := &frames.PerformFlow{ + Handle: &s.l.handle, + DeliveryCount: &deliveryCount, + LinkCredit: &linkCredit, // max number of messages + } + + select { + case s.l.session.tx <- resp: + debug.Log(2, "TX (Sender): %s", resp) + case <-s.l.close: + return nil + case <-s.l.session.done: + return s.l.session.doneErr + } + + case *frames.PerformDisposition: + if fr.Settled { + return nil + } + + // peer is in mode second, so we must send confirmation of disposition. + // NOTE: the ack must be sent through the session so it can close out + // the in-flight disposition. + dr := &frames.PerformDisposition{ + Role: encoding.RoleSender, + First: fr.First, + Last: fr.Last, + Settled: true, + } + + select { + case s.l.session.tx <- dr: + debug.Log(2, "TX (Sender): mux frame to Session: %d, %s", s.l.session.channel, dr) + case <-s.l.close: + return nil + case <-s.l.session.done: + return s.l.session.doneErr + } + + return nil + + default: + return s.l.muxHandleFrame(fr) + } + + return nil } diff --git a/sdk/messaging/azeventhubs/internal/go-amqp/session.go b/sdk/messaging/azeventhubs/internal/go-amqp/session.go index 98d87c34067b..9c806b7a50ea 100644 --- a/sdk/messaging/azeventhubs/internal/go-amqp/session.go +++ b/sdk/messaging/azeventhubs/internal/go-amqp/session.go @@ -1,5 +1,6 @@ // Copyright (C) 2017 Kale Blankenship // Portions Copyright (c) Microsoft Corporation + package amqp import ( @@ -8,12 +9,12 @@ import ( "fmt" "math" "sync" - "time" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/bitmap" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/debug" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/encoding" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/frames" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp/internal/queue" ) // Default session options @@ -21,98 +22,183 @@ const ( defaultWindow = 5000 ) -// Default link options -const ( - defaultLinkCredit = 1 - defaultLinkBatching = false - defaultLinkBatchMaxAge = 5 * time.Second -) +// SessionOptions contains the optional settings for configuring an AMQP session. +type SessionOptions struct { + // MaxLinks sets the maximum number of links (Senders/Receivers) + // allowed on the session. + // + // Minimum: 1. + // Default: 4294967295. + MaxLinks uint32 +} // Session is an AMQP session. // // A session multiplexes Receivers. type Session struct { channel uint16 // session's local channel - remoteChannel uint16 // session's remote channel, owned by conn.mux - conn *conn // underlying conn - rx chan frames.Frame // frames destined for this session are sent on this chan by conn.mux + remoteChannel uint16 // session's remote channel, owned by conn.connReader + conn *Conn // underlying conn tx chan frames.FrameBody // non-transfer frames to be sent; session must track disposition txTransfer chan *frames.PerformTransfer // transfer frames to be sent; session must track disposition + // frames destined for this session are added to this queue by conn.connReader + rxQ *queue.Holder[frames.FrameBody] + // flow control incomingWindow uint32 outgoingWindow uint32 needFlowCount uint32 - handleMax uint32 - allocateHandle chan *link // link handles are allocated by sending a link on this channel, nil is sent on link.rx once allocated - deallocateHandle chan *link // link handles are deallocated by sending a link on this channel + handleMax uint32 + + // link management + linksMu sync.RWMutex // used to synchronize link handle allocation + linksByKey map[linkKey]*link // mapping of name+role link + handles *bitmap.Bitmap // allocated handles - nextDeliveryID uint32 // atomically accessed sequence for deliveryIDs + // used for gracefully closing session + close chan struct{} + forceClose chan struct{} + closeOnce sync.Once - // used for gracefully closing link - close chan struct{} - closeOnce sync.Once - done chan struct{} // part of internal public surface area - err error + // part of internal public surface area + done chan struct{} // closed when the session has terminated (mux exited); DO NOT wait on this from within Session.mux() as it will never trigger! + doneErr error // contains the error state returned from Close(); DO NOT TOUCH outside of session.go until done has been closed! } -func newSession(c *conn, channel uint16) *Session { - return &Session{ - conn: c, - channel: channel, - rx: make(chan frames.Frame), - tx: make(chan frames.FrameBody), - txTransfer: make(chan *frames.PerformTransfer), - incomingWindow: defaultWindow, - outgoingWindow: defaultWindow, - handleMax: math.MaxUint32, - allocateHandle: make(chan *link), - deallocateHandle: make(chan *link), - close: make(chan struct{}), - done: make(chan struct{}), +func newSession(c *Conn, channel uint16, opts *SessionOptions) *Session { + s := &Session{ + conn: c, + channel: channel, + tx: make(chan frames.FrameBody), + txTransfer: make(chan *frames.PerformTransfer), + incomingWindow: defaultWindow, + outgoingWindow: defaultWindow, + handleMax: math.MaxUint32, + linksMu: sync.RWMutex{}, + linksByKey: make(map[linkKey]*link), + close: make(chan struct{}), + forceClose: make(chan struct{}), + done: make(chan struct{}), } -} -func (s *Session) init(opts *SessionOptions) { if opts != nil { - if opts.IncomingWindow != 0 { - s.incomingWindow = opts.IncomingWindow - } if opts.MaxLinks != 0 { // MaxLinks is the number of total links. // handleMax is the max handle ID which starts // at zero. so we decrement by one s.handleMax = opts.MaxLinks - 1 } - if opts.OutgoingWindow != 0 { - s.outgoingWindow = opts.OutgoingWindow + } + + // create handle map after options have been applied + s.handles = bitmap.New(s.handleMax) + + s.rxQ = queue.NewHolder(queue.New[frames.FrameBody](int(s.incomingWindow))) + + return s +} + +// waitForFrame waits for an incoming frame to be queued. +// it returns the next frame from the queue, or an error. +// the error is either from the context or conn.doneErr. +// not meant for consumption outside of session.go. +func (s *Session) waitForFrame(ctx context.Context) (frames.FrameBody, error) { + var q *queue.Queue[frames.FrameBody] + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-s.conn.done: + return nil, s.conn.doneErr + case q = <-s.rxQ.Wait(): + // populated queue + } + + fr := q.Dequeue() + s.rxQ.Release(q) + + return *fr, nil +} + +func (s *Session) begin(ctx context.Context) error { + // send Begin to server + begin := &frames.PerformBegin{ + NextOutgoingID: 0, + IncomingWindow: s.incomingWindow, + OutgoingWindow: s.outgoingWindow, + HandleMax: s.handleMax, + } + + _ = s.txFrame(begin, nil) + + // wait for response + fr, err := s.waitForFrame(ctx) + if err != nil { + // if we exit before receiving the ack, our caller will clean up the channel. + // however, it does mean that the peer will now have assigned an outgoing + // channel ID that's not in use. + return err + } + + begin, ok := fr.(*frames.PerformBegin) + if !ok { + // this codepath is hard to hit (impossible?). if the response isn't a PerformBegin and we've not + // yet seen the remote channel number, the default clause in conn.connReader will protect us from that. + // if we have seen the remote channel number then it's likely the session.mux for that channel will + // either swallow the frame or blow up in some other way, both causing this call to hang. + // deallocate session on error. we can't call + // s.Close() as the session mux hasn't started yet. + debug.Log(1, "RX (Session): unexpected begin response frame %T", fr) + s.conn.deleteSession(s) + if err := s.conn.Close(); err != nil { + return err } + return &ConnError{inner: fmt.Errorf("unexpected begin response: %#v", fr)} } + + // start Session multiplexor + go s.mux(begin) + + return nil } -// Close gracefully closes the session. +// Close closes the session. +// - ctx controls waiting for the peer to acknowledge the session is closed // -// If ctx expires while waiting for servers response, ctx.Err() will be returned. -// The session will continue to wait for the response until the Client is closed. +// If the context's deadline expires or is cancelled before the operation +// completes, the application can be left in an unknown state, potentially +// resulting in connection errors. func (s *Session) Close(ctx context.Context) error { - s.closeOnce.Do(func() { close(s.close) }) - select { - case <-s.done: - case <-ctx.Done(): - return ctx.Err() + var ctxErr error + s.closeOnce.Do(func() { + close(s.close) + select { + case <-s.done: + // mux has exited + case <-ctx.Done(): + close(s.forceClose) + ctxErr = ctx.Err() + } + }) + + if ctxErr != nil { + return ctxErr } - if s.err == ErrSessionClosed { + + var sessionErr *SessionError + if errors.As(s.doneErr, &sessionErr) && sessionErr.RemoteErr == nil && sessionErr.inner == nil { + // an empty SessionError means the session was closed by the caller return nil } - return s.err + return s.doneErr } // txFrame sends a frame to the connWriter. // it returns an error if the connection has been closed. func (s *Session) txFrame(p frames.FrameBody, done chan encoding.DeliveryState) error { - return s.conn.SendFrame(frames.Frame{ - Type: frameTypeAMQP, + return s.conn.sendFrame(frames.Frame{ + Type: frames.TypeAMQP, Channel: s.channel, Body: p, Done: done, @@ -120,161 +206,146 @@ func (s *Session) txFrame(p frames.FrameBody, done chan encoding.DeliveryState) } // NewReceiver opens a new receiver link on the session. -// opts: pass nil to accept the default values. +// - ctx controls waiting for the peer to create a sending terminus +// - source is the name of the peer's sending terminus +// - opts contains optional values, pass nil to accept the defaults +// +// If the context's deadline expires or is cancelled before the operation +// completes, the application can be left in an unknown state, potentially +// resulting in connection errors. func (s *Session) NewReceiver(ctx context.Context, source string, opts *ReceiverOptions) (*Receiver, error) { - r := &Receiver{ - batching: defaultLinkBatching, - batchMaxAge: defaultLinkBatchMaxAge, - maxCredit: defaultLinkCredit, - } - - l, err := newReceivingLink(source, s, r, opts) + r, err := newReceiver(source, s, opts) if err != nil { return nil, err } - if err = l.attach(ctx, s); err != nil { + if err = r.attach(ctx); err != nil { return nil, err } - r.link = l - - // batching is just extra overhead when maxCredits == 1 - if r.maxCredit == 1 { - r.batching = false - } - - // create dispositions channel and start dispositionBatcher if batching enabled - if r.batching { - // buffer dispositions chan to prevent disposition sends from blocking - r.dispositions = make(chan messageDisposition, r.maxCredit) - go r.dispositionBatcher() - } - return r, nil } // NewSender opens a new sender link on the session. -// opts: pass nil to accept the default values. +// - ctx controls waiting for the peer to create a receiver terminus +// - target is the name of the peer's receiver terminus +// - opts contains optional values, pass nil to accept the defaults +// +// If the context's deadline expires or is cancelled before the operation +// completes, the application can be left in an unknown state, potentially +// resulting in connection errors. func (s *Session) NewSender(ctx context.Context, target string, opts *SenderOptions) (*Sender, error) { - l, err := newSendingLink(target, s, opts) + l, err := newSender(target, s, opts) if err != nil { return nil, err } - if err = l.attach(ctx, s); err != nil { + if err = l.attach(ctx); err != nil { return nil, err } - return &Sender{link: l}, nil + return l, nil } func (s *Session) mux(remoteBegin *frames.PerformBegin) { defer func() { - // clean up session record in conn.mux() - select { - case <-s.rx: - // discard any incoming frames to keep conn mux unblocked - case s.conn.DelSession <- s: - // successfully deleted session - case <-s.conn.Done: - s.err = s.conn.Err() - } - if s.err == nil { - s.err = ErrSessionClosed + s.conn.deleteSession(s) + if s.doneErr == nil { + s.doneErr = &SessionError{} + } else if connErr := (&ConnError{}); !errors.As(s.doneErr, &connErr) { + // only wrap non-ConnError error types + var amqpErr *Error + if errors.As(s.doneErr, &amqpErr) { + s.doneErr = &SessionError{RemoteErr: amqpErr} + } else { + s.doneErr = &SessionError{inner: s.doneErr} + } } // Signal goroutines waiting on the session. close(s.done) }() var ( - links = make(map[uint32]*link) // mapping of remote handles to links - linksByKey = make(map[linkKey]*link) // mapping of name+role link - handles = bitmap.New(s.handleMax) // allocated handles - + links = make(map[uint32]*link) // mapping of remote handles to links handlesByDeliveryID = make(map[uint32]uint32) // mapping of deliveryIDs to handles deliveryIDByHandle = make(map[uint32]uint32) // mapping of handles to latest deliveryID handlesByRemoteDeliveryID = make(map[uint32]uint32) // mapping of remote deliveryID to handles settlementByDeliveryID = make(map[uint32]chan encoding.DeliveryState) + nextDeliveryID uint32 // tracks the next delivery ID for outgoing transfers + // flow control values nextOutgoingID uint32 nextIncomingID = remoteBegin.NextOutgoingID remoteIncomingWindow = remoteBegin.IncomingWindow remoteOutgoingWindow = remoteBegin.OutgoingWindow + + closeInProgress bool // indicates the end performative has been sent ) + closeWithError := func(e1 *Error, e2 error) { + if closeInProgress { + debug.Log(3, "TX (Session): close already pending, discarding %v", e1) + return + } + + closeInProgress = true + s.doneErr = e2 + _ = s.txFrame(&frames.PerformEnd{Error: e1}, nil) + } + for { txTransfer := s.txTransfer // disable txTransfer if flow control windows have been exceeded if remoteIncomingWindow == 0 || s.outgoingWindow == 0 { - log.Debug(1, "TX(Session): Disabling txTransfer - window exceeded. remoteIncomingWindow: %d outgoingWindow:%d", + debug.Log(1, "TX (Session): disabling txTransfer - window exceeded. remoteIncomingWindow: %d outgoingWindow: %d", remoteIncomingWindow, s.outgoingWindow) txTransfer = nil } + closed := s.close + if closeInProgress { + // swap out channel so it no longer triggers + closed = nil + } + + // notes on client-side closing session + // when session is closed, we must keep the mux running until the ack'ing end performative + // has been received. during this window, the session is allowed to receive frames but cannot + // send them. + // client-side close happens either by user calling Session.Close() or due to mux initiated + // close due to a violation of some invariant (see sending &Error{} to s.close). in the case + // that both code paths have been triggered, we must be careful to preserve the error that + // triggered the mux initiated close so it can be surfaced to the caller. + select { // conn has completed, exit - case <-s.conn.Done: - s.err = s.conn.Err() + case <-s.conn.done: + s.doneErr = s.conn.doneErr return - // session is being closed by user - case <-s.close: - _ = s.txFrame(&frames.PerformEnd{}, nil) - - // wait for the ack that the session is closed. - // we can't exit the mux, which deletes the session, - // until we receive it. - EndLoop: - for { - select { - case fr := <-s.rx: - _, ok := fr.Body.(*frames.PerformEnd) - if ok { - break EndLoop - } - case <-s.conn.Done: - s.err = s.conn.Err() - return - } - } + case <-s.forceClose: + // the call to s.Close() timed out waiting for the ack + s.doneErr = errors.New("the session was forcibly closed") return - // handle allocation request - case l := <-s.allocateHandle: - // Check if link name already exists, if so then an error should be returned - if linksByKey[l.Key] != nil { - l.err = fmt.Errorf("link with name '%v' already exists", l.Key.name) - l.RX <- nil + case <-closed: + if closeInProgress { + // a client-side close due to protocol error is in progress continue } - - next, ok := handles.Next() - if !ok { - // handle numbers are zero-based, report the actual count - l.err = fmt.Errorf("reached session handle max (%d)", s.handleMax+1) - l.RX <- nil - continue - } - - l.Handle = next // allocate handle to the link - linksByKey[l.Key] = l // add to mapping - l.RX <- nil // send nil on channel to indicate allocation complete - - // handle deallocation request - case l := <-s.deallocateHandle: - delete(links, l.RemoteHandle) - delete(deliveryIDByHandle, l.Handle) - delete(linksByKey, l.Key) - handles.Remove(l.Handle) - close(l.RX) // close channel to indicate deallocation - - // incoming frame for link - case fr := <-s.rx: - log.Debug(1, "RX(Session): %s", fr.Body) - - switch body := fr.Body.(type) { + // session is being closed by the client + closeInProgress = true + fr := frames.PerformEnd{} + _ = s.txFrame(&fr, nil) + + // incoming frame + case q := <-s.rxQ.Wait(): + fr := *q.Dequeue() + s.rxQ.Release(q) + debug.Log(2, "RX (Session): %s", fr) + + switch body := fr.(type) { // Disposition frames can reference transfers from more than one // link. Send this frame to all of them. case *frames.PerformDisposition: @@ -291,7 +362,7 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { handle, ok := handles[deliveryID] if !ok { - log.Debug(2, "role %s: didn't find deliveryID %d in handles map", body.Role, deliveryID) + debug.Log(2, "RX (Session): role %s: didn't find deliveryID %d in handles map", body.Role, deliveryID) continue } delete(handles, deliveryID) @@ -311,10 +382,14 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { link, ok := links[handle] if !ok { + closeWithError(&Error{ + Condition: ErrCondUnattachedHandle, + Description: "received disposition frame referencing a handle that's not in use", + }, fmt.Errorf("received disposition frame with unknown link handle %d", handle)) continue } - s.muxFrameToLink(link, fr.Body) + s.muxFrameToLink(link, fr) } continue case *frames.PerformFlow: @@ -322,14 +397,11 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { // This is a protocol error: // "[...] MUST be set if the peer has received // the begin frame for the session" - _ = s.txFrame(&frames.PerformEnd{ - Error: &Error{ - Condition: ErrorNotAllowed, - Description: "next-incoming-id not set after session established", - }, - }, nil) - s.err = errors.New("protocol error: received flow without next-incoming-id after session established") - return + closeWithError(&Error{ + Condition: ErrCondNotAllowed, + Description: "next-incoming-id not set after session established", + }, errors.New("protocol error: received flow without next-incoming-id after session established")) + continue } // "When the endpoint receives a flow frame from its peer, @@ -349,16 +421,20 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { // initial-outgoing-id(endpoint) + incoming-window(flow) - next-outgoing-id(endpoint)" remoteIncomingWindow = body.IncomingWindow - nextOutgoingID remoteIncomingWindow += *body.NextIncomingID - log.Debug(3, "RX(Session) Flow - remoteOutgoingWindow: %d remoteIncomingWindow: %d nextOutgoingID: %d", remoteOutgoingWindow, remoteIncomingWindow, nextOutgoingID) + debug.Log(3, "RX (Session): flow - remoteOutgoingWindow: %d remoteIncomingWindow: %d nextOutgoingID: %d", remoteOutgoingWindow, remoteIncomingWindow, nextOutgoingID) // Send to link if handle is set if body.Handle != nil { link, ok := links[*body.Handle] if !ok { + closeWithError(&Error{ + Condition: ErrCondUnattachedHandle, + Description: "received flow frame referencing a handle that's not in use", + }, fmt.Errorf("received flow frame with unknown link handle %d", body.Handle)) continue } - s.muxFrameToLink(link, fr.Body) + s.muxFrameToLink(link, fr) continue } @@ -370,7 +446,6 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { NextOutgoingID: nextOutgoingID, OutgoingWindow: s.outgoingWindow, } - log.Debug(1, "TX (session.mux): %s", resp) _ = s.txFrame(resp, nil) } @@ -380,16 +455,21 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { // attach frame. // // Note body.Role is the remote peer's role, we reverse for the local key. - link, linkOk := linksByKey[linkKey{name: body.Name, role: !body.Role}] + s.linksMu.RLock() + link, linkOk := s.linksByKey[linkKey{name: body.Name, role: !body.Role}] + s.linksMu.RUnlock() if !linkOk { - s.err = fmt.Errorf("protocol error: received mismatched attach frame %+v", body) - return + closeWithError(&Error{ + Condition: ErrCondNotAllowed, + Description: "received mismatched attach frame", + }, fmt.Errorf("protocol error: received mismatched attach frame %+v", body)) + continue } - link.RemoteHandle = body.Handle - links[link.RemoteHandle] = link + link.remoteHandle = body.Handle + links[link.remoteHandle] = link - s.muxFrameToLink(link, fr.Body) + s.muxFrameToLink(link, fr) case *frames.PerformTransfer: s.needFlowCount++ @@ -405,23 +485,25 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { } link, ok := links[body.Handle] if !ok { + closeWithError(&Error{ + Condition: ErrCondUnattachedHandle, + Description: "received transfer frame referencing a handle that's not in use", + }, fmt.Errorf("received transfer frame with unknown link handle %d", body.Handle)) continue } - select { - case <-s.conn.Done: - case link.RX <- fr.Body: - } + s.muxFrameToLink(link, fr) + debug.Log(2, "RX (Session): mux transfer to link: %s", fr) // if this message is received unsettled and link rcv-settle-mode == second, add to handlesByRemoteDeliveryID - if !body.Settled && body.DeliveryID != nil && link.ReceiverSettleMode != nil && *link.ReceiverSettleMode == ModeSecond { - log.Debug(1, "TX(Session): adding handle to handlesByRemoteDeliveryID. delivery ID: %d", *body.DeliveryID) + if !body.Settled && body.DeliveryID != nil && link.receiverSettleMode != nil && *link.receiverSettleMode == ReceiverSettleModeSecond { + debug.Log(1, "RX (Session): adding handle to handlesByRemoteDeliveryID. delivery ID: %d", *body.DeliveryID) handlesByRemoteDeliveryID[*body.DeliveryID] = body.Handle } // Update peer's outgoing window if half has been consumed. if s.needFlowCount >= s.incomingWindow/2 { - log.Debug(3, "TX(Session %d) Flow s.needFlowCount(%d) >= s.incomingWindow(%d)/2\n", s.channel, s.needFlowCount, s.incomingWindow) + debug.Log(3, "RX (Session): channel %d: flow - s.needFlowCount(%d) >= s.incomingWindow(%d)/2\n", s.channel, s.needFlowCount, s.incomingWindow) s.needFlowCount = 0 nID := nextIncomingID flow := &frames.PerformFlow{ @@ -430,33 +512,71 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { NextOutgoingID: nextOutgoingID, OutgoingWindow: s.outgoingWindow, } - log.Debug(1, "TX(Session): %s", flow) _ = s.txFrame(flow, nil) } case *frames.PerformDetach: link, ok := links[body.Handle] if !ok { + closeWithError(&Error{ + Condition: ErrCondUnattachedHandle, + Description: "received detach frame referencing a handle that's not in use", + }, fmt.Errorf("received detach frame with unknown link handle %d", body.Handle)) continue } - s.muxFrameToLink(link, fr.Body) + s.muxFrameToLink(link, fr) + + // we received a detach frame and sent it to the link. + // this was either the response to a client-side initiated + // detach or our peer detached us. either way, now that + // the link has processed the frame it's detached so we + // are safe to clean up its state. + delete(links, link.remoteHandle) + delete(deliveryIDByHandle, link.handle) case *frames.PerformEnd: - _ = s.txFrame(&frames.PerformEnd{}, nil) - s.err = fmt.Errorf("session ended by server: %s", body.Error) + // there are two possibilities: + // - this is the ack to a client-side Close() + // - the peer is ending the session so we must ack + + if closeInProgress { + return + } + + // peer detached us with an error, save it and send the ack + if body.Error != nil { + s.doneErr = body.Error + } + + fr := frames.PerformEnd{} + _ = s.txFrame(&fr, nil) + + // per spec, when end is received, we're no longer allowed to receive frames return default: - // TODO: evaluate - log.Debug(1, "session mux: unexpected frame: %s\n", body) + debug.Log(1, "RX (Session): unexpected frame: %s\n", body) + closeWithError(&Error{ + Condition: ErrCondInternalError, + Description: "session received unexpected frame", + }, fmt.Errorf("internal error: unexpected frame %T", body)) } case fr := <-txTransfer: + if closeInProgress { + // now that the end performative has been sent we're + // not allowed to send any more frames. + debug.Log(1, "TX (Session): discarding transfer: %s\n", fr) + continue + } + debug.Log(2, "TX (Session): %d, %s", s.channel, fr) // record current delivery ID var deliveryID uint32 - if fr.DeliveryID != nil { - deliveryID = *fr.DeliveryID + if fr.DeliveryID == &needsDeliveryID { + deliveryID = nextDeliveryID + fr.DeliveryID = &deliveryID + nextDeliveryID++ deliveryIDByHandle[fr.Handle] = deliveryID // add to handleByDeliveryID if not sender-settled @@ -481,7 +601,6 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { fr.Done = nil } - log.Debug(2, "TX(Session) - txtransfer: %s", fr) _ = s.txFrame(fr, fr.Done) // "Upon sending a transfer, the sending endpoint will increment @@ -494,33 +613,92 @@ func (s *Session) mux(remoteBegin *frames.PerformBegin) { } case fr := <-s.tx: + if closeInProgress { + // now that the end performative has been sent we're + // not allowed to send any more frames. + debug.Log(1, "TX (Session): discarding frame: %s\n", fr) + continue + } + + debug.Log(2, "TX (Session): %d, %s", s.channel, fr) switch fr := fr.(type) { + case *frames.PerformDisposition: + if fr.Settled && fr.Role == encoding.RoleSender { + // sender with a peer that's in mode second; sending confirmation of disposition. + // disposition frames can reference a range of delivery IDs, although it's highly + // likely in this case there will only be one. + start := fr.First + end := start + if fr.Last != nil { + end = *fr.Last + } + for deliveryID := start; deliveryID <= end; deliveryID++ { + // send delivery state to the channel and close it to signal + // that the delivery has completed. + if done, ok := settlementByDeliveryID[deliveryID]; ok { + delete(settlementByDeliveryID, deliveryID) + select { + case done <- fr.State: + default: + } + close(done) + } + } + } + _ = s.txFrame(fr, nil) case *frames.PerformFlow: niID := nextIncomingID fr.NextIncomingID = &niID fr.IncomingWindow = s.incomingWindow fr.NextOutgoingID = nextOutgoingID fr.OutgoingWindow = s.outgoingWindow - log.Debug(1, "TX(Session) - tx: %s", fr) _ = s.txFrame(fr, nil) case *frames.PerformTransfer: panic("transfer frames must use txTransfer") default: - log.Debug(1, "TX(Session) - default: %s", fr) _ = s.txFrame(fr, nil) } } } } -func (s *Session) muxFrameToLink(l *link, fr frames.FrameBody) { - select { - case l.RX <- fr: - // frame successfully sent to link - case <-l.Detached: - // link is closed - // this should be impossible to hit as the link has been removed from the session once Detached is closed - case <-s.conn.Done: - // conn is closed +func (s *Session) allocateHandle(l *link) error { + s.linksMu.Lock() + defer s.linksMu.Unlock() + + // Check if link name already exists, if so then an error should be returned + existing := s.linksByKey[l.key] + if existing != nil { + return fmt.Errorf("link with name '%v' already exists", l.key.name) } + + next, ok := s.handles.Next() + if !ok { + // handle numbers are zero-based, report the actual count + return fmt.Errorf("reached session handle max (%d)", s.handleMax+1) + } + + l.handle = next // allocate handle to the link + s.linksByKey[l.key] = l // add to mapping + + return nil } + +func (s *Session) deallocateHandle(l *link) { + s.linksMu.Lock() + defer s.linksMu.Unlock() + + delete(s.linksByKey, l.key) + s.handles.Remove(l.handle) +} + +func (s *Session) muxFrameToLink(l *link, fr frames.FrameBody) { + q := l.rxQ.Acquire() + q.Enqueue(fr) + l.rxQ.Release(q) + debug.Log(2, "RX (Session): mux frame to link: %s, %s", l.key.name, fr) +} + +// the address of this var is a sentinel value indicating +// that a transfer frame is in need of a delivery ID +var needsDeliveryID uint32 diff --git a/sdk/messaging/azeventhubs/internal/links_test.go b/sdk/messaging/azeventhubs/internal/links_test.go index 5170eca157b8..6b0fa8b812b2 100644 --- a/sdk/messaging/azeventhubs/internal/links_test.go +++ b/sdk/messaging/azeventhubs/internal/links_test.go @@ -35,9 +35,8 @@ func TestLinksCBSLinkStillOpen(t *testing.T) { newLinkFn := func(ctx context.Context, session amqpwrap.AMQPSession, entityPath string) (AMQPSenderCloser, error) { return session.NewSender(ctx, entityPath, &amqp.SenderOptions{ - SettlementMode: to.Ptr(amqp.ModeMixed), - RequestedReceiverSettleMode: to.Ptr(amqp.ModeFirst), - IgnoreDispositionErrors: true, + SettlementMode: to.Ptr(amqp.SenderSettleModeMixed), + RequestedReceiverSettleMode: to.Ptr(amqp.ReceiverSettleModeFirst), }) } diff --git a/sdk/messaging/azeventhubs/internal/links_unit_test.go b/sdk/messaging/azeventhubs/internal/links_unit_test.go index cce5e7b5ea14..5e489df83c33 100644 --- a/sdk/messaging/azeventhubs/internal/links_unit_test.go +++ b/sdk/messaging/azeventhubs/internal/links_unit_test.go @@ -53,13 +53,13 @@ func TestLinks_LinkStale(t *testing.T) { // we'll recover first, but our lwid (after this recovery) is stale since // the link cache will be updated after this is done. - err = links.RecoverIfNeeded(context.Background(), "0", staleLWID, &amqp.DetachError{}) + err = links.RecoverIfNeeded(context.Background(), "0", staleLWID, &amqp.LinkError{}) require.NoError(t, err) require.Nil(t, links.links["0"], "closed link is removed from the cache") require.Equal(t, 1, receivers[0].CloseCalled, "original receiver is closed, and replaced") // trying to recover again is a no-op (if nothing is in the cache) - err = links.RecoverIfNeeded(context.Background(), "0", staleLWID, &amqp.DetachError{}) + err = links.RecoverIfNeeded(context.Background(), "0", staleLWID, &amqp.LinkError{}) require.NoError(t, err) require.Nil(t, links.links["0"], "closed link is removed from the cache") require.Equal(t, 1, receivers[0].CloseCalled, "original receiver is closed, and replaced") @@ -73,7 +73,7 @@ func TestLinks_LinkStale(t *testing.T) { require.NotNil(t, newLWID) require.Equal(t, (*links.links["0"].Link).LinkName(), newLWID.Link.LinkName(), "cache contains the newly created link for partition 0") - err = links.RecoverIfNeeded(context.Background(), "0", staleLWID, &amqp.DetachError{}) + err = links.RecoverIfNeeded(context.Background(), "0", staleLWID, &amqp.LinkError{}) require.NoError(t, err) require.Equal(t, 0, receivers[0].CloseCalled, "receiver is NOT closed - we didn't need to replace it since the lwid with the error was stale") } @@ -100,7 +100,7 @@ func TestLinks_LinkRecoveryOnly(t *testing.T) { require.NotNil(t, lwid) require.NotNil(t, links.links["0"], "cache contains the newly created link for partition 0") - err = links.RecoverIfNeeded(context.Background(), "0", lwid, &amqp.DetachError{}) + err = links.RecoverIfNeeded(context.Background(), "0", lwid, &amqp.LinkError{}) require.NoError(t, err) require.Nil(t, links.links["0"], "cache will no longer a link for partition 0") @@ -151,12 +151,12 @@ func TestLinks_ConnectionRecovery(t *testing.T) { // if the connection has closed in response to an error then it'll propagate it's error to // the children, including receivers. Which means closing the receiver here will _also_ return // a connection error. - receiver.EXPECT().Close(mock.NotCancelledAndHasTimeout).Return(&amqp.ConnectionError{}) + receiver.EXPECT().Close(mock.NotCancelledAndHasTimeout).Return(&amqp.ConnError{}) ns.EXPECT().Recover(mock.NotCancelledAndHasTimeout, gomock.Any()).Return(nil) // initiate a connection level recovery - err = links.RecoverIfNeeded(context.Background(), "0", lwid, &amqp.ConnectionError{}) + err = links.RecoverIfNeeded(context.Background(), "0", lwid, &amqp.ConnError{}) require.NoError(t, err) // we still cleanup what we can (including cancelling our background negotiate claim loop) @@ -200,7 +200,7 @@ func TestLinks_closeWithTimeout(t *testing.T) { // purposefully recover with what should be a link level recovery. However, the Close() failing // means we end up "upgrading" to a connection reset instead. - err = links.RecoverIfNeeded(context.Background(), "0", lwid, &amqp.DetachError{}) + err = links.RecoverIfNeeded(context.Background(), "0", lwid, &amqp.LinkError{}) require.ErrorIs(t, err, errConnResetNeeded) // the error that comes back when the link times out being closed can only @@ -243,7 +243,7 @@ func TestLinks_linkRecoveryOnly(t *testing.T) { lwid, err := links.GetLink(context.Background(), "0") require.NoError(t, err) - err = links.RecoverIfNeeded(context.Background(), "0", lwid, &amqp.DetachError{}) + err = links.RecoverIfNeeded(context.Background(), "0", lwid, &amqp.LinkError{}) require.NoError(t, err) // we still cleanup what we can (including cancelling our background negotiate claim loop) diff --git a/sdk/messaging/azeventhubs/internal/mock/mock_amqp.go b/sdk/messaging/azeventhubs/internal/mock/mock_amqp.go index 78a87cee5959..4b67e349957a 100644 --- a/sdk/messaging/azeventhubs/internal/mock/mock_amqp.go +++ b/sdk/messaging/azeventhubs/internal/mock/mock_amqp.go @@ -139,18 +139,18 @@ func (mr *MockAMQPReceiverMockRecorder) Prefetched() *gomock.Call { } // Receive mocks base method. -func (m *MockAMQPReceiver) Receive(ctx context.Context) (*amqp.Message, error) { +func (m *MockAMQPReceiver) Receive(ctx context.Context, o *amqp.ReceiveOptions) (*amqp.Message, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Receive", ctx) + ret := m.ctrl.Call(m, "Receive", ctx, o) ret0, _ := ret[0].(*amqp.Message) ret1, _ := ret[1].(error) return ret0, ret1 } // Receive indicates an expected call of Receive. -func (mr *MockAMQPReceiverMockRecorder) Receive(ctx interface{}) *gomock.Call { +func (mr *MockAMQPReceiverMockRecorder) Receive(ctx, o interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Receive", reflect.TypeOf((*MockAMQPReceiver)(nil).Receive), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Receive", reflect.TypeOf((*MockAMQPReceiver)(nil).Receive), ctx, o) } // RejectMessage mocks base method. @@ -317,18 +317,18 @@ func (mr *MockAMQPReceiverCloserMockRecorder) Prefetched() *gomock.Call { } // Receive mocks base method. -func (m *MockAMQPReceiverCloser) Receive(ctx context.Context) (*amqp.Message, error) { +func (m *MockAMQPReceiverCloser) Receive(ctx context.Context, o *amqp.ReceiveOptions) (*amqp.Message, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Receive", ctx) + ret := m.ctrl.Call(m, "Receive", ctx, o) ret0, _ := ret[0].(*amqp.Message) ret1, _ := ret[1].(error) return ret0, ret1 } // Receive indicates an expected call of Receive. -func (mr *MockAMQPReceiverCloserMockRecorder) Receive(ctx interface{}) *gomock.Call { +func (mr *MockAMQPReceiverCloserMockRecorder) Receive(ctx, o interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Receive", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).Receive), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Receive", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).Receive), ctx, o) } // RejectMessage mocks base method. @@ -411,17 +411,17 @@ func (mr *MockAMQPSenderMockRecorder) MaxMessageSize() *gomock.Call { } // Send mocks base method. -func (m *MockAMQPSender) Send(ctx context.Context, msg *amqp.Message) error { +func (m *MockAMQPSender) Send(ctx context.Context, msg *amqp.Message, o *amqp.SendOptions) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Send", ctx, msg) + ret := m.ctrl.Call(m, "Send", ctx, msg, o) ret0, _ := ret[0].(error) return ret0 } // Send indicates an expected call of Send. -func (mr *MockAMQPSenderMockRecorder) Send(ctx, msg interface{}) *gomock.Call { +func (mr *MockAMQPSenderMockRecorder) Send(ctx, msg, o interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockAMQPSender)(nil).Send), ctx, msg) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockAMQPSender)(nil).Send), ctx, msg, o) } // MockAMQPSenderCloser is a mock of AMQPSenderCloser interface. @@ -490,17 +490,17 @@ func (mr *MockAMQPSenderCloserMockRecorder) MaxMessageSize() *gomock.Call { } // Send mocks base method. -func (m *MockAMQPSenderCloser) Send(ctx context.Context, msg *amqp.Message) error { +func (m *MockAMQPSenderCloser) Send(ctx context.Context, msg *amqp.Message, o *amqp.SendOptions) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Send", ctx, msg) + ret := m.ctrl.Call(m, "Send", ctx, msg, o) ret0, _ := ret[0].(error) return ret0 } // Send indicates an expected call of Send. -func (mr *MockAMQPSenderCloserMockRecorder) Send(ctx, msg interface{}) *gomock.Call { +func (mr *MockAMQPSenderCloserMockRecorder) Send(ctx, msg, o interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockAMQPSenderCloser)(nil).Send), ctx, msg) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockAMQPSenderCloser)(nil).Send), ctx, msg, o) } // MockAMQPSession is a mock of AMQPSession interface. diff --git a/sdk/messaging/azeventhubs/internal/mock/mock_helpers.go b/sdk/messaging/azeventhubs/internal/mock/mock_helpers.go index 0bb98b910b23..32577acb9db5 100644 --- a/sdk/messaging/azeventhubs/internal/mock/mock_helpers.go +++ b/sdk/messaging/azeventhubs/internal/mock/mock_helpers.go @@ -19,7 +19,7 @@ func SetupRPC(sender *MockAMQPSenderCloser, receiver *MockAMQPReceiverCloser, ex ch := make(chan *amqp.Message, 1000) for i := 0; i < expectedCount; i++ { - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(ctx context.Context, msg *amqp.Message) error { + sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Nil()).Do(func(ctx context.Context, msg *amqp.Message, o *amqp.SendOptions) error { ch <- msg return nil }) @@ -27,7 +27,7 @@ func SetupRPC(sender *MockAMQPSenderCloser, receiver *MockAMQPReceiverCloser, ex // RPC loops forever. We get one extra Receive() call here (the one that waits on the ctx.Done()) for i := 0; i < expectedCount+1; i++ { - receiver.EXPECT().Receive(gomock.Any()).DoAndReturn(func(ctx context.Context) (*amqp.Message, error) { + receiver.EXPECT().Receive(gomock.Any(), gomock.Nil()).DoAndReturn(func(ctx context.Context, o *amqp.ReceiveOptions) (*amqp.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() diff --git a/sdk/messaging/azeventhubs/internal/namespace.go b/sdk/messaging/azeventhubs/internal/namespace.go index 08c3b9bbb1c9..06e8a2e6bf42 100644 --- a/sdk/messaging/azeventhubs/internal/namespace.go +++ b/sdk/messaging/azeventhubs/internal/namespace.go @@ -180,11 +180,11 @@ func (ns *Namespace) newClientImpl(ctx context.Context) (amqpwrap.AMQPClient, er } connOptions.HostName = ns.FQDN - client, err := amqp.New(nConn, &connOptions) + client, err := amqp.NewConn(ctx, nConn, &connOptions) return &amqpwrap.AMQPClientWrapper{Inner: client}, err } - client, err := amqp.Dial(ns.getAMQPHostURI(), &connOptions) + client, err := amqp.Dial(ctx, ns.getAMQPHostURI(), &connOptions) return &amqpwrap.AMQPClientWrapper{Inner: client}, err } diff --git a/sdk/messaging/azeventhubs/internal/namespace_test.go b/sdk/messaging/azeventhubs/internal/namespace_test.go index 0f76710faeb8..c9a1910be27a 100644 --- a/sdk/messaging/azeventhubs/internal/namespace_test.go +++ b/sdk/messaging/azeventhubs/internal/namespace_test.go @@ -120,7 +120,7 @@ func TestNamespaceNegotiateClaimRenewal(t *testing.T) { nextRefreshDurationChecks := 0 ns.newClientFn = func(ctx context.Context) (amqpwrap.AMQPClient, error) { - return &amqpwrap.AMQPClientWrapper{Inner: &amqp.Client{}}, nil + return &amqpwrap.AMQPClientWrapper{Inner: &amqp.Conn{}}, nil } cancel, _, err := ns.startNegotiateClaimRenewer( @@ -188,7 +188,7 @@ func TestNamespaceNegotiateClaimNonRenewableToken(t *testing.T) { } ns.newClientFn = func(ctx context.Context) (amqpwrap.AMQPClient, error) { - return &amqpwrap.AMQPClientWrapper{Inner: &amqp.Client{}}, nil + return &amqpwrap.AMQPClientWrapper{Inner: &amqp.Conn{}}, nil } // since the token is non-renewable we will just do the single cbsNegotiateClaim call and never renew. @@ -255,7 +255,7 @@ func TestNamespaceNegotiateClaimFatalErrors(t *testing.T) { defer endCapture() ns.newClientFn = func(ctx context.Context) (amqpwrap.AMQPClient, error) { - return &amqpwrap.AMQPClientWrapper{Inner: &amqp.Client{}}, nil + return &amqpwrap.AMQPClientWrapper{Inner: &amqp.Conn{}}, nil } _, done, err := ns.startNegotiateClaimRenewer( diff --git a/sdk/messaging/azeventhubs/internal/rpc.go b/sdk/messaging/azeventhubs/internal/rpc.go index 9edb28a9405c..7aed68dc54c3 100644 --- a/sdk/messaging/azeventhubs/internal/rpc.go +++ b/sdk/messaging/azeventhubs/internal/rpc.go @@ -136,9 +136,9 @@ func NewRPCLink(ctx context.Context, args RPCLinkArgs) (*rpcLink, error) { const name = "com.microsoft:session-filter" const code = uint64(0x00000137000000C) if link.sessionID == nil { - receiverOpts.Filters = append(receiverOpts.Filters, amqp.LinkFilterSource(name, code, nil)) + receiverOpts.Filters = append(receiverOpts.Filters, amqp.NewLinkFilter(name, code, nil)) } else { - receiverOpts.Filters = append(receiverOpts.Filters, amqp.LinkFilterSource(name, code, link.sessionID)) + receiverOpts.Filters = append(receiverOpts.Filters, amqp.NewLinkFilter(name, code, link.sessionID)) } } @@ -165,7 +165,7 @@ func (l *rpcLink) startResponseRouter() { defer azlog.Writef(l.logEvent, responseRouterShutdownMessage) for { - res, err := l.receiver.Receive(l.rpcLinkCtx) + res, err := l.receiver.Receive(l.rpcLinkCtx, nil) if err != nil { // if the link or connection has a malfunction that would require it to restart then @@ -243,7 +243,7 @@ func (l *rpcLink) RPC(ctx context.Context, msg *amqp.Message) (*amqpwrap.RPCResp return nil, l.broadcastErr } - err = l.sender.Send(ctx, msg) + err = l.sender.Send(ctx, msg, nil) if err != nil { l.deleteChannelFromMap(messageID) diff --git a/sdk/messaging/azeventhubs/internal/utils/retrier_test.go b/sdk/messaging/azeventhubs/internal/utils/retrier_test.go index 240c4a9e0cfb..acae2d432949 100644 --- a/sdk/messaging/azeventhubs/internal/utils/retrier_test.go +++ b/sdk/messaging/azeventhubs/internal/utils/retrier_test.go @@ -383,7 +383,7 @@ func TestRetryLogging(t *testing.T) { azlog.Writef("TestFunc", "Attempt %d, resetting", args.I) args.ResetAttempts() reset = true - return &amqp.DetachError{} + return &amqp.LinkError{} } if reset { @@ -393,8 +393,8 @@ func TestRetryLogging(t *testing.T) { return errors.New("custom fatal error") }, func(err error) bool { - var de amqp.DetachError - return errors.Is(err, &de) + var de *amqp.LinkError + return !errors.As(err, &de) }) require.Nil(t, err) @@ -402,7 +402,7 @@ func TestRetryLogging(t *testing.T) { "[TestFunc ] Attempt 0, within test func", "[TestFunc ] Attempt 0, resetting", "[testLogEvent] (test_operation) Resetting retry attempts", - "[testLogEvent] (test_operation) Retry attempt -1 returned retryable error: link detached, reason: *Error(nil)", + "[testLogEvent] (test_operation) Retry attempt -1 returned retryable error: amqp: link closed", "[TestFunc ] Attempt 0, within test func", "[TestFunc ] Attempt 0, return nil", }, normalizeRetryLogLines(logs)) diff --git a/sdk/messaging/azeventhubs/partition_client.go b/sdk/messaging/azeventhubs/partition_client.go index bff81b6ad29b..8c5ab53fe8a3 100644 --- a/sdk/messaging/azeventhubs/partition_client.go +++ b/sdk/messaging/azeventhubs/partition_client.go @@ -20,7 +20,7 @@ import ( // DefaultConsumerGroup is the name of the default consumer group in the Event Hubs service. const DefaultConsumerGroup = "$Default" -const defaultPrefetchSize = uint32(300) +const defaultPrefetchSize = int32(300) // defaultLinkRxBuffer is the maximum number of transfer frames we can handle // on the Receiver. This matches the current default window size that go-amqp @@ -130,7 +130,7 @@ func (pc *PartitionClient) ReceiveEvents(ctx context.Context, count int, options } for { - amqpMessage, err := lwid.Link.Receive(ctx) + amqpMessage, err := lwid.Link.Receive(ctx, nil) if internal.IsOwnershipLostError(err) { log.Writef(EventConsumer, "(%s) Error, link ownership lost: %s", lwid.String(), err) @@ -215,9 +215,9 @@ func (pc *PartitionClient) newEventHubConsumerLink(ctx context.Context, session } receiverOptions := &amqp.ReceiverOptions{ - SettlementMode: to.Ptr(amqp.ModeFirst), + SettlementMode: to.Ptr(amqp.ReceiverSettleModeFirst), Filters: []amqp.LinkFilter{ - amqp.LinkFilterSelector(pc.offsetExpression), + amqp.NewSelectorFilter(pc.offsetExpression), }, Properties: props, TargetAddress: pc.instanceID, @@ -225,7 +225,7 @@ func (pc *PartitionClient) newEventHubConsumerLink(ctx context.Context, session if pc.prefetch > 0 { log.Writef(EventConsumer, "Enabling prefetch with %d credits", pc.prefetch) - receiverOptions.Credit = uint32(pc.prefetch) + receiverOptions.Credit = pc.prefetch } else if pc.prefetch == 0 { log.Writef(EventConsumer, "Enabling prefetch with %d credits", defaultPrefetchSize) receiverOptions.Credit = defaultPrefetchSize @@ -233,8 +233,7 @@ func (pc *PartitionClient) newEventHubConsumerLink(ctx context.Context, session // prefetch is disabled, enable manual credits and enable // a reasonable default max for the buffer. log.Writef(EventConsumer, "Disabling prefetch") - receiverOptions.ManualCredits = true - receiverOptions.Credit = defaultMaxCreditSize + receiverOptions.Credit = -1 } log.Writef(EventConsumer, "Creating receiver:\n source:%s\n instanceID: %s\n owner level: %d\n offset: %s\n manual: %v\n prefetch: %d", @@ -242,7 +241,7 @@ func (pc *PartitionClient) newEventHubConsumerLink(ctx context.Context, session pc.instanceID, pc.ownerLevel, pc.offsetExpression, - receiverOptions.ManualCredits, + receiverOptions.Credit == -1, pc.prefetch) receiver, err := session.NewReceiver(ctx, entityPath, receiverOptions) diff --git a/sdk/messaging/azeventhubs/partition_client_unit_test.go b/sdk/messaging/azeventhubs/partition_client_unit_test.go index e03eb2d5eb99..cbb570fadd8b 100644 --- a/sdk/messaging/azeventhubs/partition_client_unit_test.go +++ b/sdk/messaging/azeventhubs/partition_client_unit_test.go @@ -39,7 +39,7 @@ func TestUnit_PartitionClient_PrefetchOff(t *testing.T) { require.NotEmpty(t, events) require.Equal(t, []uint32{uint32(3)}, ns.Receiver.IssuedCredit, "Non-prefetch scenarios will issue credit at the time of request") - require.Equal(t, uint32(0), ns.Receiver.ActiveCredits, "All messages should have been received") + require.EqualValues(t, 0, ns.Receiver.ActiveCredits, "All messages should have been received") require.True(t, ns.Receiver.ManualCreditsSetFromOptions) } @@ -83,7 +83,7 @@ func TestUnit_PartitionClient_PrefetchOff_CreditLimits(t *testing.T) { func TestUnit_PartitionClient_PrefetchOffOnlyBackfillsCredits(t *testing.T) { testData := []struct { Name string - Initial uint32 + Initial int32 Issued []uint32 }{ {"Need some more credits", 2, []uint32{uint32(1)}}, @@ -118,7 +118,7 @@ func TestUnit_PartitionClient_PrefetchOffOnlyBackfillsCredits(t *testing.T) { require.NotEmpty(t, events) require.Equal(t, td.Issued, ns.Receiver.IssuedCredit, "Only issue credits to backfill missing credits") - require.Equal(t, uint32(0), ns.Receiver.ActiveCredits, "All messages should have been received") + require.EqualValues(t, 0, ns.Receiver.ActiveCredits, "All messages should have been received") require.True(t, ns.Receiver.ManualCreditsSetFromOptions) }) } @@ -149,10 +149,10 @@ func TestUnit_PartitionClient_PrefetchOn(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, events) - require.Equal(t, td.initialCredits, ns.Receiver.CreditsSetFromOptions, "All messages should have been received") + require.EqualValues(t, td.initialCredits, ns.Receiver.CreditsSetFromOptions, "All messages should have been received") require.Nil(t, ns.Receiver.IssuedCredit, "prefetching doesn't manually issue credits") - require.Equal(t, uint32(td.initialCredits-3), ns.Receiver.ActiveCredits, "All messages should have been received") + require.EqualValues(t, td.initialCredits-3, ns.Receiver.ActiveCredits, "All messages should have been received") } } diff --git a/sdk/messaging/azeventhubs/producer_client.go b/sdk/messaging/azeventhubs/producer_client.go index ec3653fe039e..2e17baddadea 100644 --- a/sdk/messaging/azeventhubs/producer_client.go +++ b/sdk/messaging/azeventhubs/producer_client.go @@ -162,7 +162,7 @@ type SendEventDataBatchOptions struct { // SendEventDataBatch sends an event data batch to Event Hubs. func (pc *ProducerClient) SendEventDataBatch(ctx context.Context, batch *EventDataBatch, options *SendEventDataBatchOptions) error { err := pc.links.Retry(ctx, exported.EventProducer, "SendEventDataBatch", getPartitionID(batch.partitionID), pc.retryOptions, func(ctx context.Context, lwid internal.LinkWithID[amqpwrap.AMQPSenderCloser]) error { - return lwid.Link.Send(ctx, batch.toAMQPMessage()) + return lwid.Link.Send(ctx, batch.toAMQPMessage(), nil) }) return internal.TransformError(err) } @@ -211,9 +211,8 @@ func (pc *ProducerClient) getEntityPath(partitionID string) string { func (pc *ProducerClient) newEventHubProducerLink(ctx context.Context, session amqpwrap.AMQPSession, entityPath string) (amqpwrap.AMQPSenderCloser, error) { sender, err := session.NewSender(ctx, entityPath, &amqp.SenderOptions{ - SettlementMode: to.Ptr(amqp.ModeMixed), - RequestedReceiverSettleMode: to.Ptr(amqp.ModeFirst), - IgnoreDispositionErrors: true, + SettlementMode: to.Ptr(amqp.SenderSettleModeMixed), + RequestedReceiverSettleMode: to.Ptr(amqp.ReceiverSettleModeFirst), }) if err != nil {