Skip to content

Commit

Permalink
EdgeHub: Fix TokenUpdate logic (#206)
Browse files Browse the repository at this point in the history
* Check refresh token validity before returning it

* Fix code and add test

* Fix cloud connection token update logic (#217)

* Adding comments

* Fix CloudConnectionTest

* Fix CloudConnection test
  • Loading branch information
varunpuranik authored Sep 12, 2018
1 parent 4f3ccd6 commit 9d2ba5e
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class CloudConnection : ICloudConnection
const int TokenTimeToLiveSeconds = 3600; // Unused - Token is generated by downstream clients
const int TokenExpiryBufferPercentage = 8; // Assuming a standard token for 1 hr, we set expiry time to around 5 mins.
const uint OperationTimeoutMilliseconds = 1 * 60 * 1000; // 1 min
static readonly TimeSpan TokenRetryWaitTime = TimeSpan.FromSeconds(20);

readonly Action<string, CloudConnectionStatus> connectionStatusChangedHandler;
readonly ITransportSettings[] transportSettingsList;
Expand Down Expand Up @@ -103,15 +104,15 @@ public async Task<ICloudProxy> CreateOrUpdateAsync(IClientCredentials newCredent
if (newCredentials is ITokenCredentials tokenAuth && this.tokenGetter.HasValue)
{
if (IsTokenExpired(tokenAuth.Identity.IotHubHostName, tokenAuth.Token))
{
{
throw new InvalidOperationException($"Token for client {tokenAuth.Identity.Id} is expired");
}

this.tokenGetter.ForEach(tg =>
{
tg.SetResult(tokenAuth.Token);
// First reset the token getter and then set the result.
this.tokenGetter = Option.None<TaskCompletionSource<string>>();
Events.NewTokenObtained(newCredentials.Identity.IotHubHostName, newCredentials.Identity.Id, tokenAuth.Token);
tg.SetResult(tokenAuth.Token);
});
return (cp, false);
}
Expand Down Expand Up @@ -188,7 +189,7 @@ async Task<IClient> CreateDeviceClient(
{
client.SetProductInfo(newCredentials.ProductInfo);
}

Events.CreateDeviceClientSuccess(transportSettings.GetTransportType(), OperationTimeoutMilliseconds, newCredentials.Identity);
return client;
}
Expand Down Expand Up @@ -274,36 +275,63 @@ void InternalConnectionStatusChangesHandler(ConnectionStatus status, ConnectionS
/// If the existing identity has a usable token, then use it.
/// Else, generate a notification of token being near expiry and return a task that
/// can be completed later.
/// Keep retrying till we get a usable token.
/// Note - Don't use this.Identity in this method, as it may not have been set yet!
/// </summary>
async Task<string> GetNewToken(string iotHub, string id, string currentToken, IIdentity currentIdentity)
{
Events.GetNewToken(id);
// We have to catch UnauthorizedAccessException, because on IsTokenUsable, we call parse from
// Device Client and it throws if the token is expired.
if (IsTokenUsable(iotHub, currentToken))
bool retrying = false;
string token = currentToken;
while (true)
{
Events.UsingExistingToken(id);
return currentToken;
}
else
{
Events.TokenExpired(id, currentToken);
}
// We have to catch UnauthorizedAccessException, because on IsTokenUsable, we call parse from
// Device Client and it throws if the token is expired.
if (IsTokenUsable(iotHub, token))
{
if (retrying)
{
Events.NewTokenObtained(iotHub, id, token);
}
else
{
Events.UsingExistingToken(id);
}
return token;
}
else
{
Events.TokenNotUsable(iotHub, id, token);
}

// No need to lock here as the lock is being held by the refresher.
TaskCompletionSource<string> tcs = this.tokenGetter
.GetOrElse(
() =>
bool newTokenGetterCreated = false;
// No need to lock here as the lock is being held by the refresher.
TaskCompletionSource<string> tcs = this.tokenGetter
.GetOrElse(
() =>
{
Events.SafeCreateNewToken(id);
var taskCompletionSource = new TaskCompletionSource<string>();
this.tokenGetter = Option.Some(taskCompletionSource);
newTokenGetterCreated = true;
return taskCompletionSource;
});

// If a new tokenGetter was created, then invoke the connection status changed handler
if (newTokenGetterCreated)
{
// If retrying, wait for some time.
if (retrying)
{
Events.SafeCreateNewToken(id);
var taskCompletionSource = new TaskCompletionSource<string>();
this.tokenGetter = Option.Some(taskCompletionSource);
this.connectionStatusChangedHandler(currentIdentity.Id, CloudConnectionStatus.TokenNearExpiry);
return taskCompletionSource;
});
string newToken = await tcs.Task;
return newToken;
await Task.Delay(TokenRetryWaitTime);
}
this.connectionStatusChangedHandler(currentIdentity.Id, CloudConnectionStatus.TokenNearExpiry);
}

retrying = true;
// this.tokenGetter will be reset when this task returns.
token = await tcs.Task;
}
}

internal static DateTime GetTokenExpiry(string hostName, string token)
Expand Down Expand Up @@ -427,7 +455,6 @@ enum EventIds
CreateNewToken,
UpdatedCloudConnection,
ObtainedNewToken,
TokenExpired,
ErrorRenewingToken,
ErrorCheckingTokenUsability
}
Expand Down Expand Up @@ -486,11 +513,6 @@ internal static void NewTokenObtained(string hostname, string id, string newToke
Log.LogInformation((int)EventIds.ObtainedNewToken, Invariant($"Obtained new token for client {id} that expires in {timeRemaining}"));
}

internal static void TokenExpired(string id, string currentToken)
{
Log.LogDebug((int)EventIds.TokenExpired, Invariant($"Token Expired. Id:{id}, CurrentToken: {currentToken}."));
}

internal static void ErrorRenewingToken(Exception ex)
{
Log.LogDebug((int)EventIds.ErrorRenewingToken, ex, "Critical Error trying to renew Token.");
Expand All @@ -500,6 +522,12 @@ public static void ErrorCheckingTokenUsable(Exception ex)
{
Log.LogDebug((int)EventIds.ErrorCheckingTokenUsability, ex, "Error checking if token is usable.");
}

public static void TokenNotUsable(string hostname, string id, string newToken)
{
TimeSpan timeRemaining = GetTokenExpiryTimeRemaining(hostname, newToken);
Log.LogDebug((int)EventIds.ObtainedNewToken, Invariant($"Token received for client {id} expires in {timeRemaining}, and so is not usable. Getting a fresh token..."));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public void GetTokenExpiryBufferSecondsTest()
string token = TokenHelper.CreateSasToken("azure.devices.net");
TimeSpan timeRemaining = CloudConnection.GetTokenExpiryTimeRemaining("foo.azuredevices.net", token);
Assert.True(timeRemaining > TimeSpan.Zero);
}
}

[Unit]
[Fact]
Expand Down Expand Up @@ -81,7 +81,14 @@ public async Task RefreshTokenTest()

IClientCredentials GetClientCredentialsWithExpiringToken()
{
string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddSeconds(10));
string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(3));
var identity = new DeviceIdentity(iothubHostName, deviceId);
return new TokenCredentials(identity, token, string.Empty);
}

IClientCredentials GetClientCredentialsWithNonExpiringToken()
{
string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(10));
var identity = new DeviceIdentity(iothubHostName, deviceId);
return new TokenCredentials(identity, token, string.Empty);
}
Expand All @@ -106,28 +113,105 @@ IClientCredentials GetClientCredentialsWithExpiringToken()
var deviceAuthenticationWithTokenRefresh = authenticationMethod as DeviceAuthenticationWithTokenRefresh;
Assert.NotNull(deviceAuthenticationWithTokenRefresh);

// Wait for the token to expire
await Task.Delay(TimeSpan.FromSeconds(10));

Task<string> getTokenTask = deviceAuthenticationWithTokenRefresh.GetTokenAsync(iothubHostName);
Assert.False(getTokenTask.IsCompleted);

Assert.Equal(receivedStatus, CloudConnectionStatus.TokenNearExpiry);

IClientCredentials clientCredentialsWithExpiringToken2 = GetClientCredentialsWithExpiringToken();
IClientCredentials clientCredentialsWithExpiringToken2 = GetClientCredentialsWithNonExpiringToken();
ICloudProxy cloudProxy2 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithExpiringToken2);

// Wait for the task to complete
await Task.Delay(TimeSpan.FromSeconds(10));

Assert.True(getTokenTask.IsCompletedSuccessfully);
Assert.Equal(cloudProxy2, cloudConnection.CloudProxy.OrDefault());
Assert.True(cloudProxy2.IsActive);
Assert.True(cloudProxy1.IsActive);
Assert.Equal(cloudProxy1, cloudProxy2);
Assert.True(getTokenTask.IsCompletedSuccessfully);
Assert.Equal(getTokenTask.Result, (clientCredentialsWithExpiringToken2 as ITokenCredentials)?.Token);
}

[Fact]
[Unit]
public async Task RefreshTokenWithRetryTest()
{
string iothubHostName = "test.azure-devices.net";
string deviceId = "device1";

IClientCredentials GetClientCredentialsWithExpiringToken()
{
string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(3));
var identity = new DeviceIdentity(iothubHostName, deviceId);
return new TokenCredentials(identity, token, string.Empty);
}

IClientCredentials GetClientCredentialsWithNonExpiringToken()
{
string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(10));
var identity = new DeviceIdentity(iothubHostName, deviceId);
return new TokenCredentials(identity, token, string.Empty);
}

IAuthenticationMethod authenticationMethod = null;
IClientProvider clientProvider = GetMockDeviceClientProviderWithToken((s, a, t) => authenticationMethod = a);

var transportSettings = new ITransportSettings[] { new AmqpTransportSettings(TransportType.Amqp_Tcp_Only) };

var receivedStatuses = new List<CloudConnectionStatus>();
void ConnectionStatusHandler(string id, CloudConnectionStatus status) => receivedStatuses.Add(status);
var messageConverterProvider = new MessageConverterProvider(new Dictionary<Type, IMessageConverter> { [typeof(TwinCollection)] = Mock.Of<IMessageConverter>() });

var cloudConnection = new CloudConnection(ConnectionStatusHandler, transportSettings, messageConverterProvider, clientProvider, Mock.Of<ICloudListener>(), TokenProvider, DeviceScopeIdentitiesCache, TimeSpan.FromMinutes(60));

IClientCredentials clientCredentialsWithExpiringToken1 = GetClientCredentialsWithExpiringToken();
ICloudProxy cloudProxy1 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithExpiringToken1);
Assert.True(cloudProxy1.IsActive);
Assert.Equal(cloudProxy1, cloudConnection.CloudProxy.OrDefault());

Assert.NotNull(authenticationMethod);
var deviceAuthenticationWithTokenRefresh = authenticationMethod as DeviceAuthenticationWithTokenRefresh;
Assert.NotNull(deviceAuthenticationWithTokenRefresh);

// Try to refresh token but get an expiring token
Task<string> getTokenTask = deviceAuthenticationWithTokenRefresh.GetTokenAsync(iothubHostName);
Assert.False(getTokenTask.IsCompleted);

Assert.Equal(2, receivedStatuses.Count);
Assert.Equal(receivedStatuses[1], CloudConnectionStatus.TokenNearExpiry);

ICloudProxy cloudProxy2 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithExpiringToken1);

// Wait for the task to process
await Task.Delay(TimeSpan.FromSeconds(5));

Assert.False(getTokenTask.IsCompletedSuccessfully);
Assert.Equal(cloudProxy2, cloudConnection.CloudProxy.OrDefault());
Assert.True(cloudProxy2.IsActive);
Assert.True(cloudProxy1.IsActive);
Assert.Equal(cloudProxy1, cloudProxy2);

// Wait for 20 secs for retry to happen
await Task.Delay(TimeSpan.FromSeconds(20));

// Check if retry happened
Assert.Equal(3, receivedStatuses.Count);
Assert.Equal(receivedStatuses[2], CloudConnectionStatus.TokenNearExpiry);

IClientCredentials clientCredentialsWithNonExpiringToken = GetClientCredentialsWithNonExpiringToken();
ICloudProxy cloudProxy3 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithNonExpiringToken);

// Wait for the task to complete
await Task.Delay(TimeSpan.FromSeconds(5));

Assert.True(getTokenTask.IsCompletedSuccessfully);
Assert.Equal(cloudProxy3, cloudConnection.CloudProxy.OrDefault());
Assert.True(cloudProxy3.IsActive);
Assert.True(cloudProxy1.IsActive);
Assert.Equal(cloudProxy1, cloudProxy3);
Assert.Equal(getTokenTask.Result, (clientCredentialsWithNonExpiringToken as ITokenCredentials)?.Token);
}

[Fact]
[Unit]
public async Task CloudConnectionCallbackTest()
Expand Down Expand Up @@ -205,9 +289,9 @@ public async Task UpdateDeviceConnectionTest()
string hostname = "dummy.azure-devices.net";
string deviceId = "device1";

IClientCredentials GetClientCredentials()
IClientCredentials GetClientCredentials(TimeSpan tokenExpiryDuration)
{
string token = TokenHelper.CreateSasToken(hostname, DateTime.UtcNow.AddSeconds(10));
string token = TokenHelper.CreateSasToken(hostname, DateTime.UtcNow.AddSeconds(tokenExpiryDuration.TotalSeconds));
var identity = new DeviceIdentity(hostname, deviceId);
return new TokenCredentials(identity, token, string.Empty);
}
Expand Down Expand Up @@ -258,7 +342,7 @@ IClient GetMockedDeviceClient()
var credentialsCache = Mock.Of<ICredentialsCache>();
IConnectionManager connectionManager = new ConnectionManager(cloudConnectionProvider, credentialsCache, deviceId, "$edgeHub");

IClientCredentials clientCredentials1 = GetClientCredentials();
IClientCredentials clientCredentials1 = GetClientCredentials(TimeSpan.FromSeconds(10));
Try<ICloudProxy> cloudProxyTry1 = await connectionManager.CreateCloudConnectionAsync(clientCredentials1);
Assert.True(cloudProxyTry1.Success);

Expand All @@ -272,36 +356,26 @@ IClient GetMockedDeviceClient()
Task<string> tokenGetter = deviceTokenRefresher.GetTokenAsync(hostname);
Assert.False(tokenGetter.IsCompleted);

IClientCredentials clientCredentials2 = GetClientCredentials();
IClientCredentials clientCredentials2 = GetClientCredentials(TimeSpan.FromMinutes(2));
Try<ICloudProxy> cloudProxyTry2 = await connectionManager.CreateCloudConnectionAsync(clientCredentials2);
Assert.True(cloudProxyTry2.Success);

IDeviceProxy deviceProxy2 = GetMockDeviceProxy();
await connectionManager.AddDeviceConnection(clientCredentials2.Identity, deviceProxy2);

await Task.Delay(TimeSpan.FromSeconds(3));
Assert.True(tokenGetter.IsCompleted);
Assert.Equal(tokenGetter.Result, (clientCredentials2 as ITokenCredentials)?.Token);

await Task.Delay(TimeSpan.FromSeconds(10));
Assert.NotNull(authenticationMethod);
deviceTokenRefresher = authenticationMethod as DeviceAuthenticationWithTokenRefresh;
Assert.NotNull(deviceTokenRefresher);
tokenGetter = deviceTokenRefresher.GetTokenAsync(hostname);
Assert.False(tokenGetter.IsCompleted);

IClientCredentials clientCredentials3 = GetClientCredentials();
IClientCredentials clientCredentials3 = GetClientCredentials(TimeSpan.FromMinutes(10));
Try<ICloudProxy> cloudProxyTry3 = await connectionManager.CreateCloudConnectionAsync(clientCredentials3);
Assert.True(cloudProxyTry3.Success);

IDeviceProxy deviceProxy3 = GetMockDeviceProxy();
await connectionManager.AddDeviceConnection(clientCredentials3.Identity, deviceProxy3);

await Task.Delay(TimeSpan.FromSeconds(3));
await Task.Delay(TimeSpan.FromSeconds(23));
Assert.True(tokenGetter.IsCompleted);
Assert.Equal(tokenGetter.Result, (clientCredentials3 as ITokenCredentials)?.Token);

Mock.VerifyAll(Mock.Get(deviceProxy1), Mock.Get(deviceProxy2));
}

static async Task GetCloudConnectionTest(Func<IClientCredentials> credentialsGenerator, IClientProvider clientProvider)
Expand Down

0 comments on commit 9d2ba5e

Please sign in to comment.