From 895e54285c00c18f3ea67a1e3eda6e51a0942658 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Mon, 25 Sep 2023 12:43:14 +0200 Subject: [PATCH] [FIXED] Fix pull heartbeat validation Signed-off-by: Piotr Piotrowski --- js.go | 22 +++++++++---- test/js_test.go | 87 ++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 99 insertions(+), 10 deletions(-) diff --git a/js.go b/js.go index b31839bf2..7fdb0131c 100644 --- a/js.go +++ b/js.go @@ -2756,9 +2756,6 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { ttl = js.opts.wait } sub.mu.Unlock() - if o.hb != 0 && 2*o.hb >= ttl { - return nil, fmt.Errorf("%w: idle heartbeat value too large", ErrInvalidArg) - } // Use the given context or setup a default one for the span // of the pull batch request. @@ -2784,6 +2781,14 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { } defer cancel() + // if heartbeat is set, validate it against the context timeout + if o.hb > 0 { + deadline, _ := ctx.Deadline() + if 2*o.hb >= time.Until(deadline) { + return nil, fmt.Errorf("%w: idle heartbeat value too large", ErrInvalidArg) + } + } + // Check if context not done already before making the request. select { case <-ctx.Done(): @@ -3017,9 +3022,6 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e ttl = js.opts.wait } sub.mu.Unlock() - if o.hb != 0 && 2*o.hb >= ttl { - return nil, fmt.Errorf("%w: idle heartbeat value too large", ErrInvalidArg) - } // Use the given context or setup a default one for the span // of the pull batch request. @@ -3050,6 +3052,14 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e } }() + // if heartbeat is set, validate it against the context timeout + if o.hb > 0 { + deadline, _ := ctx.Deadline() + if 2*o.hb >= time.Until(deadline) { + return nil, fmt.Errorf("%w: idle heartbeat value too large", ErrInvalidArg) + } + } + // Check if context not done already before making the request. select { case <-ctx.Done(): diff --git a/test/js_test.go b/test/js_test.go index 83e7e3d2a..aad1c5bc8 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -1143,6 +1143,7 @@ func TestPullSubscribeFetchWithHeartbeat(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %s", err) } + defer sub.Unsubscribe() for i := 0; i < 5; i++ { if _, err := js.Publish("foo", []byte("msg")); err != nil { t.Fatalf("Unexpected error: %s", err) @@ -1184,13 +1185,51 @@ func TestPullSubscribeFetchWithHeartbeat(t *testing.T) { // heartbeat value too large _, err = sub.Fetch(5, nats.PullHeartbeat(200*time.Millisecond), nats.MaxWait(300*time.Millisecond)) if !errors.Is(err, nats.ErrInvalidArg) { - t.Fatalf("Expected no heartbeat error; got: %v", err) + t.Fatalf("Expected invalid arg error; got: %v", err) } // heartbeat value invalid _, err = sub.Fetch(5, nats.PullHeartbeat(-1)) if !errors.Is(err, nats.ErrInvalidArg) { - t.Fatalf("Expected no heartbeat error; got: %v", err) + t.Fatalf("Expected invalid arg error; got: %v", err) + } + + // set short timeout on JetStream context + js, err = nc.JetStream(nats.MaxWait(100 * time.Millisecond)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + sub1, err := js.PullSubscribe("foo", "") + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + defer sub.Unsubscribe() + + // should produce invalid arg error based on default timeout from JetStream context + _, err = sub1.Fetch(5, nats.PullHeartbeat(100*time.Millisecond)) + if !errors.Is(err, nats.ErrInvalidArg) { + t.Fatalf("Expected invalid arg error; got: %v", err) + } + + // overwrite default timeout with context timeout, fetch available messages + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + msgs, err = sub1.Fetch(10, nats.PullHeartbeat(100*time.Millisecond), nats.Context(ctx)) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + if len(msgs) != 5 { + t.Fatalf("Expected %d messages; got: %d", 5, len(msgs)) + } + for _, msg := range msgs { + msg.Ack() + } + + // overwrite default timeout with max wait, should time out because no messages are available + _, err = sub1.Fetch(5, nats.PullHeartbeat(100*time.Millisecond), nats.MaxWait(300*time.Millisecond)) + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected timeout error; got: %v", err) } } @@ -1213,6 +1252,7 @@ func TestPullSubscribeFetchBatchWithHeartbeat(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %s", err) } + defer sub.Unsubscribe() for i := 0; i < 5; i++ { if _, err := js.Publish("foo", []byte("msg")); err != nil { t.Fatalf("Unexpected error: %s", err) @@ -1282,16 +1322,55 @@ func TestPullSubscribeFetchBatchWithHeartbeat(t *testing.T) { } // heartbeat value too large - _, err = sub.Fetch(5, nats.PullHeartbeat(200*time.Millisecond), nats.MaxWait(300*time.Millisecond)) + _, err = sub.FetchBatch(5, nats.PullHeartbeat(200*time.Millisecond), nats.MaxWait(300*time.Millisecond)) if !errors.Is(err, nats.ErrInvalidArg) { t.Fatalf("Expected no heartbeat error; got: %v", err) } // heartbeat value invalid - _, err = sub.Fetch(5, nats.PullHeartbeat(-1)) + _, err = sub.FetchBatch(5, nats.PullHeartbeat(-1)) if !errors.Is(err, nats.ErrInvalidArg) { t.Fatalf("Expected no heartbeat error; got: %v", err) } + + // set short timeout on JetStream context + js, err = nc.JetStream(nats.MaxWait(100 * time.Millisecond)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + sub1, err := js.PullSubscribe("foo", "") + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + defer sub.Unsubscribe() + + // should produce invalid arg error based on default timeout from JetStream context + _, err = sub1.Fetch(5, nats.PullHeartbeat(100*time.Millisecond)) + if !errors.Is(err, nats.ErrInvalidArg) { + t.Fatalf("Expected invalid arg error; got: %v", err) + } + + // overwrite default timeout with context timeout, fetch available messages + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + msgs, err = sub1.FetchBatch(10, nats.PullHeartbeat(100*time.Millisecond), nats.Context(ctx)) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + for msg := range msgs.Messages() { + msg.Ack() + } + + // overwrite default timeout with max wait, should time out because no messages are available + msgs, err = sub1.FetchBatch(5, nats.PullHeartbeat(100*time.Millisecond), nats.MaxWait(300*time.Millisecond)) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + <-msgs.Done() + if msgs.Error() != nil { + t.Fatalf("Unexpected error: %s", msgs.Error()) + } } func TestPullSubscribeFetchBatch(t *testing.T) {