From 0eb279beaf5175d5b4e8342f6d313411c748eae7 Mon Sep 17 00:00:00 2001 From: Varun Puranik Date: Fri, 30 Nov 2018 18:44:34 -0800 Subject: [PATCH] Add timeout / cancellation support to Store apis (#535) * Add cancellation token and timed store * Execute until cancelled changes * Fix build * Cleanup code.. * Adding tests * Remove private * Fix InMemoryDbStore * Fix setup * Don't dispose the tasks.. * Fix ExecuteUntilCancelled * Fix formatting --- ...osoft.Azure.Devices.Edge.Agent.Core.csproj | 2 +- .../storage/CheckpointStore.cs | 9 +- .../modules/RoutingModule.cs | 7 +- .../routing/RoutingTest.cs | 2 +- .../storage/CheckpointStoreTest.cs | 6 +- .../storage/MessageStoreTest.cs | 4 +- .../ColumnFamilyDbStore.cs | 82 +++++++---- .../ColumnFamilyStorageRocksDbWrapper.cs | 8 +- .../DbStoreProvider.cs | 9 +- .../IRocksDb.cs | 1 + .../RocksDbWrapper.cs | 4 +- .../EncryptedStore.cs | 48 +++++-- .../EntityStore.cs | 103 +++++++++----- .../IDbStore.cs | 5 +- .../IEntityStore.cs | 9 ++ .../IKeyValueStore.cs | 24 +++- .../ISequentialStore.cs | 9 +- .../IStoreProvider.cs | 2 +- .../InMemoryDbStore.cs | 103 ++++++-------- ...icrosoft.Azure.Devices.Edge.Storage.csproj | 3 + .../NullKeyValueStore.cs | 34 +++-- .../SequentialStore.cs | 44 +++--- .../SerDeExtensions.cs | 1 + .../StoreProvider.cs | 12 +- .../StoreUtils.cs | 5 +- .../TimedEntityStore.cs | 134 ++++++++++++++++++ .../TaskEx.cs | 121 ++++++++++++++-- .../ColumnFamilyStoreTest.cs | 2 +- .../TaskExTest.cs | 111 +++++++++++++++ 29 files changed, 703 insertions(+), 201 deletions(-) create mode 100644 edge-util/src/Microsoft.Azure.Devices.Edge.Storage/TimedEntityStore.cs diff --git a/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.Core/Microsoft.Azure.Devices.Edge.Agent.Core.csproj b/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.Core/Microsoft.Azure.Devices.Edge.Agent.Core.csproj index 493fc28315f..0854625358d 100644 --- a/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.Core/Microsoft.Azure.Devices.Edge.Agent.Core.csproj +++ b/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.Core/Microsoft.Azure.Devices.Edge.Agent.Core.csproj @@ -22,7 +22,7 @@ - + diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/storage/CheckpointStore.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/storage/CheckpointStore.cs index 68e4d3a0c5e..cac78ada2ad 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/storage/CheckpointStore.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/storage/CheckpointStore.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. namespace Microsoft.Azure.Devices.Edge.Hub.Core.Storage { using System; @@ -20,10 +20,9 @@ public class CheckpointStore : ICheckpointStore this.underlyingStore = underlyingStore; } - public static CheckpointStore Create(IDbStoreProvider dbStoreProvider) + public static CheckpointStore Create(IStoreProvider storeProvider) { - IDbStore dbStore = Preconditions.CheckNotNull(dbStoreProvider, nameof(dbStoreProvider)).GetDbStore(Constants.CheckpointStorePartitionKey); - IEntityStore underlyingStore = new EntityStore(dbStore, nameof(CheckpointEntity), 12); + IEntityStore underlyingStore = Preconditions.CheckNotNull(storeProvider, nameof(storeProvider)).GetEntityStore(Constants.CheckpointStorePartitionKey); return new CheckpointStore(underlyingStore); } @@ -89,4 +88,4 @@ public CheckpointEntity(long offset, DateTime? lastFailedRevivalTime, DateTime? public DateTime? UnhealthySince { get; } } } -} \ No newline at end of file +} diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/RoutingModule.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/RoutingModule.cs index cae54d10665..243c02a115a 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/RoutingModule.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/RoutingModule.cs @@ -307,7 +307,12 @@ protected override void Load(ContainerBuilder builder) .SingleInstance(); // ICheckpointStore - builder.Register(c => CheckpointStore.Create(c.Resolve())) + builder.Register(c => + { + var dbStoreProvider = c.Resolve(); + IStoreProvider storeProvider = new StoreProvider(dbStoreProvider); + return CheckpointStore.Create(storeProvider); + }) .As() .SingleInstance(); diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/routing/RoutingTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/routing/RoutingTest.cs index f8f79ad3954..cce5ac2cbf2 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/routing/RoutingTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/routing/RoutingTest.cs @@ -434,7 +434,7 @@ public async Task TestRoutingTwinChangeNotificationFromModule() var routerConfig = new RouterConfig(endpoints, routesList); IDbStoreProvider dbStoreProvider = new InMemoryDbStoreProvider(); IStoreProvider storeProvider = new StoreProvider(dbStoreProvider); - IMessageStore messageStore = new MessageStore(storeProvider, CheckpointStore.Create(dbStoreProvider), TimeSpan.MaxValue); + IMessageStore messageStore = new MessageStore(storeProvider, CheckpointStore.Create(storeProvider), TimeSpan.MaxValue); IEndpointExecutorFactory endpointExecutorFactory = new StoringAsyncEndpointExecutorFactory(endpointExecutorConfig, new AsyncEndpointExecutorOptions(1, TimeSpan.FromMilliseconds(10)), messageStore); Router router = await Router.CreateAsync(Guid.NewGuid().ToString(), iotHubName, routerConfig, endpointExecutorFactory); ITwinManager twinManager = new TwinManager(connectionManager, new TwinCollectionMessageConverter(), new TwinMessageConverter(), Option.None>()); diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/storage/CheckpointStoreTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/storage/CheckpointStoreTest.cs index 27bc281a120..4540ebf1f1c 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/storage/CheckpointStoreTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/storage/CheckpointStoreTest.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. namespace Microsoft.Azure.Devices.Edge.Hub.Core.Test.Storage { using System; @@ -18,7 +18,7 @@ public class CheckpointStoreTest [Fact] public async Task CheckpointStoreBasicTest() { - ICheckpointStore checkpointStore = CheckpointStore.Create(new InMemoryDbStoreProvider()); + ICheckpointStore checkpointStore = CheckpointStore.Create(new StoreProvider(new InMemoryDbStoreProvider())); for (long i = 0; i < 10; i++) { @@ -88,4 +88,4 @@ public void GetCheckpointDataTest() Assert.Equal(unhealthySinceTime, checkpointData2.UnhealthySince.OrDefault()); } } -} \ No newline at end of file +} diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/storage/MessageStoreTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/storage/MessageStoreTest.cs index ab5ac8cbf64..1c91e910acb 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/storage/MessageStoreTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/storage/MessageStoreTest.cs @@ -209,7 +209,7 @@ public async Task MessageStoreAddRemoveEndpointTest() // Arrange var dbStoreProvider = new InMemoryDbStoreProvider(); IStoreProvider storeProvider = new StoreProvider(dbStoreProvider); - ICheckpointStore checkpointStore = CheckpointStore.Create(dbStoreProvider); + ICheckpointStore checkpointStore = CheckpointStore.Create(storeProvider); IMessageStore messageStore = new MessageStore(storeProvider, checkpointStore, TimeSpan.FromHours(1)); // Act @@ -278,7 +278,7 @@ IMessage GetMessage(int i) { var dbStoreProvider = new InMemoryDbStoreProvider(); IStoreProvider storeProvider = new StoreProvider(dbStoreProvider); - ICheckpointStore checkpointStore = CheckpointStore.Create(dbStoreProvider); + ICheckpointStore checkpointStore = CheckpointStore.Create(storeProvider); IMessageStore messageStore = new MessageStore(storeProvider, checkpointStore, TimeSpan.FromSeconds(ttlSecs)); await messageStore.AddEndpoint("module1"); await messageStore.AddEndpoint("module2"); diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/ColumnFamilyDbStore.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/ColumnFamilyDbStore.cs index aad7f5d4984..f2170691eb5 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/ColumnFamilyDbStore.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/ColumnFamilyDbStore.cs @@ -3,6 +3,7 @@ namespace Microsoft.Azure.Devices.Edge.Storage.RocksDb { using System; + using System.Threading; using System.Threading.Tasks; using App.Metrics; using App.Metrics.Timer; @@ -13,108 +14,129 @@ namespace Microsoft.Azure.Devices.Edge.Storage.RocksDb class ColumnFamilyDbStore : IDbStore { readonly IRocksDb db; - + public ColumnFamilyDbStore(IRocksDb db, ColumnFamilyHandle handle) { this.db = Preconditions.CheckNotNull(db, nameof(db)); this.Handle = Preconditions.CheckNotNull(handle, nameof(handle)); } - internal ColumnFamilyHandle Handle { get; } + internal ColumnFamilyHandle Handle { get; } + + public Task Put(byte[] key, byte[] value) => this.Put(key, value, CancellationToken.None); + + public Task> Get(byte[] key) => this.Get(key, CancellationToken.None); + + public Task Remove(byte[] key) => this.Remove(key, CancellationToken.None); + + public Task Contains(byte[] key) => this.Contains(key, CancellationToken.None); + + public Task> GetFirstEntry() => this.GetFirstEntry(CancellationToken.None); - public Task> Get(byte[] key) + public Task> GetLastEntry() => this.GetLastEntry(CancellationToken.None); + + public Task IterateBatch(int batchSize, Func perEntityCallback) => this.IterateBatch(batchSize, perEntityCallback, CancellationToken.None); + + public Task IterateBatch(byte[] startKey, int batchSize, Func perEntityCallback) => this.IterateBatch(startKey, batchSize, perEntityCallback, CancellationToken.None); + + public async Task> Get(byte[] key, CancellationToken cancellationToken) { Preconditions.CheckNotNull(key, nameof(key)); Option returnValue; using (Metrics.DbGetLatency("all")) { - byte[] value = this.db.Get(key, this.Handle); + Func operation = () => this.db.Get(key, this.Handle); + byte[] value = await operation.ExecuteUntilCancelled(cancellationToken); returnValue = value != null ? Option.Some(value) : Option.None(); } - return Task.FromResult(returnValue); + + return returnValue; } - public Task Put(byte[] key, byte[] value) + public Task Put(byte[] key, byte[] value, CancellationToken cancellationToken) { Preconditions.CheckNotNull(key, nameof(key)); Preconditions.CheckNotNull(value, nameof(value)); using (Metrics.DbPutLatency("all")) { - this.db.Put(key, value, this.Handle); + Action operation = () => this.db.Put(key, value, this.Handle); + return operation.ExecuteUntilCancelled(cancellationToken); } - return Task.CompletedTask; } - public Task Remove(byte[] key) + public Task Remove(byte[] key, CancellationToken cancellationToken) { Preconditions.CheckNotNull(key, nameof(key)); - this.db.Remove(key, this.Handle); - return Task.CompletedTask; - } + Action operation = () => this.db.Remove(key, this.Handle); + return operation.ExecuteUntilCancelled(cancellationToken); + } - public Task> GetLastEntry() + public async Task> GetLastEntry(CancellationToken cancellationToken) { using (Iterator iterator = this.db.NewIterator(this.Handle)) { - iterator.SeekToLast(); + Action operation = () => iterator.SeekToLast(); + await operation.ExecuteUntilCancelled(cancellationToken); if (iterator.Valid()) { byte[] key = iterator.Key(); byte[] value = iterator.Value(); - return Task.FromResult(Option.Some((key, value))); + return Option.Some((key, value)); } else { - return Task.FromResult(Option.None<(byte[], byte[])>()); + return Option.None<(byte[], byte[])>(); } } } - public Task> GetFirstEntry() + public async Task> GetFirstEntry(CancellationToken cancellationToken) { using (Iterator iterator = this.db.NewIterator(this.Handle)) { - iterator.SeekToFirst(); + Action operation = () => iterator.SeekToFirst(); + await operation.ExecuteUntilCancelled(cancellationToken); if (iterator.Valid()) { byte[] key = iterator.Key(); byte[] value = iterator.Value(); - return Task.FromResult(Option.Some((key, value))); + return Option.Some((key, value)); } else { - return Task.FromResult(Option.None<(byte[], byte[])>()); + return Option.None<(byte[], byte[])>(); } } } - public Task Contains(byte[] key) + public async Task Contains(byte[] key, CancellationToken cancellationToken) { Preconditions.CheckNotNull(key, nameof(key)); - byte[] value = this.db.Get(key, this.Handle); - return Task.FromResult(value != null); + Func operation = () => this.db.Get(key, this.Handle); + byte[] value = await operation.ExecuteUntilCancelled(cancellationToken); + return value != null; } - public Task IterateBatch(byte[] startKey, int batchSize, Func callback) + public Task IterateBatch(byte[] startKey, int batchSize, Func callback, CancellationToken cancellationToken) { Preconditions.CheckNotNull(startKey, nameof(startKey)); Preconditions.CheckRange(batchSize, 1, nameof(batchSize)); Preconditions.CheckNotNull(callback, nameof(callback)); - return this.IterateBatch(iterator => iterator.Seek(startKey), batchSize, callback); + return this.IterateBatch(iterator => iterator.Seek(startKey), batchSize, callback, cancellationToken); } - public Task IterateBatch(int batchSize, Func callback) + public Task IterateBatch(int batchSize, Func callback, CancellationToken cancellationToken) { Preconditions.CheckRange(batchSize, 1, nameof(batchSize)); Preconditions.CheckNotNull(callback, nameof(callback)); - return this.IterateBatch(iterator => iterator.SeekToFirst(), batchSize, callback); + return this.IterateBatch(iterator => iterator.SeekToFirst(), batchSize, callback, cancellationToken); } - async Task IterateBatch(Action seeker, int batchSize, Func callback) + async Task IterateBatch(Action seeker, int batchSize, Func callback, CancellationToken cancellationToken) { // Use tailing iterator to prevent creating a snapshot. var readOptions = new ReadOptions(); @@ -128,7 +150,7 @@ async Task IterateBatch(Action seeker, int batchSize, Func ListColumnFamilies() { return this.columnFamiliesProvider.ListColumnFamilies(); } - } + } public ColumnFamilyHandle GetColumnFamily(string columnFamilyName) => this.db.GetColumnFamily(columnFamilyName); @@ -78,10 +80,10 @@ public ColumnFamilyHandle CreateColumnFamily(ColumnFamilyOptions columnFamilyOpt lock (ColumnFamiliesLock) { this.columnFamiliesProvider.AddColumnFamily(entityName); - ColumnFamilyHandle handle = this.db.CreateColumnFamily(columnFamilyOptions, entityName); + ColumnFamilyHandle handle = this.db.CreateColumnFamily(columnFamilyOptions, entityName); return handle; } - } + } public void DropColumnFamily(string columnFamilyName) { diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/DbStoreProvider.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/DbStoreProvider.cs index fab8abb2de0..712c117b17f 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/DbStoreProvider.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/DbStoreProvider.cs @@ -29,12 +29,12 @@ public class DbStoreProvider : IDbStoreProvider this.compactionTimer = new Timer(this.RunCompaction, null, CompactionPeriod, CompactionPeriod); } - private void RunCompaction(object state) + void RunCompaction(object state) { Events.StartingCompaction(); foreach (KeyValuePair entityDbStore in this.entityDbStoreDictionary) { - if(entityDbStore.Value is ColumnFamilyDbStore cfDbStore) + if (entityDbStore.Value is ColumnFamilyDbStore cfDbStore) { Events.CompactingStore(entityDbStore.Key); this.db.Compact(cfDbStore.Handle); @@ -53,6 +53,7 @@ public static DbStoreProvider Create(IRocksDbOptionsProvider optionsProvider, st var dbStorePartition = new ColumnFamilyDbStore(db, handle); entityDbStoreDictionary[columnFamilyName] = dbStorePartition; } + var dbStore = new DbStoreProvider(optionsProvider, db, entityDbStoreDictionary); return dbStore; } @@ -61,11 +62,12 @@ public IDbStore GetDbStore(string partitionName) { Preconditions.CheckNonWhiteSpace(partitionName, nameof(partitionName)); if (!this.entityDbStoreDictionary.TryGetValue(partitionName, out IDbStore entityDbStore)) - { + { ColumnFamilyHandle handle = this.db.CreateColumnFamily(this.optionsProvider.GetColumnFamilyOptions(), partitionName); entityDbStore = new ColumnFamilyDbStore(this.db, handle); entityDbStore = this.entityDbStoreDictionary.GetOrAdd(partitionName, entityDbStore); } + return entityDbStore; } @@ -102,6 +104,7 @@ public void Dispose() static class Events { static readonly ILogger Log = Logger.Factory.CreateLogger(); + // Use an ID not used by other components const int IdStart = 4000; diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/IRocksDb.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/IRocksDb.cs index 91c6a01f977..12212411a64 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/IRocksDb.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/IRocksDb.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Storage.RocksDb { using System; diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/RocksDbWrapper.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/RocksDbWrapper.cs index 3c6b434d537..40a7d974a54 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/RocksDbWrapper.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage.RocksDb/RocksDbWrapper.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Storage.RocksDb { using System; @@ -14,7 +15,6 @@ namespace Microsoft.Azure.Devices.Edge.Storage.RocksDb /// sealed class RocksDbWrapper : IRocksDb { - readonly AtomicBoolean isDisposed = new AtomicBoolean(false); readonly RocksDb db; readonly string path; @@ -27,10 +27,8 @@ sealed class RocksDbWrapper : IRocksDb this.dbOptions = dbOptions; } - public static RocksDbWrapper Create(IRocksDbOptionsProvider optionsProvider, string path, IEnumerable partitionsList) { - Preconditions.CheckNonWhiteSpace(path, nameof(path)); Preconditions.CheckNotNull(optionsProvider, nameof(optionsProvider)); DbOptions dbOptions = Preconditions.CheckNotNull(optionsProvider.GetDbOptions()); diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/EncryptedStore.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/EncryptedStore.cs index ed2c381b0f6..7c8573eb5d1 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/EncryptedStore.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/EncryptedStore.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Storage { using System; + using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Devices.Edge.Util; @@ -16,16 +18,32 @@ public EncryptedStore(IKeyValueStore entityStore, IEncryptionProvide this.encryptionProvider = Preconditions.CheckNotNull(encryptionProvider, nameof(encryptionProvider)); } - public async Task Put(TK key, TV value) + public Task Put(TK key, TV value) => this.Put(key, value, CancellationToken.None); + + public Task> Get(TK key) => this.Get(key, CancellationToken.None); + + public Task Remove(TK key) => this.Remove(key, CancellationToken.None); + + public Task Contains(TK key) => this.Contains(key, CancellationToken.None); + + public Task> GetFirstEntry() => this.GetFirstEntry(CancellationToken.None); + + public Task> GetLastEntry() => this.GetLastEntry(CancellationToken.None); + + public Task IterateBatch(int batchSize, Func perEntityCallback) => this.IterateBatch(batchSize, perEntityCallback, CancellationToken.None); + + public Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback) => this.IterateBatch(startKey, batchSize, perEntityCallback, CancellationToken.None); + + public async Task Put(TK key, TV value, CancellationToken cancellationToken) { string valueString = value.ToJson(); string encryptedString = await this.encryptionProvider.EncryptAsync(valueString); - await this.entityStore.Put(key, encryptedString); + await this.entityStore.Put(key, encryptedString, cancellationToken); } - public async Task> Get(TK key) + public async Task> Get(TK key, CancellationToken cancellationToken) { - Option encryptedValue = await this.entityStore.Get(key); + Option encryptedValue = await this.entityStore.Get(key, cancellationToken); return await encryptedValue.Map( async e => { @@ -35,13 +53,13 @@ public async Task> Get(TK key) .GetOrElse(() => Task.FromResult(Option.None())); } - public Task Remove(TK key) => this.entityStore.Remove(key); + public Task Remove(TK key, CancellationToken cancellationToken) => this.entityStore.Remove(key, cancellationToken); - public Task Contains(TK key) => this.entityStore.Contains(key); + public Task Contains(TK key, CancellationToken cancellationToken) => this.entityStore.Contains(key, cancellationToken); - public async Task> GetFirstEntry() + public async Task> GetFirstEntry(CancellationToken cancellationToken) { - Option<(TK key, string value)> encryptedValue = await this.entityStore.GetFirstEntry(); + Option<(TK key, string value)> encryptedValue = await this.entityStore.GetFirstEntry(cancellationToken); return await encryptedValue.Map( async e => { @@ -51,9 +69,9 @@ public async Task> Get(TK key) .GetOrElse(() => Task.FromResult(Option.None<(TK key, TV value)>())); } - public async Task> GetLastEntry() + public async Task> GetLastEntry(CancellationToken cancellationToken) { - Option<(TK key, string value)> encryptedValue = await this.entityStore.GetLastEntry(); + Option<(TK key, string value)> encryptedValue = await this.entityStore.GetLastEntry(cancellationToken); return await encryptedValue.Map( async e => { @@ -63,7 +81,7 @@ public async Task> Get(TK key) .GetOrElse(() => Task.FromResult(Option.None<(TK key, TV value)>())); } - public Task IterateBatch(int batchSize, Func perEntityCallback) + public Task IterateBatch(int batchSize, Func perEntityCallback, CancellationToken cancellationToken) { return this.entityStore.IterateBatch( batchSize, @@ -72,10 +90,11 @@ public Task IterateBatch(int batchSize, Func perEntityCallback) string decryptedValue = await this.encryptionProvider.DecryptAsync(stringValue); var value = decryptedValue.FromJson(); await perEntityCallback(key, value); - }); + }, + cancellationToken); } - public Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback) + public Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback, CancellationToken cancellationToken) { return this.entityStore.IterateBatch( startKey, @@ -85,7 +104,8 @@ public Task IterateBatch(TK startKey, int batchSize, Func perEntit string decryptedValue = await this.encryptionProvider.DecryptAsync(stringValue); var value = decryptedValue.FromJson(); await perEntityCallback(key, value); - }); + }, + cancellationToken); } protected virtual void Dispose(bool disposing) diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/EntityStore.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/EntityStore.cs index e354a1de35b..76c07866847 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/EntityStore.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/EntityStore.cs @@ -3,9 +3,9 @@ namespace Microsoft.Azure.Devices.Edge.Storage { using System; + using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Devices.Edge.Util; - using Microsoft.Azure.Devices.Edge.Util.Concurrency; /// /// Store for particular Key/Value pair. This provides additional functionality on top of Db Key/Value store such as - @@ -28,111 +28,143 @@ public EntityStore(IDbStore dbStore, string entityName, int keyShardCount = 1) public string EntityName { get; } - public async Task> Get(TK key) + public async Task> Get(TK key, CancellationToken cancellationToken) { - Option valueBytes = await this.dbStore.Get(key.ToBytes()); + Option valueBytes = await this.dbStore.Get(key.ToBytes(), cancellationToken); Option value = valueBytes.Map(v => v.FromBytes()); return value; } - public async Task Put(TK key, TV value) + public Task Put(TK key, TV value) => this.Put(key, value, CancellationToken.None); + + public Task> Get(TK key) => this.Get(key, CancellationToken.None); + + public Task Remove(TK key) => this.Remove(key, CancellationToken.None); + + public Task Contains(TK key) => this.Contains(key, CancellationToken.None); + + public Task> GetFirstEntry() => this.GetFirstEntry(CancellationToken.None); + + public Task> GetLastEntry() => this.GetLastEntry(CancellationToken.None); + + public Task IterateBatch(int batchSize, Func perEntityCallback) => this.IterateBatch(batchSize, perEntityCallback, CancellationToken.None); + + public Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback) => this.IterateBatch(startKey, batchSize, perEntityCallback, CancellationToken.None); + + public async Task Put(TK key, TV value, CancellationToken cancellationToken) { - using (await this.keyLockProvider.GetLock(key).LockAsync()) + using (await this.keyLockProvider.GetLock(key).LockAsync(cancellationToken)) { - await this.dbStore.Put(key.ToBytes(), value.ToBytes()); + await this.dbStore.Put(key.ToBytes(), value.ToBytes(), cancellationToken); } } - public virtual Task Remove(TK key) + public virtual Task Remove(TK key, CancellationToken cancellationToken) { - return this.dbStore.Remove(key.ToBytes()); + return this.dbStore.Remove(key.ToBytes(), cancellationToken); } - public async Task Remove(TK key, Func predicate) + public Task Remove(TK key, Func predicate) => + this.Remove(key, predicate, CancellationToken.None); + + public async Task Remove(TK key, Func predicate, CancellationToken cancellationToken) { Preconditions.CheckNotNull(predicate, nameof(predicate)); - using (await this.keyLockProvider.GetLock(key).LockAsync()) + using (await this.keyLockProvider.GetLock(key).LockAsync(cancellationToken)) { - Option value = await this.Get(key); + Option value = await this.Get(key, cancellationToken); return await value.Filter(v => predicate(v)).Match( async v => { - await this.Remove(key); + await this.Remove(key, cancellationToken); return true; }, () => Task.FromResult(false)); } } - public async Task Update(TK key, Func updator) + public Task Update(TK key, Func updator) => + this.Update(key, updator, CancellationToken.None); + + public async Task Update(TK key, Func updator, CancellationToken cancellationToken) { Preconditions.CheckNotNull(updator, nameof(updator)); - using (await this.keyLockProvider.GetLock(key).LockAsync()) + using (await this.keyLockProvider.GetLock(key).LockAsync(cancellationToken)) { byte[] keyBytes = key.ToBytes(); - byte[] existingValueBytes = (await this.dbStore.Get(keyBytes)).Expect(() => new InvalidOperationException("Value not found in store")); + byte[] existingValueBytes = (await this.dbStore.Get(keyBytes, cancellationToken)) + .Expect(() => new InvalidOperationException("Value not found in store")); var existingValue = existingValueBytes.FromBytes(); TV updatedValue = updator(existingValue); - await this.dbStore.Put(keyBytes, updatedValue.ToBytes()); + await this.dbStore.Put(keyBytes, updatedValue.ToBytes(), cancellationToken); return updatedValue; } } - public async Task PutOrUpdate(TK key, TV value, Func updator) + public Task PutOrUpdate(TK key, TV putValue, Func valueUpdator) => + this.PutOrUpdate(key, putValue, valueUpdator, CancellationToken.None); + + public async Task PutOrUpdate(TK key, TV value, Func updator, CancellationToken cancellationToken) { Preconditions.CheckNotNull(updator, nameof(updator)); - using (await this.keyLockProvider.GetLock(key).LockAsync()) + using (await this.keyLockProvider.GetLock(key).LockAsync(cancellationToken)) { byte[] keyBytes = key.ToBytes(); - Option existingValueBytes = await this.dbStore.Get(keyBytes); + Option existingValueBytes = await this.dbStore.Get(keyBytes, cancellationToken); TV newValue = await existingValueBytes.Map( async e => { var existingValue = e.FromBytes(); TV updatedValue = updator(existingValue); - await this.dbStore.Put(keyBytes, updatedValue.ToBytes()); + await this.dbStore.Put(keyBytes, updatedValue.ToBytes(), cancellationToken); return updatedValue; }).GetOrElse( async () => { - await this.dbStore.Put(keyBytes, value.ToBytes()); + await this.dbStore.Put(keyBytes, value.ToBytes(), cancellationToken); return value; }); return newValue; } } - public async Task FindOrPut(TK key, TV value) + public Task FindOrPut(TK key, TV putValue) => + this.FindOrPut(key, putValue, CancellationToken.None); + + public async Task FindOrPut(TK key, TV value, CancellationToken cancellationToken) { - using (await this.keyLockProvider.GetLock(key).LockAsync()) + using (await this.keyLockProvider.GetLock(key).LockAsync(cancellationToken)) { byte[] keyBytes = key.ToBytes(); - Option existingValueBytes = await this.dbStore.Get(keyBytes); + Option existingValueBytes = await this.dbStore.Get(keyBytes, cancellationToken); if (!existingValueBytes.HasValue) { - await this.dbStore.Put(keyBytes, value.ToBytes()); + await this.dbStore.Put(keyBytes, value.ToBytes(), cancellationToken); } + return existingValueBytes.Map(e => e.FromBytes()).GetOrElse(value); } } - public async Task> GetFirstEntry() + public async Task> GetFirstEntry(CancellationToken cancellationToken) { - Option<(byte[] key, byte[] value)> firstEntry = await this.dbStore.GetFirstEntry(); + Option<(byte[] key, byte[] value)> firstEntry = await this.dbStore.GetFirstEntry(cancellationToken); return firstEntry.Map(e => (e.key.FromBytes(), e.value.FromBytes())); } - public async Task> GetLastEntry() + public async Task> GetLastEntry(CancellationToken cancellationToken) { - Option<(byte[] key, byte[] value)> lastEntry = await this.dbStore.GetLastEntry(); + Option<(byte[] key, byte[] value)> lastEntry = await this.dbStore.GetLastEntry(cancellationToken); return lastEntry.Map(e => (e.key.FromBytes(), e.value.FromBytes())); } - public Task IterateBatch(TK startKey, int batchSize, Func callback) => this.IterateBatch(Option.Some(startKey), batchSize, callback); + public Task IterateBatch(TK startKey, int batchSize, Func callback, CancellationToken cancellationToken) + => this.IterateBatch(Option.Some(startKey), batchSize, callback, cancellationToken); - public Task IterateBatch(int batchSize, Func callback) => this.IterateBatch(Option.None(), batchSize, callback); + public Task IterateBatch(int batchSize, Func callback, CancellationToken cancellationToken) + => this.IterateBatch(Option.None(), batchSize, callback, cancellationToken); - Task IterateBatch(Option startKey, int batchSize, Func callback) + Task IterateBatch(Option startKey, int batchSize, Func callback, CancellationToken cancellationToken) { Preconditions.CheckRange(batchSize, 1, nameof(batchSize)); Preconditions.CheckNotNull(callback, nameof(callback)); @@ -145,11 +177,12 @@ Task DeserializingCallback(byte[] keyBytes, byte[] valueBytes) } return startKey.Match( - k => this.dbStore.IterateBatch(k.ToBytes(), batchSize, DeserializingCallback), - () => this.dbStore.IterateBatch(batchSize, DeserializingCallback)); + k => this.dbStore.IterateBatch(k.ToBytes(), batchSize, DeserializingCallback, cancellationToken), + () => this.dbStore.IterateBatch(batchSize, DeserializingCallback, cancellationToken)); } - public Task Contains(TK key) => this.dbStore.Contains(key.ToBytes()); + public Task Contains(TK key, CancellationToken cancellationToken) + => this.dbStore.Contains(key.ToBytes(), cancellationToken); protected virtual void Dispose(bool disposing) { diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IDbStore.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IDbStore.cs index 02235e37372..548b7cbbedd 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IDbStore.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IDbStore.cs @@ -1,7 +1,8 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. namespace Microsoft.Azure.Devices.Edge.Storage { public interface IDbStore : IKeyValueStore - { } + { + } } diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IEntityStore.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IEntityStore.cs index 1ea76eedc20..d625d717f07 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IEntityStore.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IEntityStore.cs @@ -3,6 +3,7 @@ namespace Microsoft.Azure.Devices.Edge.Storage { using System; + using System.Threading; using System.Threading.Tasks; /// @@ -14,10 +15,18 @@ public interface IEntityStore : IKeyValueStore Task Remove(TK key, Func predicate); + Task Remove(TK key, Func predicate, CancellationToken cancellationToken); + Task Update(TK key, Func updator); + Task Update(TK key, Func updator, CancellationToken cancellationToken); + Task PutOrUpdate(TK key, TV putValue, Func valueUpdator); + Task PutOrUpdate(TK key, TV putValue, Func valueUpdator, CancellationToken cancellationToken); + Task FindOrPut(TK key, TV putValue); + + Task FindOrPut(TK key, TV putValue, CancellationToken cancellationToken); } } diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IKeyValueStore.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IKeyValueStore.cs index 0ab78952803..d1993f8b41b 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IKeyValueStore.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IKeyValueStore.cs @@ -1,7 +1,9 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Storage { using System; + using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Devices.Edge.Util; @@ -24,6 +26,22 @@ public interface IKeyValueStore : IDisposable Task IterateBatch(int batchSize, Func perEntityCallback); - Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback); + Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback); + + Task Put(TK key, TV value, CancellationToken cancellationToken); + + Task> Get(TK key, CancellationToken cancellationToken); + + Task Remove(TK key, CancellationToken cancellationToken); + + Task Contains(TK key, CancellationToken cancellationToken); + + Task> GetFirstEntry(CancellationToken cancellationToken); + + Task> GetLastEntry(CancellationToken cancellationToken); + + Task IterateBatch(int batchSize, Func perEntityCallback, CancellationToken cancellationToken); + + Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback, CancellationToken cancellationToken); } -} \ No newline at end of file +} diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/ISequentialStore.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/ISequentialStore.cs index 96bf333cd46..02d02e76812 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/ISequentialStore.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/ISequentialStore.cs @@ -4,6 +4,7 @@ namespace Microsoft.Azure.Devices.Edge.Storage { using System; using System.Collections.Generic; + using System.Threading; using System.Threading.Tasks; /// @@ -19,6 +20,12 @@ public interface ISequentialStore : IDisposable Task RemoveFirst(Func> predicate); - Task> GetBatch(long startingOffset, int batchSize); + Task> GetBatch(long startingOffset, int batchSize); + + Task Append(T item, CancellationToken cancellationToken); + + Task RemoveFirst(Func> predicate, CancellationToken cancellationToken); + + Task> GetBatch(long startingOffset, int batchSize, CancellationToken cancellationToken); } } diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IStoreProvider.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IStoreProvider.cs index 37d9ea7fe04..293752fb0b1 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IStoreProvider.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/IStoreProvider.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. namespace Microsoft.Azure.Devices.Edge.Storage { diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/InMemoryDbStore.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/InMemoryDbStore.cs index 7d241d288ba..44acc01d1d4 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/InMemoryDbStore.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/InMemoryDbStore.cs @@ -9,48 +9,60 @@ namespace Microsoft.Azure.Devices.Edge.Storage using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Devices.Edge.Util; + using Nito.AsyncEx; /// - /// Provices in memory implementation of a Db store + /// Provides an in memory implementation of the IDbStore /// class InMemoryDbStore : IDbStore { readonly ItemKeyedCollection keyValues; - readonly ReaderWriterLockSlim listLock = new ReaderWriterLockSlim(); + readonly AsyncReaderWriterLock listLock = new AsyncReaderWriterLock(); public InMemoryDbStore() { this.keyValues = new ItemKeyedCollection(new ByteArrayComparer()); } - public Task Contains(byte[] key) => Task.FromResult(this.keyValues.Contains(key)); + public Task Put(byte[] key, byte[] value) => this.Put(key, value, CancellationToken.None); - public Task> Get(byte[] key) + public Task> Get(byte[] key) => this.Get(key, CancellationToken.None); + + public Task Remove(byte[] key) => this.Remove(key, CancellationToken.None); + + public Task Contains(byte[] key) => this.Contains(key, CancellationToken.None); + + public Task> GetFirstEntry() => this.GetFirstEntry(CancellationToken.None); + + public Task> GetLastEntry() => this.GetLastEntry(CancellationToken.None); + + public Task IterateBatch(int batchSize, Func perEntityCallback) => this.IterateBatch(batchSize, perEntityCallback, CancellationToken.None); + + public Task IterateBatch(byte[] startKey, int batchSize, Func perEntityCallback) => this.IterateBatch(startKey, batchSize, perEntityCallback, CancellationToken.None); + + public Task Contains(byte[] key, CancellationToken cancellationToken) => Task.FromResult(this.keyValues.Contains(key)); + + public async Task> Get(byte[] key, CancellationToken cancellationToken) { - this.listLock.EnterReadLock(); - try + using (await this.listLock.ReaderLockAsync(cancellationToken)) { Option value = this.keyValues.Contains(key) ? Option.Some(this.keyValues[key].Value) : Option.None(); - return Task.FromResult(value); - } - finally - { - this.listLock.ExitReadLock(); + return value; } } - public Task IterateBatch(int batchSize, Func callback) + public async Task IterateBatch(int batchSize, Func callback, CancellationToken cancellationToken) { int index = 0; - List<(byte[] key, byte[] value)> snapshot = this.GetSnapshot(); - return this.IterateBatch(snapshot, index, batchSize, callback); + List<(byte[] key, byte[] value)> snapshot = await this.GetSnapshot(cancellationToken); + await this.IterateBatch(snapshot, index, batchSize, callback, cancellationToken); } - public Task IterateBatch(byte[] startKey, int batchSize, Func callback) + public async Task IterateBatch(byte[] startKey, int batchSize, Func callback, CancellationToken cancellationToken) { - List<(byte[] key, byte[] value)> snapshot = this.GetSnapshot(); + List<(byte[] key, byte[] value)> snapshot = await this.GetSnapshot(cancellationToken); int i = 0; for (; i < snapshot.Count; i++) { @@ -60,14 +72,15 @@ public Task IterateBatch(byte[] startKey, int batchSize, Func snapshot, int index, int batchSize, Func callback) + async Task IterateBatch(List<(byte[] key, byte[] value)> snapshot, int index, int batchSize, Func callback, CancellationToken cancellationToken) { if (index >= 0) { - for (int i = index; i < index + batchSize && i < snapshot.Count; i++) + for (int i = index; i < index + batchSize && i < snapshot.Count && !cancellationToken.IsCancellationRequested; i++) { var keyClone = snapshot[i].key.Clone() as byte[]; var valueClone = snapshot[i].value.Clone() as byte[]; @@ -76,42 +89,31 @@ async Task IterateBatch(List<(byte[] key, byte[] value)> snapshot, int index, in } } - public Task> GetFirstEntry() + public async Task> GetFirstEntry(CancellationToken cancellationToken) { - this.listLock.EnterReadLock(); - try + using (await this.listLock.ReaderLockAsync(cancellationToken)) { Option<(byte[], byte[])> firstEntry = this.keyValues.Count > 0 ? Option.Some((this.keyValues[0].Key, this.keyValues[0].Value)) : Option.None<(byte[], byte[])>(); - return Task.FromResult(firstEntry); - } - finally - { - this.listLock.ExitReadLock(); + return firstEntry; } } - public Task> GetLastEntry() + public async Task> GetLastEntry(CancellationToken cancellationToken) { - this.listLock.EnterReadLock(); - try + using (await this.listLock.ReaderLockAsync(cancellationToken)) { Option<(byte[], byte[])> lastEntry = (this.keyValues.Count > 0) ? Option.Some((this.keyValues[this.keyValues.Count - 1].Key, this.keyValues[this.keyValues.Count - 1].Value)) : Option.None<(byte[], byte[])>(); - return Task.FromResult(lastEntry); - } - finally - { - this.listLock.ExitReadLock(); + return lastEntry; } } - public Task Put(byte[] key, byte[] value) + public async Task Put(byte[] key, byte[] value, CancellationToken cancellationToken) { - this.listLock.EnterWriteLock(); - try + using (await this.listLock.WriterLockAsync(cancellationToken)) { if (!this.keyValues.Contains(key)) { @@ -121,39 +123,23 @@ public Task Put(byte[] key, byte[] value) { this.keyValues[key].Value = value; } - return Task.CompletedTask; - } - finally - { - this.listLock.ExitWriteLock(); } } - public Task Remove(byte[] key) + public async Task Remove(byte[] key, CancellationToken cancellationToken) { - this.listLock.EnterWriteLock(); - try + using (await this.listLock.WriterLockAsync(cancellationToken)) { this.keyValues.Remove(key); - return Task.CompletedTask; - } - finally - { - this.listLock.ExitWriteLock(); } } - List<(byte[], byte[])> GetSnapshot() + async Task> GetSnapshot(CancellationToken cancellationToken) { - this.listLock.EnterReadLock(); - try + using (await this.listLock.ReaderLockAsync(cancellationToken)) { return new List<(byte[], byte[])>(this.keyValues.ItemList); } - finally - { - this.listLock.ExitReadLock(); - } } public void Dispose() @@ -199,6 +185,7 @@ public int GetHashCode(byte[] obj) { hashCode = hashCode * -1521134295 + b.GetHashCode(); } + return hashCode; } } diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/Microsoft.Azure.Devices.Edge.Storage.csproj b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/Microsoft.Azure.Devices.Edge.Storage.csproj index 686e0bb6298..550f612b234 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/Microsoft.Azure.Devices.Edge.Storage.csproj +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/Microsoft.Azure.Devices.Edge.Storage.csproj @@ -19,6 +19,9 @@ bin\CodeCoverage DEBUG;TRACE + + + diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/NullKeyValueStore.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/NullKeyValueStore.cs index 8c481f9ffde..9bc98c7cbe9 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/NullKeyValueStore.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/NullKeyValueStore.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Storage { using System; + using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Devices.Edge.Util; @@ -11,20 +13,36 @@ public void Dispose() { } - public Task Put(TK key, TV value) => Task.CompletedTask; + public Task Put(TK key, TV value) => this.Put(key, value, CancellationToken.None); + + public Task> Get(TK key) => this.Get(key, CancellationToken.None); + + public Task Remove(TK key) => this.Remove(key, CancellationToken.None); + + public Task Contains(TK key) => this.Contains(key, CancellationToken.None); + + public Task> GetFirstEntry() => this.GetFirstEntry(CancellationToken.None); + + public Task> GetLastEntry() => this.GetLastEntry(CancellationToken.None); + + public Task IterateBatch(int batchSize, Func perEntityCallback) => this.IterateBatch(batchSize, perEntityCallback, CancellationToken.None); + + public Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback) => this.IterateBatch(startKey, batchSize, perEntityCallback, CancellationToken.None); + + public Task Put(TK key, TV value, CancellationToken cancellationToken) => Task.CompletedTask; - public Task> Get(TK key) => Task.FromResult(Option.None()); + public Task> Get(TK key, CancellationToken cancellationToken) => Task.FromResult(Option.None()); - public Task Remove(TK key) => Task.CompletedTask; + public Task Remove(TK key, CancellationToken cancellationToken) => Task.CompletedTask; - public Task Contains(TK key) => Task.FromResult(false); + public Task Contains(TK key, CancellationToken cancellationToken) => Task.FromResult(false); - public Task> GetFirstEntry() => Task.FromResult(Option.None<(TK key, TV value)>()); + public Task> GetFirstEntry(CancellationToken cancellationToken) => Task.FromResult(Option.None<(TK key, TV value)>()); - public Task> GetLastEntry() => Task.FromResult(Option.None<(TK key, TV value)>()); + public Task> GetLastEntry(CancellationToken cancellationToken) => Task.FromResult(Option.None<(TK key, TV value)>()); - public Task IterateBatch(int batchSize, Func perEntityCallback) => Task.CompletedTask; + public Task IterateBatch(int batchSize, Func perEntityCallback, CancellationToken cancellationToken) => Task.CompletedTask; - public Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback) => Task.CompletedTask; + public Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback, CancellationToken cancellationToken) => Task.CompletedTask; } } diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/SequentialStore.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/SequentialStore.cs index a04bb344a32..e0737303307 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/SequentialStore.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/SequentialStore.cs @@ -5,6 +5,7 @@ namespace Microsoft.Azure.Devices.Edge.Storage using System; using System.Collections.Generic; using System.Linq; + using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Azure.Devices.Edge.Util.Concurrency; @@ -33,11 +34,17 @@ class SequentialStore : ISequentialStore public string EntityName => this.entityStore.EntityName; + public Task Append(T item) => this.Append(item, CancellationToken.None); + + public Task RemoveFirst(Func> predicate) => this.RemoveFirst(predicate, CancellationToken.None); + + public Task> GetBatch(long startingOffset, int batchSize) => this.GetBatch(startingOffset, batchSize, CancellationToken.None); + public static async Task> Create(IEntityStore entityStore) { Preconditions.CheckNotNull(entityStore, nameof(entityStore)); - Option<(byte[] key, T value)> firstEntry = await entityStore.GetFirstEntry(); - Option<(byte[] key, T value)> lastEntry = await entityStore.GetLastEntry(); + Option<(byte[] key, T value)> firstEntry = await entityStore.GetFirstEntry(CancellationToken.None); + Option<(byte[] key, T value)> lastEntry = await entityStore.GetLastEntry(CancellationToken.None); long MapLocalFunction((byte[] key, T) e) => StoreUtils.GetOffsetFromKey(e.key); long headOffset = firstEntry.Map(MapLocalFunction).GetOrElse(DefaultHeadOffset); long tailOffset = lastEntry.Map(MapLocalFunction).GetOrElse(DefaultTailOffset); @@ -45,21 +52,21 @@ public static async Task> Create(IEntityStore ent return sequentialStore; } - public async Task Append(T item) + public async Task Append(T item, CancellationToken cancellationToken) { - using (await this.tailLockObject.LockAsync()) + using (await this.tailLockObject.LockAsync(cancellationToken)) { long currentOffset = this.tailOffset + 1; byte[] key = StoreUtils.GetKeyFromOffset(currentOffset); - await this.entityStore.Put(key, item); + await this.entityStore.Put(key, item, cancellationToken); this.tailOffset = currentOffset; return currentOffset; } } - public async Task RemoveFirst(Func> predicate) + public async Task RemoveFirst(Func> predicate, CancellationToken cancellationToken) { - using (await this.headLockObject.LockAsync()) + using (await this.headLockObject.LockAsync(cancellationToken)) { // Tail offset could change here, but not holding a lock for efficiency. if (this.IsEmpty()) @@ -68,17 +75,18 @@ public async Task RemoveFirst(Func> predicate) } byte[] key = StoreUtils.GetKeyFromOffset(this.headOffset); - Option value = await this.entityStore.Get(key); + Option value = await this.entityStore.Get(key, cancellationToken); bool result = await value .Match( async v => { if (await predicate(this.headOffset, v)) { - await this.entityStore.Remove(key); + await this.entityStore.Remove(key, cancellationToken); this.headOffset++; return true; } + return false; }, () => Task.FromResult(false)); @@ -86,7 +94,7 @@ public async Task RemoveFirst(Func> predicate) } } - public async Task> GetBatch(long startingOffset, int batchSize) + public async Task> GetBatch(long startingOffset, int batchSize, CancellationToken cancellationToken) { Preconditions.CheckRange(batchSize, 1, nameof(batchSize)); @@ -102,12 +110,16 @@ public async Task RemoveFirst(Func> predicate) var batch = new List<(long, T)>(); byte[] startingKey = StoreUtils.GetKeyFromOffset(startingOffset); - await this.entityStore.IterateBatch(startingKey, batchSize, (k, v) => - { - long offsetFromKey = StoreUtils.GetOffsetFromKey(k); - batch.Add((offsetFromKey, v)); - return Task.CompletedTask; - }); + await this.entityStore.IterateBatch( + startingKey, + batchSize, + (k, v) => + { + long offsetFromKey = StoreUtils.GetOffsetFromKey(k); + batch.Add((offsetFromKey, v)); + return Task.CompletedTask; + }, + cancellationToken); return batch; } diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/SerDeExtensions.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/SerDeExtensions.cs index 253d6c4d188..a476e7ba8a8 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/SerDeExtensions.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/SerDeExtensions.cs @@ -49,6 +49,7 @@ public static byte[] ToBytes(this object value) string json = value.ToJson(); bytes = json.ToBytes(); } + return bytes; } diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/StoreProvider.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/StoreProvider.cs index 530f77027af..c5adf2a5f44 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/StoreProvider.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/StoreProvider.cs @@ -2,23 +2,33 @@ namespace Microsoft.Azure.Devices.Edge.Storage { + using System; using System.Threading.Tasks; using Microsoft.Azure.Devices.Edge.Util; public class StoreProvider : IStoreProvider { readonly IDbStoreProvider dbStoreProvider; + readonly TimeSpan operationTimeout; public StoreProvider(IDbStoreProvider dbStoreProvider) + : this(dbStoreProvider, TimeSpan.FromMinutes(2)) { this.dbStoreProvider = Preconditions.CheckNotNull(dbStoreProvider, nameof(dbStoreProvider)); } + public StoreProvider(IDbStoreProvider dbStoreProvider, TimeSpan operationTimeout) + { + this.dbStoreProvider = Preconditions.CheckNotNull(dbStoreProvider, nameof(dbStoreProvider)); + this.operationTimeout = operationTimeout; + } + public IEntityStore GetEntityStore(string entityName) { IDbStore entityDbStore = this.dbStoreProvider.GetDbStore(Preconditions.CheckNonWhiteSpace(entityName, nameof(entityName))); IEntityStore entityStore = new EntityStore(entityDbStore, entityName, 12); - return entityStore; + IEntityStore timedEntityStore = new TimedEntityStore(entityStore, this.operationTimeout); + return timedEntityStore; } public async Task> GetSequentialStore(string entityName) diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/StoreUtils.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/StoreUtils.cs index 3a29f70ebe0..63ed94b30b3 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/StoreUtils.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/StoreUtils.cs @@ -15,6 +15,7 @@ public static long GetOffsetFromKey(byte[] key) { Array.Reverse(key); } + long offset = BitConverter.ToInt64(key, 0); return offset; } @@ -26,7 +27,8 @@ public static byte[] GetKeyFromOffset(long offset) if (BitConverter.IsLittleEndian) { Array.Reverse(bytes); - } + } + return bytes; } @@ -41,6 +43,7 @@ public static IDictionary ToDictionary(this IReadOnlyDictionary< properties.Add(item.Key, item.Value); } } + return properties; } } diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/TimedEntityStore.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/TimedEntityStore.cs new file mode 100644 index 00000000000..d52d5c80d5b --- /dev/null +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Storage/TimedEntityStore.cs @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.Azure.Devices.Edge.Storage +{ + using System; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Devices.Edge.Util; + + public class TimedEntityStore : TimedKeyValueStore, IEntityStore + { + readonly IEntityStore underlyingEntityStore; + readonly TimeSpan timeout; + + public TimedEntityStore(IEntityStore underlyingEntityStore, TimeSpan timeout) + : base(underlyingEntityStore, timeout) + { + this.underlyingEntityStore = Preconditions.CheckNotNull(underlyingEntityStore, nameof(underlyingEntityStore)); + this.timeout = timeout; + } + + public string EntityName => this.underlyingEntityStore.EntityName; + + public Task Remove(TK key, Func predicate) => this.Remove(key, predicate, CancellationToken.None); + + public Task Remove(TK key, Func predicate, CancellationToken cancellationToken) + { + Func> containsWithTimeout = cts => this.underlyingEntityStore.Contains(key, cts); + return containsWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + + public Task Update(TK key, Func updator) => this.Update(key, updator, CancellationToken.None); + + public Task Update(TK key, Func updator, CancellationToken cancellationToken) + { + Func> containsWithTimeout = cts => this.underlyingEntityStore.Update(key, updator, cts); + return containsWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + + public Task PutOrUpdate(TK key, TV putValue, Func valueUpdator) => this.PutOrUpdate(key, putValue, valueUpdator, CancellationToken.None); + + public Task PutOrUpdate(TK key, TV putValue, Func valueUpdator, CancellationToken cancellationToken) + { + Func> containsWithTimeout = cts => this.underlyingEntityStore.PutOrUpdate(key, putValue, valueUpdator, cts); + return containsWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + + public Task FindOrPut(TK key, TV putValue) => this.FindOrPut(key, putValue, CancellationToken.None); + + public Task FindOrPut(TK key, TV putValue, CancellationToken cancellationToken) + { + Func> containsWithTimeout = cts => this.underlyingEntityStore.FindOrPut(key, putValue, cts); + return containsWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + } + + public class TimedKeyValueStore : IKeyValueStore + { + readonly IKeyValueStore underlyingKeyValueStore; + readonly TimeSpan timeout; + + public TimedKeyValueStore(IKeyValueStore underlyingKeyValueStore, TimeSpan timeout) + { + this.underlyingKeyValueStore = Preconditions.CheckNotNull(underlyingKeyValueStore, nameof(underlyingKeyValueStore)); + this.timeout = timeout; + } + + public void Dispose() => this.underlyingKeyValueStore.Dispose(); + + public Task Put(TK key, TV value) => this.Put(key, value, CancellationToken.None); + + public Task> Get(TK key) => this.Get(key, CancellationToken.None); + + public Task Remove(TK key) => this.Remove(key, CancellationToken.None); + + public Task Contains(TK key) => this.Contains(key, CancellationToken.None); + + public Task> GetFirstEntry() => this.GetFirstEntry(CancellationToken.None); + + public Task> GetLastEntry() => this.GetLastEntry(CancellationToken.None); + + public Task IterateBatch(int batchSize, Func perEntityCallback) => this.IterateBatch(batchSize, perEntityCallback, CancellationToken.None); + + public Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback) => this.IterateBatch(startKey, batchSize, perEntityCallback, CancellationToken.None); + + public Task Put(TK key, TV value, CancellationToken cancellationToken) + { + Func putWithTimeout = cts => this.underlyingKeyValueStore.Put(key, value, cts); + return putWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + + public Task> Get(TK key, CancellationToken cancellationToken) + { + Func>> getWithTimeout = cts => this.underlyingKeyValueStore.Get(key, cts); + return getWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + + public Task Remove(TK key, CancellationToken cancellationToken) + { + Func removeWithTimeout = cts => this.underlyingKeyValueStore.Remove(key, cts); + return removeWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + + public Task Contains(TK key, CancellationToken cancellationToken) + { + Func> containsWithTimeout = cts => this.underlyingKeyValueStore.Contains(key, cts); + return containsWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + + public Task> GetFirstEntry(CancellationToken cancellationToken) + { + Func>> getFirstEntryWithTimeout = cts => this.underlyingKeyValueStore.GetFirstEntry(cts); + return getFirstEntryWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + + public Task> GetLastEntry(CancellationToken cancellationToken) + { + Func>> getLastEntryWithTimeout = cts => this.underlyingKeyValueStore.GetLastEntry(cts); + return getLastEntryWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + + public Task IterateBatch(int batchSize, Func perEntityCallback, CancellationToken cancellationToken) + { + Func iterateWithTimeout = cts => this.underlyingKeyValueStore.IterateBatch(batchSize, perEntityCallback, cts); + return iterateWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + + public Task IterateBatch(TK startKey, int batchSize, Func perEntityCallback, CancellationToken cancellationToken) + { + Func iterateWithTimeout = cts => this.underlyingKeyValueStore.IterateBatch(startKey, batchSize, perEntityCallback, cts); + return iterateWithTimeout.TimeoutAfter(cancellationToken, this.timeout); + } + } +} diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Util/TaskEx.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Util/TaskEx.cs index 8be25b36230..72a75787e91 100644 --- a/edge-util/src/Microsoft.Azure.Devices.Edge.Util/TaskEx.cs +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Util/TaskEx.cs @@ -27,14 +27,14 @@ public static Task WhenCanceled(this CancellationToken cancellationToken) return tcs.Task; } - public async static Task<(T1, T2)> WhenAll(Task t1, Task t2) + public static async Task<(T1, T2)> WhenAll(Task t1, Task t2) { T1 val1 = await t1; T2 val2 = await t2; return (val1, val2); } - public async static Task<(T1, T2, T3)> WhenAll(Task t1, Task t2, Task t3) + public static async Task<(T1, T2, T3)> WhenAll(Task t1, Task t2, Task t3) { T1 val1 = await t1; T2 val2 = await t2; @@ -42,7 +42,7 @@ public static Task WhenCanceled(this CancellationToken cancellationToken) return (val1, val2, val3); } - public async static Task<(T1, T2, T3, T4)> WhenAll(Task t1, Task t2, Task t3, Task t4) + public static async Task<(T1, T2, T3, T4)> WhenAll(Task t1, Task t2, Task t3, Task t4) { T1 val1 = await t1; T2 val2 = await t2; @@ -51,7 +51,7 @@ public static Task WhenCanceled(this CancellationToken cancellationToken) return (val1, val2, val3, val4); } - public async static Task<(T1, T2, T3, T4, T5)> WhenAll(Task t1, Task t2, Task t3, Task t4, Task t5) + public static async Task<(T1, T2, T3, T4, T5)> WhenAll(Task t1, Task t2, Task t3, Task t4, Task t5) { T1 val1 = await t1; T2 val2 = await t2; @@ -61,7 +61,7 @@ public static Task WhenCanceled(this CancellationToken cancellationToken) return (val1, val2, val3, val4, val5); } - public async static Task<(T1, T2, T3, T4, T5, T6)> WhenAll(Task t1, Task t2, Task t3, Task t4, Task t5, Task t6) + public static async Task<(T1, T2, T3, T4, T5, T6)> WhenAll(Task t1, Task t2, Task t3, Task t4, Task t5, Task t6) { T1 val1 = await t1; T2 val2 = await t2; @@ -72,7 +72,7 @@ public static Task WhenCanceled(this CancellationToken cancellationToken) return (val1, val2, val3, val4, val5, val6); } - public async static Task<(T1, T2, T3, T4, T5, T6, T7)> WhenAll(Task t1, Task t2, Task t3, Task t4, Task t5, Task t6, Task t7) + public static async Task<(T1, T2, T3, T4, T5, T6, T7)> WhenAll(Task t1, Task t2, Task t3, Task t4, Task t5, Task t6, Task t7) { T1 val1 = await t1; T2 val2 = await t2; @@ -84,7 +84,7 @@ public static Task WhenCanceled(this CancellationToken cancellationToken) return (val1, val2, val3, val4, val5, val6, val7); } - public async static Task<(T1, T2, T3, T4, T5, T6, T7, T8)> WhenAll(Task t1, Task t2, Task t3, Task t4, Task t5, Task t6, Task t7, Task t8) + public static async Task<(T1, T2, T3, T4, T5, T6, T7, T8)> WhenAll(Task t1, Task t2, Task t3, Task t4, Task t5, Task t6, Task t7, Task t8) { T1 val1 = await t1; T2 val2 = await t2; @@ -97,7 +97,7 @@ public static Task WhenCanceled(this CancellationToken cancellationToken) return (val1, val2, val3, val4, val5, val6, val7, val8); } - public async static Task<(T1, T2, T3, T4, T5, T6, T7, T8, T9)> WhenAll(Task t1, Task t2, Task t3, Task t4, Task t5, Task t6, Task t7, Task t8, Task t9) + public static async Task<(T1, T2, T3, T4, T5, T6, T7, T8, T9)> WhenAll(Task t1, Task t2, Task t3, Task t4, Task t5, Task t6, Task t7, Task t8, Task t9) { T1 val1 = await t1; T2 val2 = await t2; @@ -111,6 +111,111 @@ public static Task WhenCanceled(this CancellationToken cancellationToken) return (val1, val2, val3, val4, val5, val6, val7, val8, val9); } + public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + { + using (var cts = new CancellationTokenSource()) + { + Task timerTask = Task.Delay(timeout, cts.Token); + Task completedTask = await Task.WhenAny(task, timerTask); + if (completedTask == timerTask) + { + throw new TimeoutException("Operation timed out"); + } + + cts.Cancel(); + return await task; + } + } + + public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + { + using (var cts = new CancellationTokenSource()) + { + Task timerTask = Task.Delay(timeout, cts.Token); + Task completedTask = await Task.WhenAny(task, timerTask); + if (completedTask == timerTask) + { + throw new TimeoutException("Operation timed out"); + } + + cts.Cancel(); + await task; + } + } + + public static Task TimeoutAfter(this Func operation, CancellationToken cancellationToken, TimeSpan timeout) + { + using (var cts = new CancellationTokenSource()) + { + try + { + return operation(CancellationTokenSource.CreateLinkedTokenSource(cts.Token, cancellationToken).Token) + .TimeoutAfter(timeout); + } + catch (TimeoutException) + { + cts.Cancel(); + throw; + } + } + } + + public static Task TimeoutAfter(this Func> operation, CancellationToken cancellationToken, TimeSpan timeout) + { + using (var cts = new CancellationTokenSource()) + { + try + { + return operation(CancellationTokenSource.CreateLinkedTokenSource(cts.Token, cancellationToken).Token) + .TimeoutAfter(timeout); + } + catch (TimeoutException) + { + cts.Cancel(); + throw; + } + } + } + + static async Task ExecuteUntilCancelled(this Task task, CancellationToken cancellationToken) + { + var tcs = new TaskCompletionSource(); + cancellationToken.Register( + () => + { + tcs.SetException(new TaskCanceledException(task)); + }); + Task completedTask = await Task.WhenAny(task, tcs.Task); + return await completedTask; + } + + static async Task ExecuteUntilCancelled(this Task task, CancellationToken cancellationToken) + { + var tcs = new TaskCompletionSource(); + cancellationToken.Register( + () => + { + tcs.TrySetCanceled(); + }); + Task completedTask = await Task.WhenAny(task, tcs.Task); + //// Await here to bubble up any exceptions + await completedTask; + } + + public static Task ExecuteUntilCancelled(this Func operation, CancellationToken cancellationToken) + { + Preconditions.CheckNotNull(operation, nameof(operation)); + Task task = Task.Run(operation, cancellationToken); + return task.ExecuteUntilCancelled(cancellationToken); + } + + public static Task ExecuteUntilCancelled(this Action operation, CancellationToken cancellationToken) + { + Preconditions.CheckNotNull(operation, nameof(operation)); + Task task = Task.Run(operation, cancellationToken); + return task.ExecuteUntilCancelled(cancellationToken); + } + public static IAsyncResult ToAsyncResult(this Task task, AsyncCallback callback, object state) { if (task.AsyncState == state) diff --git a/edge-util/test/Microsoft.Azure.Devices.Edge.Storage.RocksDb.Test/ColumnFamilyStoreTest.cs b/edge-util/test/Microsoft.Azure.Devices.Edge.Storage.RocksDb.Test/ColumnFamilyStoreTest.cs index 52c805afbe5..765faf7b018 100644 --- a/edge-util/test/Microsoft.Azure.Devices.Edge.Storage.RocksDb.Test/ColumnFamilyStoreTest.cs +++ b/edge-util/test/Microsoft.Azure.Devices.Edge.Storage.RocksDb.Test/ColumnFamilyStoreTest.cs @@ -11,7 +11,7 @@ namespace Microsoft.Azure.Devices.Edge.Storage.RocksDb.Test [Unit] public class ColumnFamilyStoreTest : IClassFixture { - private readonly TestRocksDbStoreProvider rocksDbStoreProvider; + readonly TestRocksDbStoreProvider rocksDbStoreProvider; public ColumnFamilyStoreTest(TestRocksDbStoreProvider rocksDbStoreProvider) { diff --git a/edge-util/test/Microsoft.Azure.Devices.Edge.Util.Test/TaskExTest.cs b/edge-util/test/Microsoft.Azure.Devices.Edge.Util.Test/TaskExTest.cs index b6ca3c80624..2c079c319b9 100644 --- a/edge-util/test/Microsoft.Azure.Devices.Edge.Util.Test/TaskExTest.cs +++ b/edge-util/test/Microsoft.Azure.Devices.Edge.Util.Test/TaskExTest.cs @@ -61,5 +61,116 @@ public async Task WhenAllTuple() Assert.Equal(7, a7); Assert.Equal(8, a8); } + + [Fact] + [Unit] + public async Task ExecuteFuncUntilCancelledTest() + { + int TestFunc() + { + DateTime end = DateTime.Now + TimeSpan.FromSeconds(5); + while (DateTime.Now < end) + { + // No-op + } + + return 0; + } + + Func operation = () => TestFunc(); + var cts = new CancellationTokenSource(); + Func testCode = () => operation.ExecuteUntilCancelled(cts.Token); + Task assertTask = Assert.ThrowsAsync(testCode); + + await Task.Delay(TimeSpan.FromSeconds(2)); + cts.Cancel(); + + await assertTask; + + // Assert + Assert.True(assertTask.IsCompletedSuccessfully); + } + + [Fact] + [Unit] + public async Task ExecuteActionUntilCancelledTest() + { + void TestAction() + { + DateTime end = DateTime.Now + TimeSpan.FromSeconds(5); + while (DateTime.Now < end) + { + // No-op + } + } + + Action operation = () => TestAction(); + var cts = new CancellationTokenSource(); + Func testCode = () => operation.ExecuteUntilCancelled(cts.Token); + + Task assertTask = Assert.ThrowsAsync(testCode); + + await Task.Delay(TimeSpan.FromSeconds(2)); + cts.Cancel(); + + await assertTask; + + // Assert + Assert.True(assertTask.IsCompletedSuccessfully); + } + + [Fact] + [Unit] + public async Task ActionTimeoutAfterTest() + { + Task TestAction(CancellationToken _) => Task.Delay(TimeSpan.FromSeconds(10)); + Func operation = c => TestAction(c); + var cts = new CancellationTokenSource(); + Func testCode = c => operation.TimeoutAfter(cts.Token, TimeSpan.FromSeconds(3)); + await Assert.ThrowsAsync(() => testCode(cts.Token)); + } + + [Fact] + [Unit] + public async Task FuncTimeoutAfterTest() + { + async Task TestFunc(CancellationToken _) + { + await Task.Delay(TimeSpan.FromSeconds(10)); + return 10; + } + + Func> operation = c => TestFunc(c); + var cts = new CancellationTokenSource(); + Func testCode = c => operation.TimeoutAfter(cts.Token, TimeSpan.FromSeconds(3)); + await Assert.ThrowsAsync(() => testCode(cts.Token)); + } + + [Fact] + [Unit] + public async Task ActionTimeoutAfterCancelTest() + { + Task TestAction(CancellationToken _) => Task.Delay(TimeSpan.FromSeconds(10)); + Func operation = c => TestAction(c); + var cts = new CancellationTokenSource(); + Func testCode = c => operation.TimeoutAfter(cts.Token, TimeSpan.FromSeconds(3)); + await Assert.ThrowsAsync(() => testCode(cts.Token)); + } + + [Fact] + [Unit] + public async Task FuncTimeoutAfterCancelTest() + { + async Task TestFunc(CancellationToken _) + { + await Task.Delay(TimeSpan.FromSeconds(10)); + return 10; + } + + Func> operation = c => TestFunc(c); + var cts = new CancellationTokenSource(); + Func testCode = c => operation.TimeoutAfter(cts.Token, TimeSpan.FromSeconds(3)); + await Assert.ThrowsAsync(() => testCode(cts.Token)); + } } }