Skip to content

Commit

Permalink
Cleanup all async operation during shutdown (#249)
Browse files Browse the repository at this point in the history
## Changes
- Cleanup async calls by wrapping all async calls with shutdown CTS that gets cancelled and disposed on PostStop
  • Loading branch information
Arkatufus authored Sep 23, 2022
1 parent 7d73823 commit c22a034
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 104 deletions.
94 changes: 57 additions & 37 deletions src/Akka.Persistence.Azure/Journal/AzureTableStorageJournal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class AzureTableStorageJournal : AsyncWriteJournal
private readonly TableServiceClient _tableServiceClient;
private TableClient _tableStorage_DoNotUseDirectly;
private readonly Dictionary<string, ISet<IActorRef>> _tagSubscribers = new Dictionary<string, ISet<IActorRef>>();
private readonly CancellationTokenSource _shutdownCts;

public AzureTableStorageJournal(Config config = null)
{
Expand Down Expand Up @@ -81,6 +82,8 @@ public AzureTableStorageJournal(Config config = null)
options: _settings.TableClientOptions)
: new TableServiceClient(connectionString: _settings.ConnectionString);
}

_shutdownCts = new CancellationTokenSource();
}

public TableClient Table
Expand All @@ -107,9 +110,9 @@ public override async Task<long> ReadHighestSequenceNrAsync(

_log.Debug("Entering method ReadHighestSequenceNrAsync");

var seqNo = await HighestSequenceNumberQuery(persistenceId)
var seqNo = await HighestSequenceNumberQuery(persistenceId, null, _shutdownCts.Token)
.Select(entity => entity.GetInt64(HighestSequenceNrEntry.HighestSequenceNrKey).Value)
.AggregateAsync(0L, Math.Max);
.AggregateAsync(0L, Math.Max, cancellationToken: _shutdownCts.Token);

_log.Debug("Leaving method ReadHighestSequenceNrAsync with SeqNo [{0}] for PersistentId [{1}]", seqNo, persistenceId);

Expand All @@ -132,7 +135,8 @@ public override async Task ReplayMessagesAsync(
if (max == 0)
return;

var pages = PersistentJournalEntryReplayQuery(persistenceId, fromSequenceNr, toSequenceNr).AsPages().GetAsyncEnumerator();
var pages = PersistentJournalEntryReplayQuery(persistenceId, fromSequenceNr, toSequenceNr, null, _shutdownCts.Token)
.AsPages().GetAsyncEnumerator(_shutdownCts.Token);

ValueTask<bool>? nextTask = pages.MoveNextAsync();
var count = 0L;
Expand Down Expand Up @@ -206,7 +210,8 @@ protected override async Task DeleteMessagesToAsync(string persistenceId, long t

_log.Debug("Entering method DeleteMessagesToAsync for persistentId [{0}] and up to seqNo [{1}]", persistenceId, toSequenceNr);

var pages = PersistentJournalEntryDeleteQuery(persistenceId, toSequenceNr).AsPages().GetAsyncEnumerator();
var pages = PersistentJournalEntryDeleteQuery(persistenceId, toSequenceNr, null, _shutdownCts.Token)
.AsPages().GetAsyncEnumerator(_shutdownCts.Token);

ValueTask<bool>? nextTask = pages.MoveNextAsync();
while (nextTask.HasValue)
Expand All @@ -227,7 +232,7 @@ protected override async Task DeleteMessagesToAsync(string persistenceId, long t
if (currentPage.Values.Count > 0)
{
await Table.SubmitTransactionAsync(currentPage.Values
.Select(entity => new TableTransactionAction(TableTransactionActionType.Delete, entity)));
.Select(entity => new TableTransactionAction(TableTransactionActionType.Delete, entity)), _shutdownCts.Token);
}
}

Expand All @@ -238,7 +243,7 @@ protected override void PreStart()
{
_log.Debug("Initializing Azure Table Storage...");

InitCloudStorage(5)
InitCloudStorage(5, _shutdownCts.Token)
.ConfigureAwait(false).GetAwaiter().GetResult();

_log.Debug("Successfully started Azure Table Storage!");
Expand All @@ -247,20 +252,27 @@ protected override void PreStart()
base.PreStart();
}

protected override void PostStop()
{
_shutdownCts.Cancel();
_shutdownCts.Dispose();
base.PostStop();
}

protected override bool ReceivePluginInternal(object message)
{
switch (message)
{
case ReplayTaggedMessages replay:
ReplayTaggedMessagesAsync(replay)
ReplayTaggedMessagesAsync(replay, _shutdownCts.Token)
.PipeTo(replay.ReplyTo, success: h => new RecoverySuccess(h), failure: e => new ReplayMessagesFailure(e));
break;
case SubscribePersistenceId subscribe:
AddPersistenceIdSubscriber(Sender, subscribe.PersistenceId);
Context.Watch(Sender);
break;
case SubscribeAllPersistenceIds subscribe:
AddAllPersistenceIdSubscriber(Sender);
case SubscribeAllPersistenceIds _:
AddAllPersistenceIdSubscriber(Sender, _shutdownCts.Token);
Context.Watch(Sender);
break;
case SubscribeTag subscribe:
Expand Down Expand Up @@ -353,7 +365,7 @@ protected override async Task<IImmutableList<Exception>> WriteMessagesAsync(IEnu
if (_log.IsDebugEnabled && _settings.VerboseLogging)
_log.Debug("Attempting to write batch of {0} messages to Azure storage", batchItems.Count);

var response = await Table.SubmitTransactionAsync(batchItems);
var response = await Table.SubmitTransactionAsync(batchItems, _shutdownCts.Token);
if (_log.IsDebugEnabled && _settings.VerboseLogging)
{
foreach (var r in response.Value)
Expand Down Expand Up @@ -383,7 +395,7 @@ protected override async Task<IImmutableList<Exception>> WriteMessagesAsync(IEnu
new AllPersistenceIdsEntry(PartitionKeyEscapeHelper.Escape(item.Key)).WriteEntity()));
}

var allPersistenceResponse = await Table.SubmitTransactionAsync(allPersistenceIdsBatch);
var allPersistenceResponse = await Table.SubmitTransactionAsync(allPersistenceIdsBatch, _shutdownCts.Token);

if (_log.IsDebugEnabled && _settings.VerboseLogging)
foreach (var r in allPersistenceResponse.Value)
Expand All @@ -405,7 +417,7 @@ protected override async Task<IImmutableList<Exception>> WriteMessagesAsync(IEnu
eventTagsBatch.Add(new TableTransactionAction(TableTransactionActionType.UpsertReplace, item.WriteEntity()));
}

var eventTagsResponse = await Table.SubmitTransactionAsync(eventTagsBatch);
var eventTagsResponse = await Table.SubmitTransactionAsync(eventTagsBatch, _shutdownCts.Token);

if (_log.IsDebugEnabled && _settings.VerboseLogging)
foreach (var r in eventTagsResponse.Value)
Expand Down Expand Up @@ -439,8 +451,8 @@ protected override async Task<IImmutableList<Exception>> WriteMessagesAsync(IEnu
}

private AsyncPageable<TableEntity> GenerateAllPersistenceIdsQuery(
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
return Table.QueryAsync<TableEntity>(
filter: $"PartitionKey eq '{AllPersistenceIdsEntry.PartitionKeyValue}'",
Expand All @@ -452,8 +464,8 @@ private AsyncPageable<TableEntity> GenerateAllPersistenceIdsQuery(

private AsyncPageable<TableEntity> HighestSequenceNumberQuery(
string persistenceId,
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
return Table.QueryAsync<TableEntity>(
filter: $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(persistenceId)}' and " +
Expand All @@ -467,8 +479,8 @@ private AsyncPageable<TableEntity> HighestSequenceNumberQuery(
private AsyncPageable<TableEntity> PersistentJournalEntryDeleteQuery(
string persistenceId,
long toSequenceNr,
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
return Table.QueryAsync<TableEntity>(
filter: $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(persistenceId)}' and " +
Expand All @@ -483,8 +495,8 @@ private AsyncPageable<TableEntity> EventTagEntryDeleteQuery(
string persistenceId,
long fromSequenceNr,
long toSequenceNr,
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
return Table.QueryAsync<TableEntity>(
filter: $"PartitionKey eq '{EventTagEntry.PartitionKeyValue}' and " +
Expand All @@ -500,8 +512,8 @@ private AsyncPageable<TableEntity> PersistentJournalEntryReplayQuery(
string persistentId,
long fromSequenceNumber,
long toSequenceNumber,
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
var filter = $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(persistentId)}' and " +
$"RowKey ne '{HighestSequenceNrEntry.RowKeyValue}'";
Expand All @@ -517,8 +529,8 @@ private AsyncPageable<TableEntity> PersistentJournalEntryReplayQuery(

private AsyncPageable<TableEntity> TaggedMessageQuery(
ReplayTaggedMessages replay,
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
return Table.QueryAsync<TableEntity>(
filter: $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(EventTagEntry.GetPartitionKey(replay.Tag))}' and " +
Expand All @@ -529,13 +541,13 @@ private AsyncPageable<TableEntity> TaggedMessageQuery(
cancellationToken: cancellationToken);
}

private async Task AddAllPersistenceIdSubscriber(IActorRef subscriber)
private async Task AddAllPersistenceIdSubscriber(IActorRef subscriber, CancellationToken cancellationToken)
{
lock (_allPersistenceIdSubscribers)
{
_allPersistenceIdSubscribers.Add(subscriber);
}
subscriber.Tell(new CurrentPersistenceIds(await GetAllPersistenceIds()));
subscriber.Tell(new CurrentPersistenceIds(await GetAllPersistenceIds(cancellationToken)));
}

private void AddPersistenceIdSubscriber(IActorRef subscriber, string persistenceId)
Expand All @@ -560,18 +572,21 @@ private void AddTagSubscriber(IActorRef subscriber, string tag)
subscriptions.Add(subscriber);
}

private async Task<IEnumerable<string>> GetAllPersistenceIds()
private async Task<IEnumerable<string>> GetAllPersistenceIds(CancellationToken cancellationToken)
{
return await GenerateAllPersistenceIdsQuery().Select(item => item.RowKey).ToListAsync();
return await GenerateAllPersistenceIdsQuery(null, cancellationToken)
.Select(item => item.RowKey).ToListAsync(cancellationToken);
}

private async Task InitCloudStorage(int remainingTries)
private async Task InitCloudStorage(int remainingTries, CancellationToken cancellationToken)
{
try
{
var tableClient = _tableServiceClient.GetTableClient(_settings.TableName);

using (var cts = new CancellationTokenSource(_settings.ConnectTimeout))
var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cts.CancelAfter(_settings.ConnectTimeout);
using (cts)
{
if (!_settings.AutoInitialize)
{
Expand Down Expand Up @@ -603,15 +618,19 @@ private async Task InitCloudStorage(int remainingTries)
_log.Error(ex, "[{0}] more tries to initialize table storage remaining...", remainingTries);
if (remainingTries == 0)
throw;
await Task.Delay(RetryInterval[remainingTries]);
await InitCloudStorage(remainingTries - 1);

await Task.Delay(RetryInterval[remainingTries], cancellationToken);
if (cancellationToken.IsCancellationRequested)
throw;

await InitCloudStorage(remainingTries - 1, cancellationToken);
}
}

private async Task<bool> IsTableExist(string name, CancellationToken token)
private async Task<bool> IsTableExist(string name, CancellationToken cancellationToken)
{
var tables = await _tableServiceClient.QueryAsync(t => t.Name == name, cancellationToken: token)
.ToListAsync(token)
var tables = await _tableServiceClient.QueryAsync(t => t.Name == name, cancellationToken: cancellationToken)
.ToListAsync(cancellationToken)
.ConfigureAwait(false);
return tables.Count > 0;
}
Expand Down Expand Up @@ -657,15 +676,16 @@ private void RemoveSubscriber(
/// Replays all events with given tag within provided boundaries from current database.
/// </summary>
/// <param name="replay">TBD</param>
/// <param name="cancellationToken"></param>
/// <returns>TBD</returns>
private async Task<long> ReplayTaggedMessagesAsync(ReplayTaggedMessages replay)
private async Task<long> ReplayTaggedMessagesAsync(ReplayTaggedMessages replay, CancellationToken cancellationToken)
{
// In order to actually break at the limit we ask for we have to
// keep a separate counter and track it ourselves.
var counter = 0;
var maxOrderingId = 0L;

var pages = TaggedMessageQuery(replay).AsPages().GetAsyncEnumerator();
var pages = TaggedMessageQuery(replay, null, cancellationToken).AsPages().GetAsyncEnumerator(cancellationToken);
ValueTask<bool>? nextTask = pages.MoveNextAsync();

while (nextTask != null)
Expand Down
Loading

0 comments on commit c22a034

Please sign in to comment.