Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update consume backpressure handling #508

Closed
wants to merge 10 commits into from
114 changes: 69 additions & 45 deletions src/NATS.Client.JetStream/Internal/NatsJSConsume.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using Microsoft.Extensions.Logging;
using NATS.Client.Core;
using NATS.Client.Core.Commands;
using NATS.Client.Core.Internal;
using NATS.Client.JetStream.Models;

namespace NATS.Client.JetStream.Internal;
Expand Down Expand Up @@ -128,11 +127,9 @@ public NatsJSConsume(
Timeout.Infinite,
Timeout.Infinite);

// This channel is used to pass messages
// to the user from the subscription channel (which should be set to a
// sufficiently large value to avoid blocking socket reads in the
// NATS connection).
_userMsgs = Channel.CreateBounded<NatsJSMsg<TMsg>>(1000);
// This channel is used to pass messages to the user from the subscription channel.
// Capacity is bounded to a maximum of 1024 to avoid LOH allocations.
_userMsgs = Channel.CreateBounded<NatsJSMsg<TMsg>>(1024);
Msgs = _userMsgs.Reader;

// Capacity as 1 is enough here since it's used for signaling only.
Expand Down Expand Up @@ -164,6 +161,28 @@ public ValueTask CallMsgNextAsync(string origin, ConsumerGetnextRequest request,

public void ResetHeartbeatTimer() => _timer.Change(_hbTimeout, _hbTimeout);

public void Delivered(int msgSize)
{
lock (_pendingGate)
{
if (_pendingMsgs > 0)
_pendingMsgs--;
}

if (_maxBytes > 0)
{
if (_debug)
_logger.LogDebug(NatsJSLogEvents.MessageProperty, "Message size {Size}", msgSize);

lock (_pendingGate)
{
_pendingBytes -= msgSize;
}
}

CheckPending("delivered");
}

public override async ValueTask DisposeAsync()
{
Interlocked.Exchange(ref _disposed, 1);
Expand All @@ -183,26 +202,48 @@ public override async ValueTask DisposeAsync()
internal override async ValueTask WriteReconnectCommandsAsync(CommandWriter commandWriter, int sid)
{
await base.WriteReconnectCommandsAsync(commandWriter, sid);
ResetPending();

var request = new ConsumerGetnextRequest
{
Batch = _maxMsgs,
MaxBytes = _maxBytes,
IdleHeartbeat = _idle,
Expires = _expires,
};

if (_cancellationToken.IsCancellationRequested)
return;

await commandWriter.PublishAsync(
subject: $"{_context.Opts.Prefix}.CONSUMER.MSG.NEXT.{_stream}.{_consumer}",
value: request,
headers: default,
replyTo: Subject,
serializer: NatsJSJsonSerializer<ConsumerGetnextRequest>.Default,
cancellationToken: CancellationToken.None);
long maxMsgs = 0;
long maxBytes = 0;

// We have to do the pending check here because we can't access
// the publish method here since the connection state is not open yet
// and we're just writing the reconnect commands.
lock (_pendingGate)
{
if (_maxBytes > 0 && _pendingBytes <= _thresholdBytes)
{
maxBytes = _maxBytes - _pendingBytes;
}
else if (_maxBytes == 0 && _pendingMsgs <= _thresholdMsgs && _pendingMsgs < _maxMsgs)
{
maxMsgs = _maxMsgs - _pendingMsgs;
}
}

if (maxMsgs > 0 || maxBytes > 0)
{
var request = new ConsumerGetnextRequest
{
Batch = maxMsgs,
MaxBytes = maxBytes,
IdleHeartbeat = _idle,
Expires = _expires,
};

await commandWriter.PublishAsync(
subject: $"{_context.Opts.Prefix}.CONSUMER.MSG.NEXT.{_stream}.{_consumer}",
value: request,
headers: default,
replyTo: Subject,
serializer: NatsJSJsonSerializer<ConsumerGetnextRequest>.Default,
cancellationToken: CancellationToken.None);

ResetPending();
}
}

protected override async ValueTask ReceiveInternalAsync(
Expand Down Expand Up @@ -323,6 +364,8 @@ protected override async ValueTask ReceiveInternalAsync(
{
throw new NatsJSException("No header found");
}

CheckPending("control-msg");
}
else
{
Expand All @@ -337,35 +380,16 @@ protected override async ValueTask ReceiveInternalAsync(
_serializer),
_context);

lock (_pendingGate)
{
if (_pendingMsgs > 0)
_pendingMsgs--;
}

if (_maxBytes > 0)
{
if (_debug)
_logger.LogDebug(NatsJSLogEvents.MessageProperty, "Message size {Size}", msg.Size);

lock (_pendingGate)
{
_pendingBytes -= msg.Size;
}
}

// Stop feeding the user if we are disposed.
// We need to exit as soon as possible.
if (Volatile.Read(ref _disposed) == 0)
{
// We can't pass cancellation token here because we need to hand
// the message to the user to be processed. Writer will be completed
// when the user calls Stop() or when the subscription is closed.
await _userMsgs.Writer.WriteAsync(msg).ConfigureAwait(false);
await _userMsgs.Writer.WriteAsync(msg, CancellationToken.None).ConfigureAwait(false);
}
}

CheckPending();
}

protected override void TryComplete()
Expand All @@ -383,7 +407,7 @@ private void ResetPending()
}
}

private void CheckPending()
private void CheckPending(string origin)
{
lock (_pendingGate)
{
Expand All @@ -392,15 +416,15 @@ private void CheckPending()
if (_debug)
_logger.LogDebug(NatsJSLogEvents.PendingCount, "Check pending bytes {Pending}, {MaxBytes}", _pendingBytes, _maxBytes);

Pull("chk-bytes", _maxMsgs, _maxBytes - _pendingBytes);
Pull($"chk-bytes({origin})", _maxMsgs, _maxBytes - _pendingBytes);
ResetPending();
}
else if (_maxBytes == 0 && _pendingMsgs <= _thresholdMsgs && _pendingMsgs < _maxMsgs)
{
if (_debug)
_logger.LogDebug(NatsJSLogEvents.PendingCount, "Check pending messages {Pending}, {MaxMsgs}", _pendingMsgs, _maxMsgs);

Pull("chk-msgs", _maxMsgs - _pendingMsgs, 0);
Pull($"chk-msgs({origin})", _maxMsgs - _pendingMsgs, 0);
ResetPending();
}
}
Expand Down
1 change: 1 addition & 0 deletions src/NATS.Client.JetStream/NatsJSConsumer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public async IAsyncEnumerable<NatsJSMsg<T>> ConsumeAsync<T>(
break;

yield return jsMsg;
cc.Delivered(jsMsg.Size);
}
}
}
Expand Down
74 changes: 64 additions & 10 deletions tests/NATS.Client.JetStream.Tests/ConsumerConsumeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ public async Task Consume_msgs_test()
var consumerOpts = new NatsJSConsumeOpts { MaxMsgs = 10 };
var consumer = (NatsJSConsumer)await js.GetConsumerAsync("s1", "c1", cts.Token);
var count = 0;
await using var cc = await consumer.ConsumeInternalAsync<TestData>(serializer: TestDataJsonSerializer<TestData>.Default, consumerOpts, cancellationToken: cts.Token);
await foreach (var msg in cc.Msgs.ReadAllAsync(cts.Token))
await foreach (var msg in consumer.ConsumeAsync(serializer: TestDataJsonSerializer<TestData>.Default, consumerOpts, cancellationToken: cts.Token))
{
await msg.AckAsync(cancellationToken: cts.Token);
Assert.Equal(count, msg.Data!.Test);
Expand All @@ -92,7 +91,7 @@ public async Task Consume_msgs_test()

await Retry.Until(
reason: "received enough pulls",
condition: () => PullCount() > 5,
condition: () => PullCount() >= 4,
action: () =>
{
_output.WriteLine($"### PullCount:{PullCount()}");
Expand Down Expand Up @@ -215,12 +214,10 @@ public async Task Consume_reconnect_test()
// Not interested in management messages sent upto this point
await proxy.FlushFramesAsync(nats);

var cc = await consumer.ConsumeInternalAsync<TestData>(serializer: TestDataJsonSerializer<TestData>.Default, consumerOpts, cancellationToken: cts.Token);

var readerTask = Task.Run(async () =>
{
var count = 0;
await foreach (var msg in cc.Msgs.ReadAllAsync(cts.Token))
await foreach (var msg in consumer.ConsumeAsync<TestData>(serializer: TestDataJsonSerializer<TestData>.Default, consumerOpts, cancellationToken: cts.Token))
{
await msg.AckAsync(cancellationToken: cts.Token);
Assert.Equal(count, msg.Data!.Test);
Expand All @@ -230,6 +227,8 @@ public async Task Consume_reconnect_test()
if (count == 2)
break;
}

return count;
});

// Send a message before reconnect
Expand Down Expand Up @@ -258,11 +257,9 @@ await Retry.Until(
ack.EnsureSuccess();
}

await Retry.Until(
"acked",
() => proxy.ClientFrames.Any(f => f.Message.Contains("CONSUMER.MSG.NEXT")));
var count = await readerTask;
Assert.Equal(2, count);

await readerTask;
await nats.DisposeAsync();
}

Expand Down Expand Up @@ -446,4 +443,61 @@ public async Task Serialization_errors()
break;
}
}

[Fact]
public async Task Consume_right_amount_of_messages()
{
await using var server = NatsServer.StartJS();
await using var nats = server.CreateClientConnection();

var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));

var js = new NatsJSContext(nats);
await js.CreateStreamAsync("s1", ["s1.*"], cts.Token);

var payload = new byte[1024];
for (var i = 0; i < 50; i++)
{
var ack = await js.PublishAsync("s1.foo", payload, cancellationToken: cts.Token);
ack.EnsureSuccess();
}

// Max messages
{
var consumer = await js.CreateOrUpdateConsumerAsync("s1", "c1", cancellationToken: cts.Token);
var opts = new NatsJSConsumeOpts { MaxMsgs = 10, };
var count = 0;
await foreach (var msg in consumer.ConsumeAsync<byte[]>(opts: opts, cancellationToken: cts.Token))
{
await msg.AckAsync(cancellationToken: cts.Token);
if (++count == 4)
break;
}

await Retry.Until("consumer stats updated", async () =>
{
var info = (await js.GetConsumerAsync("s1", "c1", cts.Token)).Info;
return info is { NumAckPending: 6, NumPending: 40 };
});
}

// Max bytes
{
var consumer = await js.CreateOrUpdateConsumerAsync("s1", "c2", cancellationToken: cts.Token);
var opts = new NatsJSConsumeOpts { MaxBytes = 10 * (1024 + 50), };
var count = 0;
await foreach (var msg in consumer.ConsumeAsync<byte[]>(opts: opts, cancellationToken: cts.Token))
{
await msg.AckAsync(cancellationToken: cts.Token);
if (++count == 4)
break;
}

await Retry.Until("consumer stats updated", async () =>
{
var info = (await js.GetConsumerAsync("s1", "c2", cts.Token)).Info;
return info is { NumAckPending: 6, NumPending: 40 };
});
}
}
}
Loading