diff --git a/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.IoTHub/EdgeAgentConnection.cs b/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.IoTHub/EdgeAgentConnection.cs index 56f73053c11..24085c7b5fe 100644 --- a/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.IoTHub/EdgeAgentConnection.cs +++ b/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.IoTHub/EdgeAgentConnection.cs @@ -3,7 +3,6 @@ namespace Microsoft.Azure.Devices.Edge.Agent.IoTHub { using System; using System.Threading.Tasks; - using System.Timers; using Microsoft.Azure.Devices.Client; using Microsoft.Azure.Devices.Edge.Agent.Core; using Microsoft.Azure.Devices.Edge.Agent.Core.ConfigSources; @@ -17,16 +16,17 @@ namespace Microsoft.Azure.Devices.Edge.Agent.IoTHub public class EdgeAgentConnection : IEdgeAgentConnection { - const string PingMethodName = "ping"; internal static readonly Version ExpectedSchemaVersion = new Version("1.0"); + const string PingMethodName = "ping"; static readonly TimeSpan DefaultConfigRefreshFrequency = TimeSpan.FromHours(1); static readonly Task PingMethodResponse = Task.FromResult(new MethodResponse(200)); static readonly TimeSpan DeviceClientInitializationWaitTime = TimeSpan.FromSeconds(5); + readonly AsyncLock twinLock = new AsyncLock(); readonly ISerde desiredPropertiesSerDe; readonly Task initTask; readonly RetryStrategy retryStrategy; - readonly Timer refreshTimer; + readonly PeriodicTask refreshTwinTask; Option deviceClient; TwinCollection desiredProperties; @@ -64,8 +64,7 @@ internal EdgeAgentConnection( this.reportedProperties = Option.None(); this.deviceClient = Option.None(); this.retryStrategy = Preconditions.CheckNotNull(retryStrategy, nameof(retryStrategy)); - this.refreshTimer = new Timer(refreshConfigFrequency.TotalMilliseconds); - this.refreshTimer.Elapsed += (_, __) => this.RefreshTimerElapsed(); + this.refreshTwinTask = new PeriodicTask(this.ForceRefreshTwin, refreshConfigFrequency, refreshConfigFrequency, Events.Log, "refresh twin config"); this.initTask = this.CreateAndInitDeviceClient(Preconditions.CheckNotNull(moduleClientProvider, nameof(moduleClientProvider))); Events.TwinRefreshInit(refreshConfigFrequency); @@ -82,7 +81,7 @@ public async Task> GetDeploymentConfigInfoAsync() public void Dispose() { this.deviceClient.ForEach(d => d.Dispose()); - this.refreshTimer?.Dispose(); + this.refreshTwinTask.Dispose(); } public async Task UpdateReportedPropertiesAsync(TwinCollection patch) @@ -111,14 +110,6 @@ internal static void ValidateSchemaVersion(string schemaVersion) } } - async void RefreshTimerElapsed() => await this.RefreshTwinAsync(); - - void ResetRefreshTimer() - { - this.refreshTimer.Stop(); - this.refreshTimer.Start(); - } - async Task CreateAndInitDeviceClient(IModuleClientProvider moduleClientProvider) { using (await this.twinLock.LockAsync()) @@ -173,6 +164,7 @@ async Task OnDesiredPropertiesUpdated(TwinCollection desiredPropertiesPatch, obj } } + // This method updates local state and should be called only after acquiring twinLock async Task RefreshTwinAsync() { try @@ -193,7 +185,6 @@ async Task RefreshTwinAsync() this.reportedProperties = Option.Some(twin.Properties.Reported); await this.UpdateDeploymentConfig(); Events.TwinRefreshSuccess(); - this.ResetRefreshTimer(); } catch (Exception ex) when (!ex.IsFatal()) { @@ -202,6 +193,14 @@ async Task RefreshTwinAsync() } } + async Task ForceRefreshTwin() + { + using (await this.twinLock.LockAsync()) + { + await this.RefreshTwinAsync(); + } + } + // This method updates local state and should be called only after acquiring twinLock async Task ApplyPatchAsync(TwinCollection patch) { @@ -211,7 +210,6 @@ async Task ApplyPatchAsync(TwinCollection patch) this.desiredProperties = new TwinCollection(mergedJson); await this.UpdateDeploymentConfig(); Events.DesiredPropertiesPatchApplied(); - this.ResetRefreshTimer(); } catch (Exception ex) when (!ex.IsFatal()) { @@ -270,8 +268,8 @@ async Task WaitForDeviceClientInitialization() => static class Events { + public static readonly ILogger Log = Logger.Factory.CreateLogger(); const int IdStart = AgentEventIds.EdgeAgentConnection; - static readonly ILogger Log = Logger.Factory.CreateLogger(); enum EventIds { diff --git a/edge-agent/test/Microsoft.Azure.Devices.Edge.Agent.IoTHub.Test/EdgeAgentConnectionTest.cs b/edge-agent/test/Microsoft.Azure.Devices.Edge.Agent.IoTHub.Test/EdgeAgentConnectionTest.cs index 3f54f941707..8b6817c6c86 100644 --- a/edge-agent/test/Microsoft.Azure.Devices.Edge.Agent.IoTHub.Test/EdgeAgentConnectionTest.cs +++ b/edge-agent/test/Microsoft.Azure.Devices.Edge.Agent.IoTHub.Test/EdgeAgentConnectionTest.cs @@ -1255,97 +1255,6 @@ public async Task EdgeAgentConnectionRefreshTest() } } - [Fact] - [Unit] - public async Task EdgeAgentConnectionRefreshTest_NoRefresh() - { - // Arrange - var moduleDeserializerTypes = new Dictionary - { - { DockerType, typeof(DockerDesiredModule) } - }; - - var edgeAgentDeserializerTypes = new Dictionary - { - { DockerType, typeof(EdgeAgentDockerModule) } - }; - - var edgeHubDeserializerTypes = new Dictionary - { - { DockerType, typeof(EdgeHubDockerModule) } - }; - - var runtimeInfoDeserializerTypes = new Dictionary - { - { DockerType, typeof(DockerRuntimeInfo) } - }; - - var deserializerTypes = new Dictionary> - { - [typeof(IModule)] = moduleDeserializerTypes, - [typeof(IEdgeAgentModule)] = edgeAgentDeserializerTypes, - [typeof(IEdgeHubModule)] = edgeHubDeserializerTypes, - [typeof(IRuntimeInfo)] = runtimeInfoDeserializerTypes, - }; - - ISerde serde = new TypeSpecificSerDe(deserializerTypes); - - var runtimeInfo = new DockerRuntimeInfo("docker", new DockerRuntimeConfig("1.0", null)); - var edgeAgentDockerModule = new EdgeAgentDockerModule("docker", new DockerConfig("image", string.Empty), null, null); - var edgeHubDockerModule = new EdgeHubDockerModule( - "docker", - ModuleStatus.Running, - RestartPolicy.Always, - new DockerConfig("image", string.Empty), - null, - null); - var deploymentConfig = new DeploymentConfig( - "1.0", - runtimeInfo, - new SystemModules(edgeAgentDockerModule, edgeHubDockerModule), - new Dictionary()); - long version = 1; - string deploymentConfigJson = serde.Serialize(deploymentConfig); - JObject deploymentConfigJobject = JObject.Parse(deploymentConfigJson); - deploymentConfigJobject.Add("$version", JToken.Parse($"{version}")); - var twinCollection = new TwinCollection(deploymentConfigJobject, new JObject()); - var twin = new Twin(new TwinProperties { Desired = new TwinCollection(deploymentConfigJson) }); - - DesiredPropertyUpdateCallback desiredPropertyUpdateCallback = null; - var moduleClient = new Mock(); - moduleClient.Setup(m => m.GetTwinAsync()) - .ReturnsAsync(twin); - moduleClient.Setup(m => m.SetDesiredPropertyUpdateCallbackAsync(It.IsAny())) - .Callback(d => desiredPropertyUpdateCallback = d) - .Returns(Task.CompletedTask); - - var moduleClientProvider = new Mock(); - Func updateModuleClient = null; - moduleClientProvider.Setup(m => m.Create(It.IsAny(), It.IsAny>())) - .Callback>((c, f) => updateModuleClient = f) - .ReturnsAsync(moduleClient.Object); - - // Act - using (var edgeAgentConnection = new EdgeAgentConnection(moduleClientProvider.Object, serde, TimeSpan.FromSeconds(5))) - { - await Task.Delay(TimeSpan.FromSeconds(0.5)); - Assert.NotNull(updateModuleClient); - - await updateModuleClient(moduleClient.Object); - Assert.NotNull(desiredPropertyUpdateCallback); - - await Task.Delay(TimeSpan.FromSeconds(3)); - JObject patchConfigJobject = JObject.Parse(deploymentConfigJson); - patchConfigJobject.Add("$version", JToken.Parse($"{version + 1}")); - var patchTwinCollection = new TwinCollection(deploymentConfigJobject, new JObject()); - await desiredPropertyUpdateCallback(patchTwinCollection, null); - await Task.Delay(TimeSpan.FromSeconds(4)); - - // Assert - moduleClient.Verify(m => m.GetTwinAsync(), Times.Once); - } - } - [Theory] [Unit] [InlineData("1.0", null)] diff --git a/edge-util/src/Microsoft.Azure.Devices.Edge.Util/PeriodicTask.cs b/edge-util/src/Microsoft.Azure.Devices.Edge.Util/PeriodicTask.cs new file mode 100644 index 00000000000..716bc221208 --- /dev/null +++ b/edge-util/src/Microsoft.Azure.Devices.Edge.Util/PeriodicTask.cs @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft. All rights reserved. +namespace Microsoft.Azure.Devices.Edge.Util +{ + using System; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Extensions.Logging; + + public class PeriodicTask : IDisposable + { + readonly Func work; + readonly TimeSpan frequency; + readonly TimeSpan startAfter; + readonly object stateLock = new object(); + readonly ILogger logger; + readonly string operationName; + readonly Timer checkTimer; + readonly CancellationTokenSource cts = new CancellationTokenSource(); + + Task currentTask; + + public PeriodicTask( + Func work, + TimeSpan frequency, + TimeSpan startAfter, + ILogger logger, + string operationName) + { + Preconditions.CheckArgument(frequency > TimeSpan.Zero, "Frequency should be > 0"); + Preconditions.CheckArgument(startAfter >= TimeSpan.Zero, "startAfter should be >= 0"); + + this.work = Preconditions.CheckNotNull(work, nameof(work)); + this.frequency = frequency; + this.startAfter = startAfter; + this.logger = Preconditions.CheckNotNull(logger, nameof(logger)); + this.operationName = Preconditions.CheckNonWhiteSpace(operationName, nameof(operationName)); + this.currentTask = this.DoWork(); + this.checkTimer = new Timer(this.EnsureWork, null, startAfter, frequency); + this.logger.LogInformation($"Started operation {this.operationName}"); + } + + public PeriodicTask( + Func work, + TimeSpan frequency, + TimeSpan startAfter, + ILogger logger, + string operationName) + : this(_ => Preconditions.CheckNotNull(work, nameof(work))(), frequency, startAfter, logger, operationName) + { + } + + /// + /// Do not dispose the task here in case it hasn't completed. + /// + public void Dispose() + { + this.checkTimer?.Dispose(); + this.cts?.Cancel(); + this.cts?.Dispose(); + } + + /// + /// The current task should never complete, but in case it does, this makes sure it is started again. + /// + void EnsureWork(object state) + { + lock (this.stateLock) + { + if (this.currentTask == null || this.currentTask.IsCompleted) + { + this.logger.LogInformation($"Periodic operation {this.operationName}, is not running. Attempting to start again..."); + this.currentTask = this.DoWork(); + this.logger.LogInformation($"Started operation {this.operationName}"); + } + } + } + + async Task DoWork() + { + try + { + CancellationToken cancellationToken = this.cts.Token; + await Task.Delay(this.startAfter, cancellationToken); + while (!cancellationToken.IsCancellationRequested) + { + try + { + this.logger.LogInformation($"Starting periodic operation {this.operationName}..."); + await this.work(cancellationToken); + this.logger.LogInformation($"Successfully completed periodic operation {this.operationName}"); + } + catch (Exception e) + { + this.logger.LogWarning(e, $"Error in periodic operation {this.operationName}"); + } + + await Task.Delay(this.frequency, cancellationToken); + } + + this.logger.LogDebug($"Periodic operation {this.operationName} cancelled"); + } + catch (Exception ex) + { + this.logger.LogError(ex, $"Unexpected error in periodic operation {this.operationName}"); + } + } + } +} diff --git a/edge-util/test/Microsoft.Azure.Devices.Edge.Util.Test/PeriodicTaskTest.cs b/edge-util/test/Microsoft.Azure.Devices.Edge.Util.Test/PeriodicTaskTest.cs new file mode 100644 index 00000000000..e60fac4e9f6 --- /dev/null +++ b/edge-util/test/Microsoft.Azure.Devices.Edge.Util.Test/PeriodicTaskTest.cs @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft. All rights reserved. +namespace Microsoft.Azure.Devices.Edge.Util.Test +{ + using System; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Devices.Edge.Util.Test.Common; + using Microsoft.Extensions.Logging; + using Moq; + using Xunit; + + [Unit] + public class PeriodicTaskTest + { + [Fact] + public async Task TestPeriodicTaskTest() + { + // Arrange + int counter = 0; + Func work = async () => + { + counter++; + await Task.Delay(TimeSpan.FromSeconds(2)); + if (counter % 3 == 0) + { + throw new InvalidOperationException(); + } + }; + + TimeSpan frequency = TimeSpan.FromSeconds(3); + TimeSpan startAfter = TimeSpan.FromSeconds(5); + var logger = Mock.Of(); + + // Act + using (new PeriodicTask(work, frequency, startAfter, logger, "test op")) + { + // Assert + await Task.Delay(TimeSpan.FromSeconds(4)); + Assert.Equal(0, counter); + await Task.Delay(TimeSpan.FromSeconds(2)); + Assert.Equal(1, counter); + for (int i = 0; i < 5; i++) + { + await Task.Delay(TimeSpan.FromSeconds(5)); + Assert.Equal(2 + i, counter); + } + } + } + + [Fact] + public async Task TestPeriodicTaskWithCtsTest() + { + // Arrange + int counter = 0; + bool taskCancelled = false; + Func work = async cts => + { + counter++; + + try + { + await Task.Delay(TimeSpan.FromSeconds(3), cts); + } + catch (TaskCanceledException) + { + taskCancelled = true; + throw; + } + + if (counter % 3 == 0) + { + throw new InvalidOperationException(); + } + }; + + TimeSpan frequency = TimeSpan.FromSeconds(3); + TimeSpan startAfter = TimeSpan.FromSeconds(5); + var logger = Mock.Of(); + + // Act + using (new PeriodicTask(work, frequency, startAfter, logger, "test op")) + { + // Assert + await Task.Delay(TimeSpan.FromSeconds(4)); + Assert.Equal(0, counter); + await Task.Delay(TimeSpan.FromSeconds(2)); + Assert.Equal(1, counter); + for (int i = 0; i < 5; i++) + { + await Task.Delay(TimeSpan.FromSeconds(6)); + Assert.Equal(2 + i, counter); + } + } + + await Task.Delay(TimeSpan.FromSeconds(4)); + Assert.True(taskCancelled); + } + } +}