From afb80096eae340697e1153d7af9a5a418ba75067 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 29 May 2018 15:43:34 -0400 Subject: [PATCH] pubsub: implement synchronous mode for Receive Add a boolean ReceiveSetting called Synchronous, false by default. When true, Receive uses the ordinary Pull RPC instead of StreamingPull, and is careful to ensure that the number of messages in the client never exceeds MaxOutstandingMessages. Fixes #1088. Change-Id: I2ef9d06263d6487c14786e7ed98580e52254cd47 Reviewed-on: https://code-review.googlesource.com/33330 Reviewed-by: kokoro Reviewed-by: Michael Darakananda --- pubsub/flow_controller.go | 12 ++++ pubsub/flow_controller_test.go | 26 +++++-- pubsub/integration_test.go | 109 +++++++++++++++------------- pubsub/iterator.go | 127 ++++++++++++++++++++++----------- pubsub/subscription.go | 77 +++++++++++++++++--- pubsub/subscription_test.go | 34 +++++++++ 6 files changed, 282 insertions(+), 103 deletions(-) diff --git a/pubsub/flow_controller.go b/pubsub/flow_controller.go index 594b7a962f22..a5aeb19f7f90 100644 --- a/pubsub/flow_controller.go +++ b/pubsub/flow_controller.go @@ -15,14 +15,18 @@ package pubsub import ( + "sync/atomic" + "golang.org/x/net/context" "golang.org/x/sync/semaphore" ) // flowController implements flow control for Subscription.Receive. type flowController struct { + maxCount int maxSize int // max total size of messages semCount, semSize *semaphore.Weighted // enforces max number and size of messages + count_ int64 // acquires - releases (atomic) } // newFlowController creates a new flowController that ensures no more than @@ -31,6 +35,7 @@ type flowController struct { // respectively. func newFlowController(maxCount, maxSize int) *flowController { fc := &flowController{ + maxCount: maxCount, maxSize: maxSize, semCount: nil, semSize: nil, @@ -63,6 +68,7 @@ func (f *flowController) acquire(ctx context.Context, size int) error { return err } } + atomic.AddInt64(&f.count_, 1) return nil } @@ -85,11 +91,13 @@ func (f *flowController) tryAcquire(size int) bool { return false } } + atomic.AddInt64(&f.count_, 1) return true } // release notes that one message of size bytes is no longer outstanding. func (f *flowController) release(size int) { + atomic.AddInt64(&f.count_, -1) if f.semCount != nil { f.semCount.Release(1) } @@ -104,3 +112,7 @@ func (f *flowController) bound(size int) int64 { } return int64(size) } + +func (f *flowController) count() int { + return int(atomic.LoadInt64(&f.count_)) +} diff --git a/pubsub/flow_controller_test.go b/pubsub/flow_controller_test.go index 1449803c2bf9..6ca8e22e36e8 100644 --- a/pubsub/flow_controller_test.go +++ b/pubsub/flow_controller_test.go @@ -122,7 +122,8 @@ func TestFlowControllerSaturation(t *testing.T) { } { fc := newFlowController(maxCount, maxSize) // Atomically track flow controller state. - var curCount, curSize int64 + // The flowController itself tracks count. + var curSize int64 success := errors.New("") // Time out if wantSize or wantCount is never reached. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -143,7 +144,7 @@ func TestFlowControllerSaturation(t *testing.T) { if err := fc.acquire(ctx, test.acquireSize); err != nil { return err } - c := atomic.AddInt64(&curCount, 1) + c := int64(fc.count()) if c > test.wantCount { return fmt.Errorf("count %d exceeds want %d", c, test.wantCount) } @@ -158,9 +159,6 @@ func TestFlowControllerSaturation(t *testing.T) { hitSize = true } time.Sleep(5 * time.Millisecond) // Let other goroutines make progress. - if atomic.AddInt64(&curCount, -1) < 0 { - return errors.New("negative count") - } if atomic.AddInt64(&curSize, -int64(test.acquireSize)) < 0 { return errors.New("negative size") } @@ -217,6 +215,24 @@ func TestFlowControllerUnboundedCount(t *testing.T) { } } +func TestFlowControllerUnboundedCount2(t *testing.T) { + t.Parallel() + ctx := context.Background() + fc := newFlowController(0, 0) + // Successfully acquire 4 bytes. + if err := fc.acquire(ctx, 4); err != nil { + t.Errorf("got %v, wanted no error", err) + } + fc.release(1) + fc.release(1) + fc.release(1) + wantCount := int64(-2) + c := int64(fc.count()) + if c != wantCount { + t.Fatalf("got count %d, want %d", c, wantCount) + } +} + func TestFlowControllerUnboundedBytes(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/pubsub/integration_test.go b/pubsub/integration_test.go index d673b37515b5..a3bd54ed73f7 100644 --- a/pubsub/integration_test.go +++ b/pubsub/integration_test.go @@ -106,55 +106,11 @@ func TestIntegration_All(t *testing.T) { t.Errorf("subscription %s should exist, but it doesn't", sub.ID()) } - var msgs []*Message - for i := 0; i < 10; i++ { - text := fmt.Sprintf("a message with an index %d", i) - attrs := make(map[string]string) - attrs["foo"] = "bar" - msgs = append(msgs, &Message{ - Data: []byte(text), - Attributes: attrs, - }) - } - - // Publish the messages. - type pubResult struct { - m *Message - r *PublishResult - } - var rs []pubResult - for _, m := range msgs { - r := topic.Publish(ctx, m) - rs = append(rs, pubResult{m, r}) - } - want := make(map[string]*messageData) - for _, res := range rs { - id, err := res.r.Get(ctx) - if err != nil { - t.Fatal(err) + for _, sync := range []bool{false, true} { + for _, maxMsgs := range []int{0, 3, -1} { // MaxOutstandingMessages = default, 3, unlimited + testPublishAndReceive(t, topic, sub, maxMsgs, sync) } - md := extractMessageData(res.m) - md.ID = id - want[md.ID] = md } - - // Use a timeout to ensure that Pull does not block indefinitely if there are unexpectedly few messages available. - timeoutCtx, _ := context.WithTimeout(ctx, time.Minute) - gotMsgs, err := pullN(timeoutCtx, sub, len(want), func(ctx context.Context, m *Message) { - m.Ack() - }) - if err != nil { - t.Fatalf("Pull: %v", err) - } - got := make(map[string]*messageData) - for _, m := range gotMsgs { - md := extractMessageData(m) - got[md.ID] = md - } - if !testutil.Equal(got, want) { - t.Errorf("messages: got: %v ; want: %v", got, want) - } - if msg, ok := testIAM(ctx, topic.IAM(), "pubsub.topics.get"); !ok { t.Errorf("topic IAM: %s", msg) } @@ -167,7 +123,8 @@ func TestIntegration_All(t *testing.T) { t.Fatalf("CreateSnapshot error: %v", err) } - timeoutCtx, _ = context.WithTimeout(ctx, time.Minute) + timeoutCtx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() err = internal.Retry(timeoutCtx, gax.Backoff{}, func() (bool, error) { snapIt := client.Snapshots(timeoutCtx) for { @@ -221,6 +178,62 @@ func TestIntegration_All(t *testing.T) { } } +func testPublishAndReceive(t *testing.T, topic *Topic, sub *Subscription, maxMsgs int, synchronous bool) { + ctx := context.Background() + var msgs []*Message + for i := 0; i < 10; i++ { + text := fmt.Sprintf("a message with an index %d", i) + attrs := make(map[string]string) + attrs["foo"] = "bar" + msgs = append(msgs, &Message{ + Data: []byte(text), + Attributes: attrs, + }) + } + + // Publish some messages. + type pubResult struct { + m *Message + r *PublishResult + } + var rs []pubResult + for _, m := range msgs { + r := topic.Publish(ctx, m) + rs = append(rs, pubResult{m, r}) + } + want := make(map[string]*messageData) + for _, res := range rs { + id, err := res.r.Get(ctx) + if err != nil { + t.Fatal(err) + } + md := extractMessageData(res.m) + md.ID = id + want[md.ID] = md + } + + sub.ReceiveSettings.MaxOutstandingMessages = maxMsgs + sub.ReceiveSettings.Synchronous = synchronous + // Use a timeout to ensure that Pull does not block indefinitely if there are + // unexpectedly few messages available. + timeoutCtx, _ := context.WithTimeout(ctx, time.Minute) + gotMsgs, err := pullN(timeoutCtx, sub, len(want), func(ctx context.Context, m *Message) { + m.Ack() + }) + if err != nil { + t.Fatalf("Pull: %v", err) + } + got := make(map[string]*messageData) + for _, m := range gotMsgs { + md := extractMessageData(m) + got[md.ID] = md + } + if !testutil.Equal(got, want) { + t.Errorf("MaxOutstandingMessages=%d, Synchronous=%t: messages: got: %v ; want: %v", + maxMsgs, synchronous, got, want) + } +} + // IAM tests. // NOTE: for these to succeed, the test runner identity must have the Pub/Sub Admin or Owner roles. // To set, visit https://console.developers.google.com, select "IAM & Admin" from the top-left diff --git a/pubsub/iterator.go b/pubsub/iterator.go index ca474f69385b..b531ab9422da 100644 --- a/pubsub/iterator.go +++ b/pubsub/iterator.go @@ -31,16 +31,7 @@ import ( // of the actual deadline. const gracePeriod = 5 * time.Second -// newMessageIterator starts a new streamingMessageIterator. Stop must be called on the messageIterator -// when it is no longer needed. -// subName is the full name of the subscription to pull messages from. -// ctx is the context to use for acking messages and extending message deadlines. -func newMessageIterator(ctx context.Context, subc *vkit.SubscriberClient, subName string, po *pullOptions) *streamingMessageIterator { - ps := newPullStream(ctx, subc.StreamingPull, subName) - return newStreamingMessageIterator(ctx, ps, po, subc, subName, po.minAckDeadline) -} - -type streamingMessageIterator struct { +type messageIterator struct { ctx context.Context po *pullOptions ps *pullStream @@ -73,7 +64,15 @@ type streamingMessageIterator struct { minAckDeadline time.Duration } -func newStreamingMessageIterator(ctx context.Context, ps *pullStream, po *pullOptions, subc *vkit.SubscriberClient, subName string, minAckDeadline time.Duration) *streamingMessageIterator { +// newMessageIterator starts and returns a new messageIterator. +// ctx is the context to use for acking messages and extending message deadlines. +// subName is the full name of the subscription to pull messages from. +// Stop must be called on the messageIterator when it is no longer needed. +func newMessageIterator(ctx context.Context, subc *vkit.SubscriberClient, subName string, po *pullOptions) *messageIterator { + var ps *pullStream + if !po.synchronous { + ps = newPullStream(ctx, subc.StreamingPull, subName) + } // The period will update each tick based on the distribution of acks. We'll start by arbitrarily sending // the first keepAlive halfway towards the minimum ack deadline. keepAlivePeriod := minAckDeadline / 2 @@ -82,7 +81,7 @@ func newStreamingMessageIterator(ctx context.Context, ps *pullStream, po *pullOp ackTicker := time.NewTicker(100 * time.Millisecond) nackTicker := time.NewTicker(100 * time.Millisecond) pingTicker := time.NewTicker(30 * time.Second) - it := &streamingMessageIterator{ + it := &messageIterator{ ctx: ctx, ps: ps, po: po, @@ -100,7 +99,6 @@ func newStreamingMessageIterator(ctx context.Context, ps *pullStream, po *pullOp pendingAcks: map[string]bool{}, pendingNacks: map[string]bool{}, pendingModAcks: map[string]bool{}, - minAckDeadline: minAckDeadline, } it.wg.Add(1) go it.sender() @@ -111,7 +109,7 @@ func newStreamingMessageIterator(ctx context.Context, ps *pullStream, po *pullOp // Stop will block until Done has been called on all Messages that have been // returned by Next, or until the context with which the messageIterator was created // is cancelled or exceeds its deadline. -func (it *streamingMessageIterator) stop() { +func (it *messageIterator) stop() { it.mu.Lock() select { case <-it.stopped: @@ -127,7 +125,7 @@ func (it *streamingMessageIterator) stop() { // pending messages have either been n/acked or expired. // // Called with the lock held. -func (it *streamingMessageIterator) checkDrained() { +func (it *messageIterator) checkDrained() { select { case <-it.drained: return @@ -143,7 +141,7 @@ func (it *streamingMessageIterator) checkDrained() { } // Called when a message is acked/nacked. -func (it *streamingMessageIterator) done(ackID string, ack bool, receiveTime time.Time) { +func (it *messageIterator) done(ackID string, ack bool, receiveTime time.Time) { it.ackTimeDist.Record(int(time.Since(receiveTime) / time.Second)) it.mu.Lock() defer it.mu.Unlock() @@ -159,7 +157,7 @@ func (it *streamingMessageIterator) done(ackID string, ack bool, receiveTime tim // fail is called when a stream method returns a permanent error. // fail returns it.err. This may be err, or it may be the error // set by an earlier call to fail. -func (it *streamingMessageIterator) fail(err error) error { +func (it *messageIterator) fail(err error) error { it.mu.Lock() defer it.mu.Unlock() if it.err == nil { @@ -169,9 +167,10 @@ func (it *streamingMessageIterator) fail(err error) error { return it.err } -// receive makes a call to the stream's Recv method and returns +// receive makes a call to the stream's Recv method, or the Pull RPC, and returns // its messages. -func (it *streamingMessageIterator) receive() ([]*Message, error) { +// maxToPull is the maximum number of messages for the Pull RPC. +func (it *messageIterator) receive(maxToPull int32) ([]*Message, error) { it.mu.Lock() if it.err != nil { return nil, it.err @@ -186,18 +185,23 @@ func (it *streamingMessageIterator) receive() ([]*Message, error) { default: } - // Receive messages from stream. This may block indefinitely. - res, err := it.ps.Recv() - // The pullStream handles retries, so any error here is fatal. + var rmsgs []*pb.ReceivedMessage + var err error + if it.po.synchronous { + rmsgs, err = it.pullMessages(maxToPull) + } else { + rmsgs, err = it.recvMessages() + } + // Any error here is fatal. if err != nil { return nil, it.fail(err) } - msgs, err := convertMessages(res.ReceivedMessages) + msgs, err := convertMessages(rmsgs) if err != nil { return nil, it.fail(err) } // We received some messages. Remember them so we can keep them alive. Also, - // do a receipt mod-ack. + // do a receipt mod-ack when streaming. maxExt := time.Now().Add(it.po.maxExtension) ackIDs := map[string]bool{} it.mu.Lock() @@ -207,28 +211,71 @@ func (it *streamingMessageIterator) receive() ([]*Message, error) { addRecv(m.ID, m.ackID, now) m.doneFunc = it.done it.keepAliveDeadlines[m.ackID] = maxExt - // Don't change the mod-ack if the message is going to be nacked. This is // possible if there are retries. - if !it.pendingNacks[m.ackID] { + if !it.pendingNacks[m.ackID] && !it.po.synchronous { ackIDs[m.ackID] = true } } deadline := it.ackDeadline() it.mu.Unlock() - if !it.sendModAck(ackIDs, deadline) { - return nil, it.err + if len(ackIDs) > 0 { + if !it.sendModAck(ackIDs, deadline) { + return nil, it.err + } } return msgs, nil } +// Get messages using the Pull RPC. +// This may block indefinitely. It may also return zero messages, after some time waiting. +func (it *messageIterator) pullMessages(maxToPull int32) ([]*pb.ReceivedMessage, error) { + // Use a different context so this RPC can be canceled. + cctx, cancel := context.WithCancel(it.ctx) + defer cancel() + go func() { + // Turn a stop into a cancel. + // TODO(jba): replace the stopped channel with a context. + select { + case <-it.stopped: + cancel() + case <-cctx.Done(): + // We end up here when the deferred cancel runs at the end of the function. + } + }() + res, err := it.subc.Pull(cctx, &pb.PullRequest{ + Subscription: it.subName, + MaxMessages: maxToPull, + }) + switch { + case err == context.Canceled: + return nil, nil + case err != nil: + return nil, err + default: + return res.ReceivedMessages, nil + } +} + +func (it *messageIterator) recvMessages() ([]*pb.ReceivedMessage, error) { + res, err := it.ps.Recv() + if err != nil { + return nil, err + } + return res.ReceivedMessages, nil +} + // sender runs in a goroutine and handles all sends to the stream. -func (it *streamingMessageIterator) sender() { +func (it *messageIterator) sender() { defer it.wg.Done() defer it.ackTicker.Stop() defer it.nackTicker.Stop() defer it.pingTicker.Stop() - defer it.ps.CloseSend() + defer func() { + if it.ps != nil { + it.ps.CloseSend() + } + }() done := false for !done { @@ -281,8 +328,8 @@ func (it *streamingMessageIterator) sender() { case <-it.pingTicker.C: it.mu.Lock() - // Ping only if we are processing messages. - sendPing = (len(it.keepAliveDeadlines) > 0) + // Ping only if we are processing messages via streaming. + sendPing = !it.po.synchronous && (len(it.keepAliveDeadlines) > 0) } // Lock is held here. var acks, nacks, modAcks map[string]bool @@ -326,7 +373,7 @@ func (it *streamingMessageIterator) sender() { // for live messages. It also purges expired messages. // // Called with the lock held. -func (it *streamingMessageIterator) handleKeepAlives() { +func (it *messageIterator) handleKeepAlives() { now := time.Now() for id, expiry := range it.keepAliveDeadlines { if expiry.Before(now) { @@ -343,7 +390,7 @@ func (it *streamingMessageIterator) handleKeepAlives() { it.checkDrained() } -func (it *streamingMessageIterator) sendAck(m map[string]bool) bool { +func (it *messageIterator) sendAck(m map[string]bool) bool { return it.sendAckIDRPC(m, func(ids []string) error { addAcks(ids) return it.subc.Acknowledge(it.ctx, &pb.AcknowledgeRequest{ @@ -357,7 +404,7 @@ func (it *streamingMessageIterator) sendAck(m map[string]bool) bool { // on the time it takes to process messages. The percentile chosen is the 99%th // percentile in order to capture the highest amount of time necessary without // considering 1% outliers. -func (it *streamingMessageIterator) sendModAck(m map[string]bool, deadline time.Duration) bool { +func (it *messageIterator) sendModAck(m map[string]bool, deadline time.Duration) bool { return it.sendAckIDRPC(m, func(ids []string) error { addModAcks(ids, int32(deadline/time.Second)) return it.subc.ModifyAckDeadline(it.ctx, &pb.ModifyAckDeadlineRequest{ @@ -368,7 +415,7 @@ func (it *streamingMessageIterator) sendModAck(m map[string]bool, deadline time. }) } -func (it *streamingMessageIterator) sendAckIDRPC(ackIDSet map[string]bool, call func([]string) error) bool { +func (it *messageIterator) sendAckIDRPC(ackIDSet map[string]bool, call func([]string) error) bool { ackIDs := make([]string, 0, len(ackIDSet)) for k := range ackIDSet { ackIDs = append(ackIDs, k) @@ -392,7 +439,7 @@ func (it *streamingMessageIterator) sendAckIDRPC(ackIDSet map[string]bool, call // network. This matters if it takes a long time to process messages relative to the // default ack deadline, and if the messages are small enough so that many can fit // into the buffer. -func (it *streamingMessageIterator) pingStream() { +func (it *messageIterator) pingStream() { // Ignore error; if the stream is broken, this doesn't matter anyway. _ = it.ps.Send(&pb.StreamingPullRequest{}) } @@ -416,14 +463,14 @@ func splitRequestIDs(ids []string, maxSize int) (prefix, remainder []string) { // times should be safe. The highest 1% may expire. This number was chosen // as a way to cover most users' usecases without losing the value of // expiration. -func (it *streamingMessageIterator) ackDeadline() time.Duration { +func (it *messageIterator) ackDeadline() time.Duration { pt := time.Duration(it.ackTimeDist.Percentile(.99)) * time.Second if pt > maxAckDeadline { return maxAckDeadline } - if pt < it.minAckDeadline { - return it.minAckDeadline + if pt < minAckDeadline { + return minAckDeadline } return pt } diff --git a/pubsub/subscription.go b/pubsub/subscription.go index 4e4f21af5078..6e9350c534e8 100644 --- a/pubsub/subscription.go +++ b/pubsub/subscription.go @@ -26,6 +26,7 @@ import ( "cloud.google.com/go/internal/optional" "github.com/golang/protobuf/ptypes" durpb "github.com/golang/protobuf/ptypes/duration" + gax "github.com/googleapis/gax-go" "golang.org/x/net/context" "golang.org/x/sync/errgroup" pb "google.golang.org/genproto/googleapis/pubsub/v1" @@ -229,8 +230,35 @@ type ReceiveSettings struct { // function passed to Receive on them. To limit the number of messages being // processed concurrently, set MaxOutstandingMessages. NumGoroutines int + + // If Synchronous is true, then no more than MaxOutstandingMessages will be in + // memory at one time. (In contrast, when Synchronous is false, more than + // MaxOutstandingMessages may have been received from the service and in memory + // before being processed.) MaxOutstandingBytes still refers to the total bytes + // processed, rather than in memory. NumGoroutines is ignored. + // The default is false. + Synchronous bool } +// For synchronous receive, the time to wait if we are already processing +// MaxOutstandingMessages. There is no point calling Pull and asking for zero +// messages, so we pause to allow some message-processing callbacks to finish. +// +// The wait time is large enough to avoid consuming significant CPU, but +// small enough to provide decent throughput. Users who want better +// throughput should not be using synchronous mode. +// +// Waiting might seem like polling, so it's natural to think we could do better by +// noticing when a callback is finished and immediately calling Pull. But if +// callbacks finish in quick succession, this will result in frequent Pull RPCs that +// request a single message, which wastes network bandwidth. Better to wait for a few +// callbacks to finish, so we make fewer RPCs fetching more messages. +// +// This value is unexported so the user doesn't have another knob to think about. Note that +// it is the same value as the one used for nackTicker, so it matches this client's +// idea of a duration that is short, but not so short that we perform excessive RPCs. +const synchronousWaitTime = 100 * time.Millisecond + // This is a var so that tests can change it. var minAckDeadline = 10 * time.Second @@ -436,15 +464,20 @@ func (s *Subscription) Receive(ctx context.Context, f func(context.Context, *Mes // If MaxExtension is negative, disable automatic extension. maxExt = 0 } - numGoroutines := s.ReceiveSettings.NumGoroutines - if numGoroutines < 1 { + var numGoroutines int + switch { + case s.ReceiveSettings.Synchronous: + numGoroutines = 1 + case s.ReceiveSettings.NumGoroutines >= 1: + numGoroutines = s.ReceiveSettings.NumGoroutines + default: numGoroutines = DefaultReceiveSettings.NumGoroutines } // TODO(jba): add tests that verify that ReceiveSettings are correctly processed. po := &pullOptions{ - minAckDeadline: minAckDeadline, - maxExtension: maxExt, - maxPrefetch: trunc32(int64(maxCount)), + maxExtension: maxExt, + maxPrefetch: trunc32(int64(maxCount)), + synchronous: s.ReceiveSettings.Synchronous, } fc := newFlowController(maxCount, maxBytes) @@ -487,7 +520,29 @@ func (s *Subscription) receive(ctx context.Context, po *pullOptions, fc *flowCon defer cancel() for { - msgs, err := iter.receive() + var maxToPull int32 // maximum number of messages to pull + if po.synchronous { + if po.maxPrefetch < 0 { + // If there is no limit on the number of messages to pull, use a reasonable default. + maxToPull = 1000 + } else { + // Limit the number of messages in memory to MaxOutstandingMessages + // (here, po.maxPrefetch). For each message currently in memory, we have + // called fc.acquire but not fc.release: this is fc.count(). The next + // call to Pull should fetch no more than the difference between these + // values. + maxToPull = po.maxPrefetch - int32(fc.count()) + if maxToPull <= 0 { + // Wait for some callbacks to finish. + if err := gax.Sleep(ctx, synchronousWaitTime); err != nil { + // Return nil if the context is done, not err. + return nil + } + continue + } + } + } + msgs, err := iter.receive(maxToPull) if err == io.EOF { return nil } @@ -502,6 +557,7 @@ func (s *Subscription) receive(ctx context.Context, po *pullOptions, fc *flowCon for _, m := range msgs[i:] { m.Nack() } + // Return nil if the context is done, not err. return nil } old := msg.doneFunc @@ -519,9 +575,10 @@ func (s *Subscription) receive(ctx context.Context, po *pullOptions, fc *flowCon } } -// TODO(jba): remove when we delete messageIterator. type pullOptions struct { - minAckDeadline time.Duration - maxExtension time.Duration - maxPrefetch int32 + maxExtension time.Duration + maxPrefetch int32 + // If true, use unary Pull instead of StreamingPull, and never pull more + // than maxPrefetch messages. + synchronous bool } diff --git a/pubsub/subscription_test.go b/pubsub/subscription_test.go index 122c70832b90..d76fd8619bf0 100644 --- a/pubsub/subscription_test.go +++ b/pubsub/subscription_test.go @@ -178,6 +178,40 @@ func TestUpdateSubscription(t *testing.T) { } } +func TestReceive(t *testing.T) { + testReceive(t, true) + testReceive(t, false) +} + +func testReceive(t *testing.T, synchronous bool) { + ctx := context.Background() + client, srv := newFake(t) + defer client.Close() + + topic := mustCreateTopic(t, client, "t") + sub, err := client.CreateSubscription(ctx, "s", SubscriptionConfig{Topic: topic}) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 256; i++ { + srv.Publish(topic.name, []byte{byte(i)}, nil) + } + sub.ReceiveSettings.Synchronous = synchronous + msgs, err := pullN(ctx, sub, 256, func(_ context.Context, m *Message) { m.Ack() }) + if err != nil { + t.Fatal(err) + } + var seen [256]bool + for _, m := range msgs { + seen[m.Data[0]] = true + } + for i, saw := range seen { + if !saw { + t.Errorf("sync=%t: did not see message #%d", synchronous, i) + } + } +} + func (t1 *Topic) Equal(t2 *Topic) bool { if t1 == nil && t2 == nil { return true