From 715851791ce072765cd135abda49509de0797b33 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Tue, 3 Oct 2023 19:54:26 +0200 Subject: [PATCH] [FIXED] Invalid handling of heartbeats in Consume and Messages Signed-off-by: Piotr Piotrowski --- jetstream/options.go | 22 +-- jetstream/pull.go | 25 ++- jetstream/test/helper_test.go | 8 + jetstream/test/pull_test.go | 285 ++++++++++++++++++++++------------ 4 files changed, 222 insertions(+), 118 deletions(-) diff --git a/jetstream/options.go b/jetstream/options.go index 566f0dc2e..b47268e99 100644 --- a/jetstream/options.go +++ b/jetstream/options.go @@ -121,18 +121,20 @@ func (max PullMaxMessages) configureMessages(opts *consumeOpts) error { type PullExpiry time.Duration func (exp PullExpiry) configureConsume(opts *consumeOpts) error { - if exp < 0 { - return fmt.Errorf("%w: expires value must be positive", ErrInvalidOption) + expiry := time.Duration(exp) + if expiry < 1*time.Second { + return fmt.Errorf("%w: expires value must be at least 1s", ErrInvalidOption) } - opts.Expires = time.Duration(exp) + opts.Expires = expiry return nil } func (exp PullExpiry) configureMessages(opts *consumeOpts) error { - if exp < 0 { - return fmt.Errorf("%w: expires value must be positive", ErrInvalidOption) + expiry := time.Duration(exp) + if expiry < 0 { + return fmt.Errorf("%w: expires value must be at least 1s", ErrInvalidOption) } - opts.Expires = time.Duration(exp) + opts.Expires = expiry return nil } @@ -191,8 +193,8 @@ type PullHeartbeat time.Duration func (hb PullHeartbeat) configureConsume(opts *consumeOpts) error { hbTime := time.Duration(hb) - if hbTime < 1*time.Second || hbTime > 30*time.Second { - return fmt.Errorf("%w: idle_heartbeat value must be within 1s-30s range", ErrInvalidOption) + if hbTime < 500*time.Millisecond || hbTime > 30*time.Second { + return fmt.Errorf("%w: idle_heartbeat value must be within 500ms-30s range", ErrInvalidOption) } opts.Heartbeat = hbTime return nil @@ -200,8 +202,8 @@ func (hb PullHeartbeat) configureConsume(opts *consumeOpts) error { func (hb PullHeartbeat) configureMessages(opts *consumeOpts) error { hbTime := time.Duration(hb) - if hbTime < 1*time.Second || hbTime > 30*time.Second { - return fmt.Errorf("%w: idle_heartbeat value must be within 1s-30s range", ErrInvalidOption) + if hbTime < 500*time.Millisecond || hbTime > 30*time.Second { + return fmt.Errorf("%w: idle_heartbeat value must be within 500ms-30s range", ErrInvalidOption) } opts.Heartbeat = hbTime return nil diff --git a/jetstream/pull.go b/jetstream/pull.go index 7029bb05c..a3601ece8 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -132,7 +132,6 @@ type ( const ( DefaultMaxMessages = 500 DefaultExpires = 30 * time.Second - DefaultHeartbeat = 5 * time.Second unset = -1 ) @@ -192,12 +191,17 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( } userMsg, msgErr := checkMsg(msg) if !userMsg && msgErr == nil { + if sub.hbMonitor != nil { + sub.hbMonitor.Reset(2 * consumeOpts.Heartbeat) + } return } defer func() { sub.Lock() sub.checkPending() - sub.hbMonitor.Reset(2 * consumeOpts.Heartbeat) + if sub.hbMonitor != nil { + sub.hbMonitor.Reset(2 * consumeOpts.Heartbeat) + } sub.Unlock() }() if !userMsg { @@ -305,6 +309,9 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( MaxBytes: sub.consumeOpts.MaxBytes, Heartbeat: sub.consumeOpts.Heartbeat, } + if sub.hbMonitor != nil { + sub.hbMonitor.Reset(2 * sub.consumeOpts.Heartbeat) + } sub.resetPendingMsgs() } sub.Unlock() @@ -325,6 +332,9 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( MaxBytes: sub.consumeOpts.MaxBytes, Heartbeat: sub.consumeOpts.Heartbeat, } + if sub.hbMonitor != nil { + sub.hbMonitor.Reset(2 * sub.consumeOpts.Heartbeat) + } sub.resetPendingMsgs() } sub.Unlock() @@ -465,7 +475,7 @@ func (s *pullSubscription) Next() (Msg, error) { if atomic.LoadUint32(&s.closed) == 1 { return nil, ErrMsgIteratorClosed } - hbMonitor := s.scheduleHeartbeatCheck(s.consumeOpts.Heartbeat) + hbMonitor := s.scheduleHeartbeatCheck(2 * s.consumeOpts.Heartbeat) defer func() { if hbMonitor != nil { hbMonitor.Stop() @@ -509,6 +519,9 @@ func (s *pullSubscription) Next() (Msg, error) { if s.consumeOpts.ReportMissingHeartbeats { return nil, err } + if hbMonitor != nil { + hbMonitor.Reset(2 * s.consumeOpts.Heartbeat) + } } if errors.Is(err, errConnected) { if !isConnected { @@ -531,12 +544,14 @@ func (s *pullSubscription) Next() (Msg, error) { } s.pending.msgCount = 0 s.pending.byteCount = 0 - hbMonitor = s.scheduleHeartbeatCheck(s.consumeOpts.Heartbeat) + if hbMonitor != nil { + hbMonitor.Reset(2 * s.consumeOpts.Heartbeat) + } } } if errors.Is(err, errDisconnected) { if hbMonitor != nil { - hbMonitor.Stop() + hbMonitor.Reset(2 * s.consumeOpts.Heartbeat) } isConnected = false } diff --git a/jetstream/test/helper_test.go b/jetstream/test/helper_test.go index ce7909dbe..a9dbae222 100644 --- a/jetstream/test/helper_test.go +++ b/jetstream/test/helper_test.go @@ -39,6 +39,14 @@ type jsServer struct { restart sync.Mutex } +// Restart can be used to start again a server +// using the same listen address as before. +func (srv *jsServer) Restart() { + srv.restart.Lock() + defer srv.restart.Unlock() + srv.Server = natsserver.RunServer(srv.myopts) +} + // Dumb wait program to sync on callbacks, etc... Will timeout func Wait(ch chan bool) error { return WaitTime(ch, 5*time.Second) diff --git a/jetstream/test/pull_test.go b/jetstream/test/pull_test.go index 981a44fc5..20354128c 100644 --- a/jetstream/test/pull_test.go +++ b/jetstream/test/pull_test.go @@ -2045,141 +2045,220 @@ func TestPullConsumerConsume(t *testing.T) { func TestPullConsumerConsume_WithCluster(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } } name := "cluster" - stream := jetstream.StreamConfig{ + singleStream := jetstream.StreamConfig{ Name: name, Replicas: 1, Subjects: []string{"FOO.*"}, } - t.Run("no options", func(t *testing.T) { - withJSClusterAndStream(t, name, 3, stream, func(t *testing.T, subject string, srvs ...*jsServer) { - nc, err := nats.Connect(srvs[0].ClientURL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + streamWithReplicas := jetstream.StreamConfig{ + Name: name, + Replicas: 3, + Subjects: []string{"FOO.*"}, + } - js, err := jetstream.New(nc) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer nc.Close() + for _, stream := range []jetstream.StreamConfig{singleStream, streamWithReplicas} { + t.Run(fmt.Sprintf("num replicas: %d, no options", stream.Replicas), func(t *testing.T) { + withJSClusterAndStream(t, name, 3, stream, func(t *testing.T, subject string, srvs ...*jsServer) { + nc, err := nats.Connect(srvs[0].ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - s, err := js.Stream(ctx, stream.Name) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() - msgs := make([]jetstream.Msg, 0) - wg := &sync.WaitGroup{} - wg.Add(len(testMsgs)) - l, err := c.Consume(func(msg jetstream.Msg) { - msgs = append(msgs, msg) - wg.Done() - }) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer l.Stop() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.Stream(ctx, stream.Name) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } - publishTestMsgs(t, nc) - wg.Wait() - if len(msgs) != len(testMsgs) { - t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) - } - for i, msg := range msgs { - if string(msg.Data()) != testMsgs[i] { - t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data())) + msgs := make([]jetstream.Msg, 0) + wg := &sync.WaitGroup{} + wg.Add(len(testMsgs)) + l, err := c.Consume(func(msg jetstream.Msg) { + msgs = append(msgs, msg) + wg.Done() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) } - } + defer l.Stop() + + publishTestMsgs(t, js) + wg.Wait() + if len(msgs) != len(testMsgs) { + t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) + } + for i, msg := range msgs { + if string(msg.Data()) != testMsgs[i] { + t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data())) + } + } + }) }) - }) - t.Run("subscribe, cancel subscription, then subscribe again", func(t *testing.T) { - withJSClusterAndStream(t, name, 3, stream, func(t *testing.T, subject string, srvs ...*jsServer) { - nc, err := nats.Connect(srvs[0].ClientURL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + t.Run(fmt.Sprintf("num replicas: %d, subscribe, cancel subscription, then subscribe again", stream.Replicas), func(t *testing.T) { + withJSClusterAndStream(t, name, 3, stream, func(t *testing.T, subject string, srvs ...*jsServer) { + nc, err := nats.Connect(srvs[0].ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } - js, err := jetstream.New(nc) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer nc.Close() + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - s, err := js.Stream(ctx, stream.Name) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.Stream(ctx, stream.Name) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } - wg := sync.WaitGroup{} - wg.Add(len(testMsgs)) - msgs := make([]jetstream.Msg, 0) - l, err := c.Consume(func(msg jetstream.Msg) { - if err := msg.Ack(); err != nil { + wg := sync.WaitGroup{} + wg.Add(len(testMsgs)) + msgs := make([]jetstream.Msg, 0) + l, err := c.Consume(func(msg jetstream.Msg) { + if err := msg.Ack(); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs = append(msgs, msg) + if len(msgs) == 5 { + cancel() + } + wg.Done() + }) + if err != nil { t.Fatalf("Unexpected error: %v", err) } - msgs = append(msgs, msg) - if len(msgs) == 5 { - cancel() + + publishTestMsgs(t, js) + wg.Wait() + l.Stop() + + time.Sleep(10 * time.Millisecond) + wg.Add(len(testMsgs)) + l, err = c.Consume(func(msg jetstream.Msg) { + if err := msg.Ack(); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs = append(msgs, msg) + wg.Done() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer l.Stop() + publishTestMsgs(t, js) + wg.Wait() + if len(msgs) != 2*len(testMsgs) { + t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) + } + expectedMsgs := append(testMsgs, testMsgs...) + for i, msg := range msgs { + if string(msg.Data()) != expectedMsgs[i] { + t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data())) + } } - wg.Done() }) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + }) - publishTestMsgs(t, nc) - wg.Wait() - l.Stop() + t.Run(fmt.Sprintf("num replicas: %d, recover consume after server restart", stream.Replicas), func(t *testing.T) { + withJSClusterAndStream(t, name, 3, stream, func(t *testing.T, subject string, srvs ...*jsServer) { + nc, err := nats.Connect(srvs[0].ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } - time.Sleep(10 * time.Millisecond) - wg.Add(len(testMsgs)) - defer cancel() - l, err = c.Consume(func(msg jetstream.Msg) { - if err := msg.Ack(); err != nil { + js, err := jetstream.New(nc) + if err != nil { t.Fatalf("Unexpected error: %v", err) } - msgs = append(msgs, msg) - wg.Done() - }) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer l.Stop() - publishTestMsgs(t, nc) - wg.Wait() - if len(msgs) != 2*len(testMsgs) { - t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) - } - expectedMsgs := append(testMsgs, testMsgs...) - for i, msg := range msgs { - if string(msg.Data()) != expectedMsgs[i] { - t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data())) + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.Stream(ctx, streamWithReplicas.Name) + if err != nil { + t.Fatalf("Unexpected error: %v", err) } - } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy, InactiveThreshold: 10 * time.Second}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + wg := sync.WaitGroup{} + wg.Add(len(testMsgs)) + msgs := make([]jetstream.Msg, 0) + l, err := c.Consume(func(msg jetstream.Msg) { + if err := msg.Ack(); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs = append(msgs, msg) + wg.Done() + }, jetstream.PullExpiry(1*time.Second), jetstream.PullHeartbeat(500*time.Millisecond)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer l.Stop() + + publishTestMsgs(t, js) + wg.Wait() + + time.Sleep(10 * time.Millisecond) + srvs[0].Shutdown() + srvs[1].Shutdown() + srvs[0].Restart() + srvs[1].Restart() + wg.Add(len(testMsgs)) + + for i := 0; i < 10; i++ { + time.Sleep(500 * time.Millisecond) + if _, err := js.Stream(context.Background(), stream.Name); err == nil { + break + } else if i == 9 { + t.Fatal("JetStream not recovered: ", err) + } + } + publishTestMsgs(t, js) + wg.Wait() + if len(msgs) != 2*len(testMsgs) { + t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) + } + expectedMsgs := append(testMsgs, testMsgs...) + for i, msg := range msgs { + if string(msg.Data()) != expectedMsgs[i] { + t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data())) + } + } + }) }) - }) + } } func TestPullConsumerNext(t *testing.T) {