Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EdgeHub: Fix TokenUpdate logic #206

Merged
merged 16 commits into from
Sep 12, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class CloudConnection : ICloudConnection
const int TokenTimeToLiveSeconds = 3600; // Unused - Token is generated by downstream clients
const int TokenExpiryBufferPercentage = 8; // Assuming a standard token for 1 hr, we set expiry time to around 5 mins.
const uint OperationTimeoutMilliseconds = 1 * 60 * 1000; // 1 min
static readonly TimeSpan TokenRetryWaitTime = TimeSpan.FromSeconds(20);

readonly Action<string, CloudConnectionStatus> connectionStatusChangedHandler;
readonly ITransportSettings[] transportSettingsList;
Expand Down Expand Up @@ -93,15 +94,15 @@ public async Task<ICloudProxy> CreateOrUpdateAsync(IClientCredentials newCredent
if (newCredentials is ITokenCredentials tokenAuth && this.tokenGetter.HasValue)
{
if (IsTokenExpired(tokenAuth.Identity.IotHubHostName, tokenAuth.Token))
{
{
throw new InvalidOperationException($"Token for client {tokenAuth.Identity.Id} is expired");
}

this.tokenGetter.ForEach(tg =>
{
tg.SetResult(tokenAuth.Token);
// First reset the token getter and then set the result.
this.tokenGetter = Option.None<TaskCompletionSource<string>>();
Events.NewTokenObtained(newCredentials.Identity.IotHubHostName, newCredentials.Identity.Id, tokenAuth.Token);
tg.SetResult(tokenAuth.Token);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this order matter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the code below expects that the tokenGetter will be reset when the value is set. And that also answers your other question, I suppose

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, if order is important, then we should do these under a lock. Otherwise order shouldn't matter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, the order is important because the caller is waiting for SetResult to be called (not looping), and when it resumes, it expects this.tokenGetter to be reset.

});
return (cp, false);
}
Expand Down Expand Up @@ -177,7 +178,7 @@ async Task<IClient> CreateDeviceClient(
{
client.SetProductInfo(newCredentials.ProductInfo);
}

Events.CreateDeviceClientSuccess(transportSettings.GetTransportType(), OperationTimeoutMilliseconds, newCredentials.Identity);
return client;
}
Expand Down Expand Up @@ -263,33 +264,55 @@ void InternalConnectionStatusChangesHandler(ConnectionStatus status, ConnectionS
async Task<string> GetNewToken(string iotHub, string id, string currentToken, IIdentity currentIdentity)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add some comments explaining the retry logic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

{
Events.GetNewToken(id);
// We have to catch UnauthorizedAccessException, because on IsTokenUsable, we call parse from
// Device Client and it throws if the token is expired.
if (IsTokenUsable(iotHub, currentToken))
{
Events.UsingExistingToken(id);
return currentToken;
}
else
bool retrying = false;
string token = currentToken;
while (true)
{
Events.TokenExpired(id, currentToken);
}
// We have to catch UnauthorizedAccessException, because on IsTokenUsable, we call parse from
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this catch occur?

Copy link
Contributor Author

@varunpuranik varunpuranik Aug 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In GetTokenExpiry (and IsTokenUsable also catches all exceptions)

// Device Client and it throws if the token is expired.
if (IsTokenUsable(iotHub, token))
{
if (retrying)
{
Events.NewTokenObtained(iotHub, id, token);
}
else
{
Events.UsingExistingToken(id);
}
return token;
}
else
{
Events.TokenNotUsable(iotHub, id, token);
}

// No need to lock here as the lock is being held by the refresher.
TaskCompletionSource<string> tcs = this.tokenGetter
.GetOrElse(
() =>
bool newTokenGetterCreated = false;
// No need to lock here as the lock is being held by the refresher.
TaskCompletionSource<string> tcs = this.tokenGetter
.GetOrElse(
() =>
{
Events.SafeCreateNewToken(id);
var taskCompletionSource = new TaskCompletionSource<string>();
this.tokenGetter = Option.Some(taskCompletionSource);
newTokenGetterCreated = true;
return taskCompletionSource;
});

if (newTokenGetterCreated)
{
if (retrying)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not clear to me when and how many times 'newTokenGetterCreated' and 'retrying' would both be true. Perhaps some comments could help clarify when tokenGetters are set and cleared.

{
Events.SafeCreateNewToken(id);
var taskCompletionSource = new TaskCompletionSource<string>();
this.tokenGetter = Option.Some(taskCompletionSource);
this.connectionStatusChangedHandler(currentIdentity.Id, CloudConnectionStatus.TokenNearExpiry);
return taskCompletionSource;
});
string newToken = await tcs.Task;
return newToken;
}
await Task.Delay(TokenRetryWaitTime);
}
this.connectionStatusChangedHandler(currentIdentity.Id, CloudConnectionStatus.TokenNearExpiry);
}

retrying = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we iterating the loop again only to log a message? Could we log the NewTokenObtained event here and return the new token right away?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but wanted to avoid the duplicate code. This seemed cleaner :)

token = await tcs.Task;
}
}

internal static DateTime GetTokenExpiry(string hostName, string token)
{
Expand Down Expand Up @@ -412,7 +435,6 @@ enum EventIds
CreateNewToken,
UpdatedCloudConnection,
ObtainedNewToken,
TokenExpired,
ErrorRenewingToken,
ErrorCheckingTokenUsability
}
Expand Down Expand Up @@ -471,11 +493,6 @@ internal static void NewTokenObtained(string hostname, string id, string newToke
Log.LogInformation((int)EventIds.ObtainedNewToken, Invariant($"Obtained new token for client {id} that expires in {timeRemaining}"));
}

internal static void TokenExpired(string id, string currentToken)
{
Log.LogDebug((int)EventIds.TokenExpired, Invariant($"Token Expired. Id:{id}, CurrentToken: {currentToken}."));
}

internal static void ErrorRenewingToken(Exception ex)
{
Log.LogDebug((int)EventIds.ErrorRenewingToken, ex, "Critical Error trying to renew Token.");
Expand All @@ -485,6 +502,12 @@ public static void ErrorCheckingTokenUsable(Exception ex)
{
Log.LogDebug((int)EventIds.ErrorCheckingTokenUsability, ex, "Error checking if token is usable.");
}

public static void TokenNotUsable(string hostname, string id, string newToken)
{
TimeSpan timeRemaining = GetTokenExpiryTimeRemaining(hostname, newToken);
Log.LogDebug((int)EventIds.ObtainedNewToken, Invariant($"Token received for client {id} expires in {timeRemaining}, and so is not usable. Getting a fresh token..."));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void GetTokenExpiryBufferSecondsTest()
string token = TokenHelper.CreateSasToken("azure.devices.net");
TimeSpan timeRemaining = CloudConnection.GetTokenExpiryTimeRemaining("foo.azuredevices.net", token);
Assert.True(timeRemaining > TimeSpan.Zero);
}
}

[Unit]
[Fact]
Expand Down Expand Up @@ -79,7 +79,14 @@ public async Task RefreshTokenTest()

IClientCredentials GetClientCredentialsWithExpiringToken()
{
string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddSeconds(10));
string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(3));
var identity = new DeviceIdentity(iothubHostName, deviceId);
return new TokenCredentials(identity, token, string.Empty);
}

IClientCredentials GetClientCredentialsWithNonExpiringToken()
{
string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(10));
var identity = new DeviceIdentity(iothubHostName, deviceId);
return new TokenCredentials(identity, token, string.Empty);
}
Expand All @@ -104,28 +111,105 @@ IClientCredentials GetClientCredentialsWithExpiringToken()
var deviceAuthenticationWithTokenRefresh = authenticationMethod as DeviceAuthenticationWithTokenRefresh;
Assert.NotNull(deviceAuthenticationWithTokenRefresh);

// Wait for the token to expire
await Task.Delay(TimeSpan.FromSeconds(10));

Task<string> getTokenTask = deviceAuthenticationWithTokenRefresh.GetTokenAsync(iothubHostName);
Assert.False(getTokenTask.IsCompleted);

Assert.Equal(receivedStatus, CloudConnectionStatus.TokenNearExpiry);

IClientCredentials clientCredentialsWithExpiringToken2 = GetClientCredentialsWithExpiringToken();
IClientCredentials clientCredentialsWithExpiringToken2 = GetClientCredentialsWithNonExpiringToken();
ICloudProxy cloudProxy2 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithExpiringToken2);

// Wait for the task to complete
await Task.Delay(TimeSpan.FromSeconds(10));

Assert.True(getTokenTask.IsCompletedSuccessfully);
Assert.Equal(cloudProxy2, cloudConnection.CloudProxy.OrDefault());
Assert.True(cloudProxy2.IsActive);
Assert.True(cloudProxy1.IsActive);
Assert.Equal(cloudProxy1, cloudProxy2);
Assert.True(getTokenTask.IsCompletedSuccessfully);
Assert.Equal(getTokenTask.Result, (clientCredentialsWithExpiringToken2 as ITokenCredentials)?.Token);
}

[Fact]
[Unit]
public async Task RefreshTokenWithRetryTest()
{
string iothubHostName = "test.azure-devices.net";
string deviceId = "device1";

IClientCredentials GetClientCredentialsWithExpiringToken()
{
string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(3));
var identity = new DeviceIdentity(iothubHostName, deviceId);
return new TokenCredentials(identity, token, string.Empty);
}

IClientCredentials GetClientCredentialsWithNonExpiringToken()
{
string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(10));
var identity = new DeviceIdentity(iothubHostName, deviceId);
return new TokenCredentials(identity, token, string.Empty);
}

IAuthenticationMethod authenticationMethod = null;
IClientProvider clientProvider = GetMockDeviceClientProviderWithToken((s, a, t) => authenticationMethod = a);

var transportSettings = new ITransportSettings[] { new AmqpTransportSettings(TransportType.Amqp_Tcp_Only) };

var receivedStatuses = new List<CloudConnectionStatus>();
void ConnectionStatusHandler(string id, CloudConnectionStatus status) => receivedStatuses.Add(status);
var messageConverterProvider = new MessageConverterProvider(new Dictionary<Type, IMessageConverter>());

var cloudConnection = new CloudConnection(ConnectionStatusHandler, transportSettings, messageConverterProvider, clientProvider);

IClientCredentials clientCredentialsWithExpiringToken1 = GetClientCredentialsWithExpiringToken();
ICloudProxy cloudProxy1 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithExpiringToken1);
Assert.True(cloudProxy1.IsActive);
Assert.Equal(cloudProxy1, cloudConnection.CloudProxy.OrDefault());

Assert.NotNull(authenticationMethod);
var deviceAuthenticationWithTokenRefresh = authenticationMethod as DeviceAuthenticationWithTokenRefresh;
Assert.NotNull(deviceAuthenticationWithTokenRefresh);

// Try to refresh token but get an expiring token
Task<string> getTokenTask = deviceAuthenticationWithTokenRefresh.GetTokenAsync(iothubHostName);
Assert.False(getTokenTask.IsCompleted);

Assert.Equal(2, receivedStatuses.Count);
Assert.Equal(receivedStatuses[1], CloudConnectionStatus.TokenNearExpiry);

ICloudProxy cloudProxy2 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithExpiringToken1);

// Wait for the task to process
await Task.Delay(TimeSpan.FromSeconds(5));

Assert.False(getTokenTask.IsCompletedSuccessfully);
Assert.Equal(cloudProxy2, cloudConnection.CloudProxy.OrDefault());
Assert.True(cloudProxy2.IsActive);
Assert.True(cloudProxy1.IsActive);
Assert.Equal(cloudProxy1, cloudProxy2);

// Wait for 20 secs for retry to happen
await Task.Delay(TimeSpan.FromSeconds(20));

// Check if retry happened
Assert.Equal(3, receivedStatuses.Count);
Assert.Equal(receivedStatuses[2], CloudConnectionStatus.TokenNearExpiry);

IClientCredentials clientCredentialsWithNonExpiringToken = GetClientCredentialsWithNonExpiringToken();
ICloudProxy cloudProxy3 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithNonExpiringToken);

// Wait for the task to complete
await Task.Delay(TimeSpan.FromSeconds(5));

Assert.True(getTokenTask.IsCompletedSuccessfully);
Assert.Equal(cloudProxy3, cloudConnection.CloudProxy.OrDefault());
Assert.True(cloudProxy3.IsActive);
Assert.True(cloudProxy1.IsActive);
Assert.Equal(cloudProxy1, cloudProxy3);
Assert.Equal(getTokenTask.Result, (clientCredentialsWithNonExpiringToken as ITokenCredentials)?.Token);
}

[Fact]
[Unit]
public async Task CloudConnectionCallbackTest()
Expand Down Expand Up @@ -204,9 +288,9 @@ public async Task UpdateDeviceConnectionTest()
string hostname = "dummy.azure-devices.net";
string deviceId = "device1";

IClientCredentials GetClientCredentials()
IClientCredentials GetClientCredentials(TimeSpan tokenExpiryDuration)
{
string token = TokenHelper.CreateSasToken(hostname, DateTime.UtcNow.AddSeconds(10));
string token = TokenHelper.CreateSasToken(hostname, DateTime.UtcNow.AddSeconds(tokenExpiryDuration.TotalSeconds));
var identity = new DeviceIdentity(hostname, deviceId);
return new TokenCredentials(identity, token, string.Empty);
}
Expand Down Expand Up @@ -256,7 +340,7 @@ IClient GetMockedDeviceClient()
cloudConnectionProvider.BindEdgeHub(Mock.Of<IEdgeHub>());
IConnectionManager connectionManager = new ConnectionManager(cloudConnectionProvider);

IClientCredentials clientCredentials1 = GetClientCredentials();
IClientCredentials clientCredentials1 = GetClientCredentials(TimeSpan.FromSeconds(10));
Try<ICloudProxy> cloudProxyTry1 = await connectionManager.CreateCloudConnectionAsync(clientCredentials1);
Assert.True(cloudProxyTry1.Success);

Expand All @@ -271,7 +355,7 @@ IClient GetMockedDeviceClient()
Task<string> tokenGetter = deviceTokenRefresher.GetTokenAsync(hostname);
Assert.False(tokenGetter.IsCompleted);

IClientCredentials clientCredentials2 = GetClientCredentials();
IClientCredentials clientCredentials2 = GetClientCredentials(TimeSpan.FromMinutes(2));
Try<ICloudProxy> cloudProxyTry2 = await connectionManager.CreateCloudConnectionAsync(clientCredentials2);
Assert.True(cloudProxyTry2.Success);

Expand All @@ -280,29 +364,19 @@ IClient GetMockedDeviceClient()
connectionManager.BindDeviceProxy(clientCredentials2.Identity, deviceProxy2);

await Task.Delay(TimeSpan.FromSeconds(3));
Assert.True(tokenGetter.IsCompleted);
Assert.Equal(tokenGetter.Result, (clientCredentials2 as ITokenCredentials)?.Token);

await Task.Delay(TimeSpan.FromSeconds(10));
Assert.NotNull(authenticationMethod);
deviceTokenRefresher = authenticationMethod as DeviceAuthenticationWithTokenRefresh;
Assert.NotNull(deviceTokenRefresher);
tokenGetter = deviceTokenRefresher.GetTokenAsync(hostname);
Assert.False(tokenGetter.IsCompleted);

IClientCredentials clientCredentials3 = GetClientCredentials();
IClientCredentials clientCredentials3 = GetClientCredentials(TimeSpan.FromMinutes(10));
Try<ICloudProxy> cloudProxyTry3 = await connectionManager.CreateCloudConnectionAsync(clientCredentials3);
Assert.True(cloudProxyTry3.Success);

IDeviceProxy deviceProxy3 = GetMockDeviceProxy();
await connectionManager.AddDeviceConnection(clientCredentials3);
connectionManager.BindDeviceProxy(clientCredentials3.Identity, deviceProxy3);

await Task.Delay(TimeSpan.FromSeconds(3));
await Task.Delay(TimeSpan.FromSeconds(23));
Assert.True(tokenGetter.IsCompleted);
Assert.Equal(tokenGetter.Result, (clientCredentials3 as ITokenCredentials)?.Token);

Mock.VerifyAll(Mock.Get(deviceProxy1), Mock.Get(deviceProxy2));
}

static async Task GetCloudConnectionTest(Func<IClientCredentials> credentialsGenerator, IClientProvider clientProvider)
Expand Down