Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constrain NetworkTimeout on IMDS calls from DefaultAzureCredential #24328

Merged
merged 7 commits into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
<ItemGroup>
<PackageReference Include="Azure.Core" />
<PackageReference Include="Azure.Identity" />
<PackageReference Include="Microsoft.AspNetCore.Server.Kestrel" />
<PackageReference Include="Newtonsoft.Json" />
<PackageReference Include="NUnit" />
<PackageReference Include="NUnit3TestAdapter" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;

namespace Azure.Core.Tests
namespace Azure.Core.TestFramework
{
public class TestServer : IStartup, IDisposable
{
Expand All @@ -21,7 +21,7 @@ public class TestServer : IStartup, IDisposable

public Uri Address => new Uri(_host.ServerFeatures.Get<IServerAddressesFeature>().Addresses.First());

public TestServer(Action<HttpContext> app) : this(context => { app(context); return Task.CompletedTask;})
public TestServer(Action<HttpContext> app) : this(context => { app(context); return Task.CompletedTask; })
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Net.Http;
using System.Threading.Tasks;
using Azure.Core.Pipeline;
using Azure.Core.TestFramework;
using NUnit.Framework;

namespace Azure.Core.Tests
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using Azure.Core;

namespace Azure.Identity
Expand All @@ -25,7 +26,14 @@ public virtual TokenCredential CreateEnvironmentCredential()

public virtual TokenCredential CreateManagedIdentityCredential(string clientId)
{
return new ManagedIdentityCredential(clientId, Pipeline);
return new ManagedIdentityCredential(new ManagedIdentityClient(
new ManagedIdentityClientOptions
{
ClientId = clientId,
Pipeline = Pipeline,
InitialImdsConnectionTimeout = TimeSpan.FromSeconds(1)
})
);
}

public virtual TokenCredential CreateSharedTokenCacheCredential(string tenantId, string username)
Expand Down
19 changes: 17 additions & 2 deletions sdk/identity/Azure.Identity/src/ImdsManagedIdentitySource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ internal class ImdsManagedIdentitySource : ManagedIdentitySource
private readonly string _clientId;
private readonly Uri _imdsEndpoint;

internal ImdsManagedIdentitySource(CredentialPipeline pipeline, string clientId) : base(pipeline)
private TimeSpan? _imdsNetworkTimeout;

internal ImdsManagedIdentitySource(ManagedIdentityClientOptions options) : base(options.Pipeline)
{
_clientId = clientId;
_clientId = options.ClientId;
_imdsNetworkTimeout = options.InitialImdsConnectionTimeout;

if (!string.IsNullOrEmpty(EnvironmentVariables.PodIdentityEndpoint))
{
Expand Down Expand Up @@ -66,6 +69,15 @@ protected override Request CreateRequest(string[] scopes)
return request;
}

protected override HttpMessage CreateHttpMessage(Request request)
{
HttpMessage message = base.CreateHttpMessage(request);

message.NetworkTimeout = _imdsNetworkTimeout;

return message;
}

public async override ValueTask<AccessToken> AuthenticateAsync(bool async, TokenRequestContext context, CancellationToken cancellationToken)
{
try
Expand All @@ -88,6 +100,9 @@ public async override ValueTask<AccessToken> AuthenticateAsync(bool async, Token

protected override async ValueTask<AccessToken> HandleResponseAsync(bool async, TokenRequestContext context, Response response, CancellationToken cancellationToken)
{
// if we got a response from IMDS we can stop limiting the network timeout
schaabs marked this conversation as resolved.
Show resolved Hide resolved
_imdsNetworkTimeout = null;

// handle error status codes indicating managed identity is not available
var baseMessage = response.Status switch
{
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private static ManagedIdentitySource SelectManagedIdentitySource(ManagedIdentity
AzureArcManagedIdentitySource.TryCreate(options) ??
ServiceFabricManagedIdentitySource.TryCreate(options) ??
TokenExchangeManagedIdentitySource.TryCreate(options) ??
new ImdsManagedIdentitySource(options.Pipeline, options.ClientId);
new ImdsManagedIdentitySource(options);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;

namespace Azure.Identity
{
internal class ManagedIdentityClientOptions
Expand All @@ -11,6 +13,8 @@ internal class ManagedIdentityClientOptions

public bool PreserveTransport { get; set; }

public TimeSpan? InitialImdsConnectionTimeout { get; set; }

public CredentialPipeline Pipeline { get; set; }
}
}
9 changes: 7 additions & 2 deletions sdk/identity/Azure.Identity/src/ManagedIdentitySource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ protected ManagedIdentitySource(CredentialPipeline pipeline)
public virtual async ValueTask<AccessToken> AuthenticateAsync(bool async, TokenRequestContext context, CancellationToken cancellationToken)
{
using Request request = CreateRequest(context.Scopes);
using HttpMessage message = new HttpMessage(request, _responseClassifier);
using HttpMessage message = CreateHttpMessage(request);
if (async)
{
await Pipeline.HttpPipeline.SendAsync(message, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -58,9 +58,14 @@ protected virtual async ValueTask<AccessToken> HandleResponseAsync(bool async, T

protected abstract Request CreateRequest(string[] scopes);

protected virtual HttpMessage CreateHttpMessage(Request request)
{
return new HttpMessage(request, _responseClassifier);
}

protected static async Task<string> GetMessageFromResponse(Response response, bool async, CancellationToken cancellationToken)
{
if (response?.ContentStream == null)
if (response?.ContentStream == null || !response.ContentStream.CanRead || response.ContentStream.Length == 0)
{
return null;
}
Expand Down
102 changes: 100 additions & 2 deletions sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Azure.Core;
using Azure.Core.TestFramework;
using Azure.Identity.Tests.Mock;
using Microsoft.AspNetCore.Http;
using NUnit.Framework;

namespace Azure.Identity.Tests
Expand Down Expand Up @@ -318,7 +319,7 @@ public async Task VerifyMsiUnavailableOnIMDSAggregateExcpetion()
using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.001/" } });

// setting the delay to 1ms and retry mode to fixed to speed up test
var options = new TokenCredentialOptions() { Retry = { Delay = TimeSpan.FromMilliseconds(1), Mode = RetryMode.Fixed } };
var options = new TokenCredentialOptions() { Retry = { Delay = TimeSpan.FromMilliseconds(1), Mode = RetryMode.Fixed, NetworkTimeout = TimeSpan.FromMilliseconds(100) } };
schaabs marked this conversation as resolved.
Show resolved Hide resolved

var credential = InstrumentClient(new ManagedIdentityCredential(options: options));

Expand All @@ -335,7 +336,7 @@ public async Task VerifyMsiUnavailableOnIMDSRequestFailedExcpetion()
{
using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.001/" } });

var options = new TokenCredentialOptions() { Retry = { MaxRetries = 0 } };
var options = new TokenCredentialOptions() { Retry = { MaxRetries = 0, NetworkTimeout = TimeSpan.FromMilliseconds(100) } };

var credential = InstrumentClient(new ManagedIdentityCredential(options: options));

Expand All @@ -346,6 +347,103 @@ public async Task VerifyMsiUnavailableOnIMDSRequestFailedExcpetion()
await Task.CompletedTask;
}

[NonParallelizable]
[Test]
public async Task VerifyMsiUnavailableOnIMDSGatewayErrorResponse([Values(502, 504)]int statusCode)
{
using var server = new TestServer(context =>
{
context.Response.StatusCode = statusCode;
});

using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", server.Address.AbsoluteUri } });

// setting the delay to 1ms and retry mode to fixed to speed up test
var options = new TokenCredentialOptions() { };

var credential = InstrumentClient(new ManagedIdentityCredential(options: options));

var ex = Assert.ThrowsAsync<CredentialUnavailableException>(async () => await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default)));

Assert.That(ex.Message, Does.Contain(ImdsManagedIdentitySource.GatewayError));

await Task.CompletedTask;
}

[NonParallelizable]
[Test]
public async Task VerifyInitialImdsConnectionTimeoutHonored()
{
using var server = new TestServer(context =>
{
Task.Delay(8000).Wait();
schaabs marked this conversation as resolved.
Show resolved Hide resolved

context.Response.StatusCode = 418;
});

using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", server.Address.AbsoluteUri } });

// setting the delay to 1ms and retry mode to fixed to speed up test
var options = new TokenCredentialOptions() { Retry = { Delay = TimeSpan.FromMilliseconds(0), Mode = RetryMode.Fixed } };

var pipeline = CredentialPipeline.GetInstance(options);

var miClientOptions = new ManagedIdentityClientOptions { InitialImdsConnectionTimeout = TimeSpan.FromMilliseconds(100), Pipeline = pipeline };

var credential = InstrumentClient(new ManagedIdentityCredential(new ManagedIdentityClient(miClientOptions)));

var startTime = DateTimeOffset.UtcNow;

var ex = Assert.ThrowsAsync<CredentialUnavailableException>(async () => await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default)));

Assert.That(ex.Message, Does.Contain(ImdsManagedIdentitySource.AggregateError));

Assert.Less(DateTimeOffset.UtcNow - startTime, TimeSpan.FromSeconds(2));

await Task.CompletedTask;
}

[NonParallelizable]
[Test]
public async Task VerifyInitialImdsConnectionTimeoutRelaxed()
{
string token = Guid.NewGuid().ToString();
bool delay = false;

using var server = new TestServer(async context =>
schaabs marked this conversation as resolved.
Show resolved Hide resolved
{
if (delay)
{
Task.Delay(2000).Wait();
schaabs marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
delay = true;
}

await context.Response.WriteAsync($"{{ \"access_token\": \"{token}\", \"expires_on\": \"3600\" }}");
});

using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", server.Address.AbsoluteUri } });

// setting the delay to 1ms and retry mode to fixed to speed up test
var options = new TokenCredentialOptions() { Retry = { Delay = TimeSpan.FromMilliseconds(0), Mode = RetryMode.Fixed } };

var pipeline = CredentialPipeline.GetInstance(options);

var miClientOptions = new ManagedIdentityClientOptions { InitialImdsConnectionTimeout = TimeSpan.FromMilliseconds(1000), Pipeline = pipeline };

var credential = InstrumentClient(new ManagedIdentityCredential(new ManagedIdentityClient(miClientOptions)));

var at = await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default));

Assert.AreEqual(token, at.Token);

var at2 = await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default));

Assert.AreEqual(token, at.Token);
}

[Test]
public async Task VerifyClientAuthenticateThrows()
{
Expand Down